提交 53e7c177 编写于 作者: Y ying

refine comments.

上级 78e29521
# Transformer
Set the model and training configurations in `config.py`, and execute `python train.py` to train.
More details to be added.
# Represent the dict sizes of source and target language. The dict from the
# dataset here used includes the <bos>, <eos> and <unk> token but exlcudes
# the <pad> token. It should plus 1 to include the padding token when used as
# the size of lookup table.
src_vocab_size = 10000
trg_vocab_size = 10000
# Represent the id of <pad> token in source language.
src_pad_idx = src_vocab_size
# Represent the id of <pad> token in target language.
trg_pad_idx = trg_vocab_size
# Represent the position value corresponding to the <pad> token.
pos_pad_idx = 0
# Represent the max length of sequences. It should plus 1 to include position
# padding token for position encoding.
max_length = 50
# Represent the epoch number to train.
pass_num = 2
# Represent the number of sequences contained in a mini-batch.
batch_size = 64
# Reprent the params for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# Represent the dimension of embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# Represent the size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
# Represent the dimension keys are projected to for dot-product attention.
d_key = 64
# Represent the dimension values are projected to for dot-product attention.
d_value = 64
# Represent the number of head used in multi-head attention.
n_head = 8
# Represent the number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# Represent the dropout rate used by all dropout layers.
dropout = 0.1
# Names of position encoding table which will be initialized in external.
pos_enc_param_names = ("src_pos_enc_table", "trg_pos_enc_table")
# Names of all data layers listed in order.
input_data_names = ("src_word", "src_pos", "trg_word", "trg_pos",
"src_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "lbl_word")
# Attention is All You Need: A Paddle Fluid implementation
This is a Paddle Fluid implementation of the Transformer model in [Attention is All You Need]() (Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, arxiv, 2017).
If you use the dataset/code in your research, please cite the paper:
```text
@inproceedings{vaswani2017attention,
title={Attention is all you need},
author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia},
booktitle={Advances in Neural Information Processing Systems},
pages={6000--6010},
year={2017}
}
```
### TODO
This project is still under active development.
class TrainTaskConfig(object):
use_gpu = False
# the epoch number to train.
pass_num = 2
# number of sequences contained in a mini-batch.
batch_size = 64
# the hyper params for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added, but the <pad> token is not added. Transformer requires
# sequences in a mini-batch are padded to have the same length. A <pad> token is
# added into the original dictionary in paddle.dateset.wmt16.
# size of source word dictionary.
src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size
# size of target word dictionay
trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# position value corresponding to the <pad> token.
pos_pad_idx = 0
# max length of sequences. It should plus 1 to include position
# padding token for position encoding.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rate used by all dropout layers.
dropout = 0.1
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# Names of all data layers listed in order.
input_data_names = (
"src_word",
"src_pos",
"trg_word",
"trg_pos",
"src_slf_attn_bias",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"lbl_word", )
from functools import partial from functools import partial
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers import paddle.v2.fluid.layers as layers
# TODO: Remove out the batch_size from the model.
from config import batch_size, input_data_names, pos_enc_param_names from config import TrainTaskConfig, input_data_names, pos_enc_param_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
""" """
Generate the initial values for the sinusoid position encoding table. Generate the initial values for the sinusoid position encoding table.
""" """
position_enc = np.array([[ position_enc = np.array([[
...@@ -30,15 +34,16 @@ def multi_head_attention(queries, ...@@ -30,15 +34,16 @@ def multi_head_attention(queries,
num_heads=1, num_heads=1,
dropout_rate=0.): dropout_rate=0.):
""" """
Multi-Head Attention. Note that attn_bias will be to add to the logit to Multi-Head Attention. Note that attn_bias is added to the logit before
affect the attention weights. computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
""" """
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError( raise ValueError(
"Inputs quries, keys and values should all be 3-D tensors.") "Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, num_heads, d_key, d_value): def __compute_qkv(queries, keys, values, num_heads, d_key, d_value):
""" """
Add linear projection to queries, keys, and values. Add linear projection to queries, keys, and values.
""" """
q = layers.fc(input=queries, q = layers.fc(input=queries,
...@@ -66,7 +71,7 @@ def multi_head_attention(queries, ...@@ -66,7 +71,7 @@ def multi_head_attention(queries,
return x return x
hidden_size = x.shape[-1] hidden_size = x.shape[-1]
# TODO: Decouple the program desc with batch_size. # FIXME(guosheng): Decouple the program desc with batch_size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads]) x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads])
...@@ -84,7 +89,7 @@ def multi_head_attention(queries, ...@@ -84,7 +89,7 @@ def multi_head_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.") raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# TODO: Decouple the program desc with batch_size. # FIXME(guosheng): Decouple the program desc with batch_size.
return layers.reshape( return layers.reshape(
x=trans_x, x=trans_x,
shape=map(int, shape=map(int,
...@@ -95,12 +100,15 @@ def multi_head_attention(queries, ...@@ -95,12 +100,15 @@ def multi_head_attention(queries,
Scaled Dot-Product Attention Scaled Dot-Product Attention
""" """
# TODO: Optimize the shape in reshape_op or softmax_op. # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.
# The softmax_op only supports 2D tensor currently and cann't be used
# here. Additionally, the reshape_op cann't be used here, since the # The current implementation of softmax_op only supports 2D tensor,
# shape of product inferred in compile-time is not the actual shape in # consequently it cannot be directly used here.
# run-time and cann't be used to set the attribute of reshape_op. Thus, # If to use the reshape_op, Besides, the shape of product inferred in
# define the softmax temporarily. # compile-time is not the actual shape in run-time. It cann't be used
# to set the attribute of reshape_op.
# So, here define the softmax for temporary solution.
def __softmax(x, eps=1e-9): def __softmax(x, eps=1e-9):
exp_out = layers.exp(x=x) exp_out = layers.exp(x=x)
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
...@@ -136,9 +144,9 @@ def multi_head_attention(queries, ...@@ -136,9 +144,9 @@ def multi_head_attention(queries,
def positionwise_feed_forward(x, d_inner_hid, d_hid): def positionwise_feed_forward(x, d_inner_hid, d_hid):
""" """
Position-wise Feed-Forward Networks. Position-wise Feed-Forward Networks.
This consists of two linear transformations with a ReLU activation in This module consists of two linear transformations with a ReLU activation
between, which is applied to each position separately and identically. in between, which is applied to each position separately and identically.
""" """
hidden = layers.fc(input=x, hidden = layers.fc(input=x,
size=d_inner_hid, size=d_inner_hid,
...@@ -150,17 +158,18 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): ...@@ -150,17 +158,18 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
""" """
Add residual connection, layer normalization and droput on the out tensor Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd. optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise This will be used before or after multi-head attention and position-wise
feed-forward networks. feed-forward networks.
""" """
for cmd in process_cmd: for cmd in process_cmd:
if cmd == "a": if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out out = out + prev_out if prev_out else out
elif cmd == "n": elif cmd == "n": # add layer normalization
out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1) out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1)
elif cmd == "d": elif cmd == "d": # add dropout
if dropout: if dropout:
out = layers.dropout(out, dropout_prob=dropout, is_test=False) out = layers.dropout(out, dropout_prob=dropout, is_test=False)
return out return out
...@@ -179,10 +188,11 @@ def prepare_encoder(src_word, ...@@ -179,10 +188,11 @@ def prepare_encoder(src_word,
dropout=0., dropout=0.,
pos_pad_idx=0, pos_pad_idx=0,
pos_enc_param_name=None): pos_enc_param_name=None):
""" """Add word embeddings and position encodings.
Add word embeddings and position encodings and output a tensor with shape The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model]. [batch_size, max_src_length_in_batch, d_model].
This is used at the bottom of the encoder stacks.
This module is used at the bottom of the encoder stacks.
""" """
src_word_emb = layers.embedding( src_word_emb = layers.embedding(
src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx) src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx)
...@@ -193,7 +203,8 @@ def prepare_encoder(src_word, ...@@ -193,7 +203,8 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
# TODO: Decouple the program desc with batch_size
# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
return layers.dropout( return layers.dropout(
enc_input, dropout_prob=dropout, enc_input, dropout_prob=dropout,
...@@ -206,142 +217,247 @@ prepare_decoder = partial( ...@@ -206,142 +217,247 @@ prepare_decoder = partial(
prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]) prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])
def encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, d_model, def encoder_layer(enc_input,
d_inner_hid, dropout): attn_bias,
""" n_head,
The layer to be stacked in the encoder. d_key,
This consits of multi-head (self) attention followed by position-wise d_value,
feed-forward networks and both the two components companied with the d_model,
post_process_layer to add residual connection, layer normalization and d_inner_hid,
droput. dropout_rate=0.):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
""" """
attn_output = multi_head_attention(enc_input, enc_input, enc_input, attn_output = multi_head_attention(enc_input, enc_input, enc_input,
attn_bias, d_key, d_value, d_model, attn_bias, d_key, d_value, d_model,
n_head, dropout) n_head, dropout_rate)
attn_output = post_process_layer(enc_input, attn_output, "dan", dropout) attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
output = post_process_layer(attn_output, ffd_output, "dan", dropout) return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)
return output
def encoder(enc_input,
def encoder(enc_input, attn_bias, n_layer, n_head, d_key, d_value, d_model, attn_bias,
d_inner_hid, dropout): n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
""" """
The encoder is composed of a stack of identical encoder_layer layers. The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
""" """
for i in range(n_layer): for i in range(n_layer):
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout) d_model, d_inner_hid, dropout_rate)
enc_input = enc_output enc_input = enc_output
return enc_output return enc_output
def decoder_layer(dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias, def decoder_layer(dec_input,
n_head, d_key, d_value, d_model, d_inner_hid, dropout): enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
""" """
The layer to be stacked in the decoder. The structure of this is similar to slf_attn_output = multi_head_attention(
the encoder_layer but another multi-head attention is added to implement dec_input,
encoder-decoder attention. dec_input,
""" dec_input,
slf_attn_output = multi_head_attention(dec_input, dec_input, dec_input, slf_attn_bias,
slf_attn_bias, d_key, d_value, d_key,
d_model, n_head, dropout) d_value,
slf_attn_output = post_process_layer(dec_input, slf_attn_output, "dan", d_model,
dropout) n_head,
enc_attn_output = multi_head_attention(slf_attn_output, enc_output, dropout_rate, )
enc_output, dec_enc_attn_bias, d_key, slf_attn_output = post_process_layer(
d_value, d_model, n_head, dropout) dec_input,
enc_attn_output = post_process_layer(slf_attn_output, enc_attn_output, slf_attn_output,
"dan", dropout) "dan", # residual connection + dropout + layer normalization
ffd_output = positionwise_feed_forward(enc_attn_output, d_inner_hid, dropout_rate, )
d_model) enc_attn_output = multi_head_attention(
dec_output = post_process_layer(enc_attn_output, ffd_output, "dan", dropout) slf_attn_output,
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
dropout_rate, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
ffd_output = positionwise_feed_forward(
enc_attn_output,
d_inner_hid,
d_model, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
return dec_output return dec_output
def decoder(dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, def decoder(dec_input,
n_layer, n_head, d_key, d_value, d_model, d_inner_hid, dropout): enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
""" """
The decoder is composed of a stack of identical decoder_layer layers. The decoder is composed of a stack of identical decoder_layer layers.
""" """
for i in range(n_layer): for i in range(n_layer):
dec_output = decoder_layer(dec_input, enc_output, dec_slf_attn_bias, dec_output = decoder_layer(
dec_enc_attn_bias, n_head, d_key, d_value, dec_input,
d_model, d_inner_hid, dropout) enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dec_input = dec_output dec_input = dec_output
return dec_output return dec_output
def transformer(src_vocab_size, trg_vocab_size, max_length, n_layer, n_head, def transformer(
d_key, d_value, d_model, d_inner_hid, dropout, src_pad_idx, src_vocab_size,
trg_pad_idx, pos_pad_idx): trg_vocab_size,
# The shapes here only act as placeholder and are set to guarantee the max_length,
# success of infer-shape in compile time. n_layer,
# The actual shape of src_word is: n_head,
# [batch_size * max_src_length_in_batch, 1]. d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time. The actual
# shape of src_word in run time is:
# [batch_size * max_src_length_in_a_batch, 1].
src_word = layers.data( src_word = layers.data(
name=input_data_names[0], name=input_data_names[0],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
# The actual shape of src_pos is: # The actual shape of src_pos in runtime is:
# [batch_size * max_src_length_in_batch, 1]. # [batch_size * max_src_length_in_a_batch, 1].
src_pos = layers.data( src_pos = layers.data(
name=input_data_names[1], name=input_data_names[1],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
# The actual shape of trg_word is: # The actual shape of trg_word is in runtime is:
# [batch_size * max_trg_length_in_batch, 1]. # [batch_size * max_trg_length_in_a_batch, 1].
trg_word = layers.data( trg_word = layers.data(
name=input_data_names[2], name=input_data_names[2],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
# The actual shape of trg_pos is: # The actual shape of trg_pos in runtime is:
# [batch_size * max_trg_length_in_batch, 1]. # [batch_size * max_trg_length_in_a_batch, 1].
trg_pos = layers.data( trg_pos = layers.data(
name=input_data_names[3], name=input_data_names[3],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
# The actual shape of src_slf_attn_bias is: # The actual shape of src_slf_attn_bias in runtime is:
# [batch_size, n_head, max_src_length_in_batch, max_src_length_in_batch]. # [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch].
# This is used to avoid attention on paddings. # This input is used to remove attention weights on paddings.
src_slf_attn_bias = layers.data( src_slf_attn_bias = layers.data(
name=input_data_names[4], name=input_data_names[4],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
# The actual shape of trg_slf_attn_bias is: # The actual shape of trg_slf_attn_bias in runtime is:
# [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch]. # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch].
# This is used to avoid attention on paddings and subsequent words. # This is used to remove attention weights on paddings and subsequent words.
trg_slf_attn_bias = layers.data( trg_slf_attn_bias = layers.data(
name=input_data_names[5], name=input_data_names[5],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
# The actual shape of trg_src_attn_bias is: # The actual shape of trg_src_attn_bias in runtime is:
# [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch]. # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch].
# This is used to avoid attention on paddings. # This is used to remove attention weights on paddings.
trg_src_attn_bias = layers.data( trg_src_attn_bias = layers.data(
name=input_data_names[6], name=input_data_names[6],
shape=[batch_size, n_head, max_length, max_length], shape=[batch_size, n_head, max_length, max_length],
dtype="float32", dtype="float32",
append_batch_size=False) append_batch_size=False)
enc_input = prepare_encoder(src_word, src_pos, src_vocab_size, d_model, enc_input = prepare_encoder(
src_pad_idx, max_length, dropout) src_word,
enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key, src_pos,
d_value, d_model, d_inner_hid, dropout) src_vocab_size,
d_model,
dec_input = prepare_decoder(trg_word, trg_pos, trg_vocab_size, d_model, src_pad_idx,
trg_pad_idx, max_length, dropout) max_length,
dec_output = decoder(dec_input, enc_output, trg_slf_attn_bias, dropout_rate, )
trg_src_attn_bias, n_layer, n_head, d_key, d_value, enc_output = encoder(
d_model, d_inner_hid, dropout) enc_input,
src_slf_attn_bias,
# TODO: Share the same weight matrix between the two embedding layers and n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate, )
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
# TODO(guosheng): Share the weight matrix between the embedding layers and
# the pre-softmax linear transformation. # the pre-softmax linear transformation.
predict = layers.reshape( predict = layers.reshape(
x=layers.fc(input=dec_output, x=layers.fc(input=dec_output,
...@@ -350,13 +466,12 @@ def transformer(src_vocab_size, trg_vocab_size, max_length, n_layer, n_head, ...@@ -350,13 +466,12 @@ def transformer(src_vocab_size, trg_vocab_size, max_length, n_layer, n_head,
num_flatten_dims=2), num_flatten_dims=2),
shape=[-1, trg_vocab_size], shape=[-1, trg_vocab_size],
act="softmax") act="softmax")
# The actual shape of gold is: # The actual shape of gold in runtime is:
# [batch_size * max_trg_length_in_batch, 1]. # [batch_size * max_trg_length_in_a_batch, 1].
gold = layers.data( gold = layers.data(
name=input_data_names[7], name=input_data_names[7],
shape=[batch_size * max_length, 1], shape=[batch_size * max_length, 1],
dtype="int64", dtype="int64",
append_batch_size=False) append_batch_size=False)
cost = layers.cross_entropy(input=predict, label=gold) cost = layers.cross_entropy(input=predict, label=gold)
avg_cost = layers.mean(x=cost) return layers.mean(x=cost)
return avg_cost
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
from config import * from config import TrainTaskConfig, ModelHyperParams, \
pos_enc_param_names, input_data_names
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
...@@ -14,12 +17,12 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -14,12 +17,12 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
""" """
input_dict = {} input_dict = {}
def pad_batch_data(insts, def __pad_batch_data(insts,
pad_idx, pad_idx,
is_target=False, is_target=False,
return_pos=True, return_pos=True,
return_attn_bias=True, return_attn_bias=True,
return_max_len=True): return_max_len=True):
""" """
Pad the instances to the max sequence length in batch, and generate the Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias. corresponding position data and attention bias.
...@@ -66,14 +69,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -66,14 +69,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
tensor.set(data_list[i], place) tensor.set(data_list[i], place)
input_dict[name_list[i]] = tensor input_dict[name_list[i]] = tensor
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, is_target=False) [inst[0] for inst in insts], src_pad_idx, is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, is_target=True) [inst[1] for inst in insts], trg_pad_idx, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False, lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
False, False, False) False, False, False)
data_to_tensor([ data_to_tensor([
src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
...@@ -84,22 +87,30 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -84,22 +87,30 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
def main(): def main():
avg_cost = transformer(src_vocab_size + 1, trg_vocab_size + 1, avg_cost = transformer(
max_length + 1, n_layer, n_head, d_key, d_value, ModelHyperParams.src_vocab_size + 1,
d_model, d_inner_hid, dropout, src_pad_idx, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
trg_pad_idx, pos_pad_idx) ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=eps) learning_rate=TrainTaskConfig.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
train_data = paddle.batch( train_data = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.wmt16.train(src_vocab_size, trg_vocab_size), paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
buf_size=1000), ModelHyperParams.trg_vocab_size),
batch_size=batch_size) buf_size=51200),
batch_size=TrainTaskConfig.batch_size)
place = fluid.CPUPlace() place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Initialize the parameters. # Initialize the parameters.
...@@ -108,21 +119,21 @@ def main(): ...@@ -108,21 +119,21 @@ def main():
pos_enc_param = fluid.global_scope().find_var( pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor() pos_enc_param_name).get_tensor()
pos_enc_param.set( pos_enc_param.set(
position_encoding_init(max_length + 1, d_model), place) position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
batch_id = 0
for pass_id in xrange(pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
for data in train_data(): for batch_id, data in enumerate(train_data()):
data_input = prepare_batch_input(data, input_data_names, data_input = prepare_batch_input(
src_pad_idx, trg_pad_idx, data, input_data_names, ModelHyperParams.src_pad_idx,
max_length, n_head, place) ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
fetch_list=[avg_cost]) fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0]) avg_cost_val = np.array(outs[0])
print("pass_id=" + str(pass_id) + " batch=" + str(batch_id) + print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" avg_cost=" + str(avg_cost_val)) " avg_cost = " + str(avg_cost_val))
batch_id += 1
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册