提交 e4392454 编写于 作者: G guosheng

Small fix of Transformer model

上级 aff79465
...@@ -15,9 +15,6 @@ def position_encoding_init(n_position, d_pos_vec): ...@@ -15,9 +15,6 @@ def position_encoding_init(n_position, d_pos_vec):
pos / np.power(10000, 2 * (j // 2) / d_pos_vec) pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
for j in range(d_pos_vec) for j in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
# Set the position encoding of padding to small values rather than 0s to
# avoid nan in attention softmax.
position_enc[0, :] = 1e-9
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return position_enc.astype("float32") return position_enc.astype("float32")
...@@ -106,7 +103,7 @@ def multi_head_attention(queries, ...@@ -106,7 +103,7 @@ def multi_head_attention(queries,
# define the softmax temporarily. # define the softmax temporarily.
def __softmax(x, eps=1e-9): def __softmax(x, eps=1e-9):
exp_out = layers.exp(x=x) exp_out = layers.exp(x=x)
sum_out = layers.reduce_sum(x, dim=-1, keep_dim=False) sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)
scaled_q = layers.scale(x=q, scale=d_key**-0.5) scaled_q = layers.scale(x=q, scale=d_key**-0.5)
...@@ -196,6 +193,7 @@ def prepare_encoder(src_word, ...@@ -196,6 +193,7 @@ def prepare_encoder(src_word,
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
size=[src_max_len, src_emb_dim], size=[src_max_len, src_emb_dim],
padding_idx=pos_pad_idx,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册