未验证 提交 3b549867 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #701 from guoshengCS/add-transformer-initializer

Add initializer for Transformer.
from functools import partial from functools import partial
import numpy as np import numpy as np
import paddle.v2 as paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -31,7 +30,7 @@ def multi_head_attention(queries, ...@@ -31,7 +30,7 @@ def multi_head_attention(queries,
d_key, d_key,
d_value, d_value,
d_model, d_model,
num_heads=1, n_head=1,
dropout_rate=0.): dropout_rate=0.):
""" """
Multi-Head Attention. Note that attn_bias is added to the logit before Multi-Head Attention. Note that attn_bias is added to the logit before
...@@ -42,41 +41,53 @@ def multi_head_attention(queries, ...@@ -42,41 +41,53 @@ def multi_head_attention(queries,
raise ValueError( raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.") "Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, num_heads, d_key, d_value): def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
""" """
Add linear projection to queries, keys, and values. Add linear projection to queries, keys, and values.
""" """
q = layers.fc(input=queries, q = layers.fc(input=queries,
size=d_key * num_heads, size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
k = layers.fc(input=keys, k = layers.fc(input=keys,
size=d_key * num_heads, size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
v = layers.fc(input=values, v = layers.fc(input=values,
size=d_value * num_heads, size=d_value * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_value,
fan_out=n_head * d_value),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
return q, k, v return q, k, v
def __split_heads(x, num_heads): def __split_heads(x, n_head):
""" """
Reshape the last dimension of inpunt tensor x so that it becomes two Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, num_heads * hidden_dim] then output a tensor [bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, num_heads, max_sequence_length, hidden_dim]. with shape [bs, n_head, max_sequence_length, hidden_dim].
""" """
if num_heads == 1: if n_head == 1:
return x return x
hidden_size = x.shape[-1] hidden_size = x.shape[-1]
# FIXME(guosheng): Decouple the program desc with batch_size. # FIXME(guosheng): Decouple the program desc with batch_size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads]) x=x, shape=[batch_size, -1, n_head, hidden_size // n_head])
# permuate the dimensions into: # permuate the dimensions into:
# [batch_size, num_heads, max_sequence_len, hidden_size_per_head] # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x): def __combine_heads(x):
...@@ -95,7 +106,7 @@ def multi_head_attention(queries, ...@@ -95,7 +106,7 @@ def multi_head_attention(queries,
shape=map(int, shape=map(int,
[batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
""" """
Scaled Dot-Product Attention Scaled Dot-Product Attention
""" """
...@@ -114,7 +125,7 @@ def multi_head_attention(queries, ...@@ -114,7 +125,7 @@ def multi_head_attention(queries,
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)
scaled_q = layers.scale(x=q, scale=d_key**-0.5) scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = __softmax(layers.elementwise_add(x=product, y=attn_bias)) weights = __softmax(layers.elementwise_add(x=product, y=attn_bias))
if dropout_rate: if dropout_rate:
...@@ -123,13 +134,13 @@ def multi_head_attention(queries, ...@@ -123,13 +134,13 @@ def multi_head_attention(queries,
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
return out return out
q, k, v = __compute_qkv(queries, keys, values, num_heads, d_key, d_value) q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q = __split_heads(q, num_heads) q = __split_heads(q, n_head)
k = __split_heads(k, num_heads) k = __split_heads(k, n_head)
v = __split_heads(v, num_heads) v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
dropout_rate) dropout_rate)
out = __combine_heads(ctx_multiheads) out = __combine_heads(ctx_multiheads)
...@@ -137,6 +148,7 @@ def multi_head_attention(queries, ...@@ -137,6 +148,7 @@ def multi_head_attention(queries,
# Project back to the model size. # Project back to the model size.
proj_out = layers.fc(input=out, proj_out = layers.fc(input=out,
size=d_model, size=d_model,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False, bias_attr=False,
num_flatten_dims=2) num_flatten_dims=2)
return proj_out return proj_out
...@@ -151,8 +163,14 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): ...@@ -151,8 +163,14 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
hidden = layers.fc(input=x, hidden = layers.fc(input=x,
size=d_inner_hid, size=d_inner_hid,
num_flatten_dims=2, num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_hid**-0.5), high=(d_hid**-0.5)),
act="relu") act="relu")
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
return out return out
...@@ -168,7 +186,11 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): ...@@ -168,7 +186,11 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
if cmd == "a": # add residual connection if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1) out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
if dropout: if dropout:
out = layers.dropout(out, dropout_prob=dropout, is_test=False) out = layers.dropout(out, dropout_prob=dropout, is_test=False)
...@@ -195,7 +217,10 @@ def prepare_encoder(src_word, ...@@ -195,7 +217,10 @@ def prepare_encoder(src_word,
This module is used at the bottom of the encoder stacks. This module is used at the bottom of the encoder stacks.
""" """
src_word_emb = layers.embedding( src_word_emb = layers.embedding(
src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx) src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=src_pad_idx,
param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
size=[src_max_len, src_emb_dim], size=[src_max_len, src_emb_dim],
...@@ -462,6 +487,7 @@ def transformer( ...@@ -462,6 +487,7 @@ def transformer(
predict = layers.reshape( predict = layers.reshape(
x=layers.fc(input=dec_output, x=layers.fc(input=dec_output,
size=trg_vocab_size, size=trg_vocab_size,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False, bias_attr=False,
num_flatten_dims=2), num_flatten_dims=2),
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size],
......
...@@ -115,7 +115,7 @@ def main(): ...@@ -115,7 +115,7 @@ def main():
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size, paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size), ModelHyperParams.trg_vocab_size),
buf_size=51200), buf_size=100000),
batch_size=TrainTaskConfig.batch_size) batch_size=TrainTaskConfig.batch_size)
# Initialize the parameters. # Initialize the parameters.
...@@ -143,7 +143,7 @@ def main(): ...@@ -143,7 +143,7 @@ def main():
fetch_list=[cost]) fetch_list=[cost])
cost_val = np.array(outs[0]) cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" avg_cost = " + str(cost_val)) " cost = " + str(cost_val))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册