提交 225a484d 编写于 作者: G guosheng

Reuse input as output of reshape_op in Transformer

上级 41cee4e1
......@@ -80,7 +80,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head])
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -99,7 +99,9 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]])
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
......@@ -637,7 +639,8 @@ def wrap_decoder(trg_vocab_size,
postprocess_cmd,
caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(dec_output, shape=[-1, dec_output.shape[-1]])
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing:
predict = layers.matmul(
x=dec_output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册