From aff794656c0308de9bf1f402d2549f63c9572951 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 7 Feb 2018 23:48:06 +0800 Subject: [PATCH] Add Transformer demo using Fluid API for NMT --- fluid/NMT_Transformer/README.md | 5 + fluid/NMT_Transformer/config.py | 47 ++++ fluid/NMT_Transformer/model.py | 368 ++++++++++++++++++++++++++++++++ fluid/NMT_Transformer/train.py | 129 +++++++++++ 4 files changed, 549 insertions(+) create mode 100644 fluid/NMT_Transformer/README.md create mode 100644 fluid/NMT_Transformer/config.py create mode 100644 fluid/NMT_Transformer/model.py create mode 100644 fluid/NMT_Transformer/train.py diff --git a/fluid/NMT_Transformer/README.md b/fluid/NMT_Transformer/README.md new file mode 100644 index 00000000..e5fa24c2 --- /dev/null +++ b/fluid/NMT_Transformer/README.md @@ -0,0 +1,5 @@ +# Transformer + +Set the model and training configurations in `config.py`, and execute `python train.py` to train. + +More details to be added. diff --git a/fluid/NMT_Transformer/config.py b/fluid/NMT_Transformer/config.py new file mode 100644 index 00000000..546d432f --- /dev/null +++ b/fluid/NMT_Transformer/config.py @@ -0,0 +1,47 @@ +# Represent the dict sizes of source and target language. The dict from the +# dataset here used includes the , and token but exlcudes +# the token. It should plus 1 to include the padding token when used as +# the size of lookup table. +src_vocab_size = 10000 +trg_vocab_size = 10000 +# Represent the id of token in source language. +src_pad_idx = src_vocab_size +# Represent the id of token in target language. +trg_pad_idx = trg_vocab_size +# Represent the position value corresponding to the token. +pos_pad_idx = 0 +# Represent the max length of sequences. It should plus 1 to include position +# padding token for position encoding. +max_length = 50 +# Represent the epoch number to train. +pass_num = 2 +# Represent the number of sequences contained in a mini-batch. +batch_size = 64 +# Reprent the params for Adam optimizer. +learning_rate = 0.001 +beta1 = 0.9 +beta2 = 0.98 +eps = 1e-9 +# Represent the dimension of embeddings, which is also the last dimension of +# the input and output of multi-head attention, position-wise feed-forward +# networks, encoder and decoder. +d_model = 512 +# Represent the size of the hidden layer in position-wise feed-forward networks. +d_inner_hid = 1024 +# Represent the dimension keys are projected to for dot-product attention. +d_key = 64 +# Represent the dimension values are projected to for dot-product attention. +d_value = 64 +# Represent the number of head used in multi-head attention. +n_head = 8 +# Represent the number of sub-layers to be stacked in the encoder and decoder. +n_layer = 6 +# Represent the dropout rate used by all dropout layers. +dropout = 0.1 + +# Names of position encoding table which will be initialized in external. +pos_enc_param_names = ("src_pos_enc_table", "trg_pos_enc_table") +# Names of all data layers listed in order. +input_data_names = ("src_word", "src_pos", "trg_word", "trg_pos", + "src_slf_attn_bias", "trg_slf_attn_bias", + "trg_src_attn_bias", "lbl_word") diff --git a/fluid/NMT_Transformer/model.py b/fluid/NMT_Transformer/model.py new file mode 100644 index 00000000..b28ebd98 --- /dev/null +++ b/fluid/NMT_Transformer/model.py @@ -0,0 +1,368 @@ +from functools import partial +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import paddle.v2.fluid.layers as layers +# TODO: Remove out the batch_size from the model. +from config import batch_size, input_data_names, pos_enc_param_names + + +def position_encoding_init(n_position, d_pos_vec): + """ + Generate the initial values for the sinusoid position encoding table. + """ + position_enc = np.array([[ + pos / np.power(10000, 2 * (j // 2) / d_pos_vec) + for j in range(d_pos_vec) + ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) + # Set the position encoding of padding to small values rather than 0s to + # avoid nan in attention softmax. + position_enc[0, :] = 1e-9 + position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i + position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 + return position_enc.astype("float32") + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + num_heads=1, + dropout_rate=0.): + """ + Multi-Head Attention. Note that attn_bias will be to add to the logit to + affect the attention weights. + """ + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): + raise ValueError( + "Inputs quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, num_heads, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * num_heads, + bias_attr=False, + num_flatten_dims=2) + k = layers.fc(input=keys, + size=d_key * num_heads, + bias_attr=False, + num_flatten_dims=2) + v = layers.fc(input=values, + size=d_value * num_heads, + bias_attr=False, + num_flatten_dims=2) + return q, k, v + + def __split_heads(x, num_heads): + """ + Reshape the last dimension of inpunt tensor x so that it becomes two + dimensions and then transpose. Specifically, input a tensor with shape + [bs, max_sequence_length, num_heads * hidden_dim] then output a tensor + with shape [bs, num_heads, max_sequence_length, hidden_dim]. + """ + if num_heads == 1: + return x + + hidden_size = x.shape[-1] + # TODO: Decouple the program desc with batch_size. + reshaped = layers.reshape( + x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads]) + + # permuate the dimensions into: + # [batch_size, num_heads, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) == 3: return x + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # TODO: Decouple the program desc with batch_size. + return layers.reshape( + x=trans_x, + shape=map(int, + [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) + + def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): + """ + Scaled Dot-Product Attention + """ + + # TODO: Optimize the shape in reshape_op or softmax_op. + # The softmax_op only supports 2D tensor currently and cann't be used + # here. Additionally, the reshape_op cann't be used here, since the + # shape of product inferred in compile-time is not the actual shape in + # run-time and cann't be used to set the attribute of reshape_op. Thus, + # define the softmax temporarily. + def __softmax(x, eps=1e-9): + exp_out = layers.exp(x=x) + sum_out = layers.reduce_sum(x, dim=-1, keep_dim=False) + return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) + + scaled_q = layers.scale(x=q, scale=d_key**-0.5) + product = layers.matmul(x=scaled_q, y=k, transpose_y=True) + weights = __softmax(layers.elementwise_add(x=product, y=attn_bias)) + if dropout_rate: + weights = layers.dropout( + weights, dropout_prob=dropout_rate, is_test=False) + out = layers.matmul(weights, v) + return out + + q, k, v = __compute_qkv(queries, keys, values, num_heads, d_key, d_value) + + q = __split_heads(q, num_heads) + k = __split_heads(k, num_heads) + v = __split_heads(v, num_heads) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + bias_attr=False, + num_flatten_dims=2) + return proj_out + + +def positionwise_feed_forward(x, d_inner_hid, d_hid): + """ + Position-wise Feed-Forward Networks. + This consists of two linear transformations with a ReLU activation in + between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + bias_attr=False, + num_flatten_dims=2, + act="relu") + out = layers.fc(input=hidden, + size=d_hid, + bias_attr=False, + num_flatten_dims=2) + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): + """ + Add residual connection, layer normalization and droput on the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": + out = out + prev_out if prev_out else out + elif cmd == "n": + out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1) + elif cmd == "d": + if dropout: + out = layers.dropout(out, dropout_prob=dropout, is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def prepare_encoder(src_word, + src_pos, + src_vocab_size, + src_emb_dim, + src_pad_idx, + src_max_len, + dropout=0., + pos_pad_idx=0, + pos_enc_param_name=None): + """ + Add word embeddings and position encodings and output a tensor with shape + [batch_size, max_src_length_in_batch, d_model]. + This is used at the bottom of the encoder stacks. + """ + src_word_emb = layers.embedding( + src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx) + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + param_attr=fluid.ParamAttr( + name=pos_enc_param_name, trainable=False)) + enc_input = src_word_emb + src_pos_enc + # TODO: Decouple the program desc with batch_size + enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) + return layers.dropout( + enc_input, dropout_prob=dropout, + is_test=False) if dropout else enc_input + + +prepare_encoder = partial( + prepare_encoder, pos_enc_param_name=pos_enc_param_names[0]) +prepare_decoder = partial( + prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]) + + +def encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, d_model, + d_inner_hid, dropout): + """ + The layer to be stacked in the encoder. + This consits of multi-head (self) attention followed by position-wise + feed-forward networks and both the two components companied with the + post_process_layer to add residual connection, layer normalization and + droput. + """ + attn_output = multi_head_attention(enc_input, enc_input, enc_input, + attn_bias, d_key, d_value, d_model, + n_head, dropout) + attn_output = post_process_layer(enc_input, attn_output, "dan", dropout) + ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) + output = post_process_layer(attn_output, ffd_output, "dan", dropout) + return output + + +def encoder(enc_input, attn_bias, n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, dropout): + """ + The encoder is composed of a stack of identical encoder_layer layers. + """ + for i in range(n_layer): + enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, + d_model, d_inner_hid, dropout) + enc_input = enc_output + return enc_output + + +def decoder_layer(dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias, + n_head, d_key, d_value, d_model, d_inner_hid, dropout): + """ + The layer to be stacked in the decoder. The structure of this is similar to + the encoder_layer but another multi-head attention is added to implement + encoder-decoder attention. + """ + slf_attn_output = multi_head_attention(dec_input, dec_input, dec_input, + slf_attn_bias, d_key, d_value, + d_model, n_head, dropout) + slf_attn_output = post_process_layer(dec_input, slf_attn_output, "dan", + dropout) + enc_attn_output = multi_head_attention(slf_attn_output, enc_output, + enc_output, dec_enc_attn_bias, d_key, + d_value, d_model, n_head, dropout) + enc_attn_output = post_process_layer(slf_attn_output, enc_attn_output, + "dan", dropout) + ffd_output = positionwise_feed_forward(enc_attn_output, d_inner_hid, + d_model) + dec_output = post_process_layer(enc_attn_output, ffd_output, "dan", dropout) + return dec_output + + +def decoder(dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, + n_layer, n_head, d_key, d_value, d_model, d_inner_hid, dropout): + """ + The decoder is composed of a stack of identical decoder_layer layers. + """ + for i in range(n_layer): + dec_output = decoder_layer(dec_input, enc_output, dec_slf_attn_bias, + dec_enc_attn_bias, n_head, d_key, d_value, + d_model, d_inner_hid, dropout) + dec_input = dec_output + return dec_output + + +def transformer(src_vocab_size, trg_vocab_size, max_length, n_layer, n_head, + d_key, d_value, d_model, d_inner_hid, dropout, src_pad_idx, + trg_pad_idx, pos_pad_idx): + # The shapes here only act as placeholder and are set to guarantee the + # success of infer-shape in compile time. + # The actual shape of src_word is: + # [batch_size * max_src_length_in_batch, 1]. + src_word = layers.data( + name=input_data_names[0], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of src_pos is: + # [batch_size * max_src_length_in_batch, 1]. + src_pos = layers.data( + name=input_data_names[1], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of trg_word is: + # [batch_size * max_trg_length_in_batch, 1]. + trg_word = layers.data( + name=input_data_names[2], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of trg_pos is: + # [batch_size * max_trg_length_in_batch, 1]. + trg_pos = layers.data( + name=input_data_names[3], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + # The actual shape of src_slf_attn_bias is: + # [batch_size, n_head, max_src_length_in_batch, max_src_length_in_batch]. + # This is used to avoid attention on paddings. + src_slf_attn_bias = layers.data( + name=input_data_names[4], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + # The actual shape of trg_slf_attn_bias is: + # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch]. + # This is used to avoid attention on paddings and subsequent words. + trg_slf_attn_bias = layers.data( + name=input_data_names[5], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + # The actual shape of trg_src_attn_bias is: + # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch]. + # This is used to avoid attention on paddings. + trg_src_attn_bias = layers.data( + name=input_data_names[6], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + + enc_input = prepare_encoder(src_word, src_pos, src_vocab_size, d_model, + src_pad_idx, max_length, dropout) + enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key, + d_value, d_model, d_inner_hid, dropout) + + dec_input = prepare_decoder(trg_word, trg_pos, trg_vocab_size, d_model, + trg_pad_idx, max_length, dropout) + dec_output = decoder(dec_input, enc_output, trg_slf_attn_bias, + trg_src_attn_bias, n_layer, n_head, d_key, d_value, + d_model, d_inner_hid, dropout) + + # TODO: Share the same weight matrix between the two embedding layers and + # the pre-softmax linear transformation. + predict = layers.reshape( + x=layers.fc(input=dec_output, + size=trg_vocab_size, + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size], + act="softmax") + # The actual shape of gold is: + # [batch_size * max_trg_length_in_batch, 1]. + gold = layers.data( + name=input_data_names[7], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + cost = layers.cross_entropy(input=predict, label=gold) + avg_cost = layers.mean(x=cost) + return avg_cost diff --git a/fluid/NMT_Transformer/train.py b/fluid/NMT_Transformer/train.py new file mode 100644 index 00000000..c602fb21 --- /dev/null +++ b/fluid/NMT_Transformer/train.py @@ -0,0 +1,129 @@ +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +from model import transformer, position_encoding_init +from config import * + + +def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, + max_length, n_head, place): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. Then, convert the numpy + data to tensors and return a dict mapping names to tensors. + """ + input_dict = {} + + def pad_batch_data(insts, + pad_idx, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + return_list = [] + max_len = max(len(inst) for inst in insts) + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_data.astype("int64").reshape([-1, 1])] + if return_pos: + inst_pos = np.array([[ + pos_i + 1 if w_i != pad_idx else 0 + for pos_i, w_i in enumerate(inst) + ] for inst in inst_data]) + + return_list += [inst_pos.astype("int64").reshape([-1, 1])] + if return_attn_bias: + if is_target: + # This is used to avoid attention on paddings and subsequent + # words. + slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, + max_len)) + slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( + [-1, 1, max_len, max_len]) + slf_attn_bias_data = np.tile(slf_attn_bias_data, + [1, n_head, 1, 1]) * [-1e9] + else: + # This is used to avoid attention on paddings. + slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * + (max_len - len(inst)) + for inst in insts]) + slf_attn_bias_data = np.tile( + slf_attn_bias_data.reshape([-1, 1, 1, max_len]), + [1, n_head, max_len, 1]) + return_list += [slf_attn_bias_data.astype("float32")] + if return_max_len: + return_list += [max_len] + return return_list if len(return_list) > 1 else return_list[0] + + def data_to_tensor(data_list, name_list, input_dict, place): + assert len(data_list) == len(name_list) + for i in range(len(name_list)): + tensor = fluid.LoDTensor() + tensor.set(data_list[i], place) + input_dict[name_list[i]] = tensor + + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, is_target=False) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( + [inst[1] for inst in insts], trg_pad_idx, is_target=True) + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, trg_max_len, 1]).astype("float32") + lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False, + False, False, False) + + data_to_tensor([ + src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, + trg_slf_attn_bias, trg_src_attn_bias, lbl_word + ], input_data_names, input_dict, place) + + return input_dict + + +def main(): + avg_cost = transformer(src_vocab_size + 1, trg_vocab_size + 1, + max_length + 1, n_layer, n_head, d_key, d_value, + d_model, d_inner_hid, dropout, src_pad_idx, + trg_pad_idx, pos_pad_idx) + + optimizer = fluid.optimizer.Adam( + learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=eps) + optimizer.minimize(avg_cost) + + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.wmt16.train(src_vocab_size, trg_vocab_size), + buf_size=1000), + batch_size=batch_size) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + # Initialize the parameters. + exe.run(fluid.framework.default_startup_program()) + for pos_enc_param_name in pos_enc_param_names: + pos_enc_param = fluid.global_scope().find_var( + pos_enc_param_name).get_tensor() + pos_enc_param.set( + position_encoding_init(max_length + 1, d_model), place) + + batch_id = 0 + for pass_id in xrange(pass_num): + for data in train_data(): + data_input = prepare_batch_input(data, input_data_names, + src_pad_idx, trg_pad_idx, + max_length, n_head, place) + outs = exe.run(fluid.framework.default_main_program(), + feed=data_input, + fetch_list=[avg_cost]) + avg_cost_val = np.array(outs[0]) + print("pass_id=" + str(pass_id) + " batch=" + str(batch_id) + + " avg_cost=" + str(avg_cost_val)) + batch_id += 1 + + +if __name__ == "__main__": + main() -- GitLab