提交 f610916c 编写于 作者: G guosheng

Fix the reshape in no-weight-sharing mode of Transformer

上级 4f244555
......@@ -129,10 +129,12 @@ def multi_head_attention(queries,
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
cache["k"], shape=[0, 0, d_key * n_head]), k],
axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
cache["v"], shape=[0, 0, d_value * n_head]), v],
axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
......@@ -657,8 +659,7 @@ def wrap_decoder(trg_vocab_size,
else:
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2)
bias_attr=False)
if dec_inputs is None:
# Return probs for independent decoder program.
predict = layers.softmax(predict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册