提交 225a484d 编写于 作者: G guosheng

Reuse input as output of reshape_op in Transformer

上级 41cee4e1
...@@ -80,7 +80,7 @@ def multi_head_attention(queries, ...@@ -80,7 +80,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size. # size of the input as the output dimension size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head]) x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into: # permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head] # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...@@ -99,7 +99,9 @@ def multi_head_attention(queries, ...@@ -99,7 +99,9 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size. # size of the input as the output dimension size.
return layers.reshape( return layers.reshape(
x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]]) x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
""" """
...@@ -637,7 +639,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -637,7 +639,8 @@ def wrap_decoder(trg_vocab_size,
postprocess_cmd, postprocess_cmd,
caches=caches) caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM # Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(dec_output, shape=[-1, dec_output.shape[-1]]) dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing: if weight_sharing:
predict = layers.matmul( predict = layers.matmul(
x=dec_output, x=dec_output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册