From a6ec3a0d55a0cba5c954bce2961af7ce5b5520f2 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 8 Jun 2018 17:10:39 +0800 Subject: [PATCH] Tune the Transformer model for wmt14 --- .../transformer/config.py | 83 ++++--- .../transformer/model.py | 211 ++++++++++-------- .../transformer/train.py | 2 + 3 files changed, 176 insertions(+), 120 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 8ab9efce..00475854 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -1,18 +1,18 @@ class TrainTaskConfig(object): use_gpu = True # the epoch number to train. - pass_num = 30 + pass_num = 200 # the number of sequences contained in a mini-batch. batch_size = 32 # the hyper parameters for Adam optimizer. # This static learning_rate will be multiplied to the LearningRateScheduler # derived learning rate the to get the final learning rate. - learning_rate = 1 + learning_rate = 2 beta1 = 0.9 - beta2 = 0.98 + beta2 = 0.997 eps = 1e-9 # the parameters for learning rate scheduling. - warmup_steps = 4000 + warmup_steps = 8000 # the flag indicating to use average loss or sum loss when training. use_avg_cost = True # the weight used to mix up the ground-truth distribution and the fixed @@ -33,12 +33,12 @@ class TrainTaskConfig(object): class InferTaskConfig(object): - use_gpu = True + use_gpu = False # the number of examples in one run for sequence generation. - batch_size = 10 + batch_size = 2 # the parameters for beam search. beam_size = 5 - max_length = 30 + max_out_len = 30 # the number of decoded sentences to output. n_best = 1 # the flags indicating whether to output the special tokens. @@ -55,26 +55,26 @@ class ModelHyperParams(object): # included in dict can be used to pad, since the paddings' loss will be # masked out and make no effect on parameter gradients. # size of source word dictionary. - src_vocab_size = 10000 + src_vocab_size = 50000 # size of target word dictionay - trg_vocab_size = 10000 + trg_vocab_size = 50000 # index for token - bos_idx = 0 + bos_idx = 1 # index for token - eos_idx = 1 + eos_idx = 2 # index for token - unk_idx = 2 + unk_idx = 0 # max length of sequences. # The size of position encoding table should at least plus 1, since the # sinusoid position encoding starts from 1 and 0 can be used as the padding # token for position encoding. - max_length = 50 + max_length = 256 # the dimension for word embeddings, which is also the last dimension of # the input and output of multi-head attention, position-wise feed-forward # networks, encoder and decoder. d_model = 512 # size of the hidden layer in position-wise feed-forward networks. - d_inner_hid = 1024 + d_inner_hid = 2048 # the dimension that keys are projected to for dot-product attention. d_key = 64 # the dimension that values are projected to for dot-product attention. @@ -89,7 +89,7 @@ class ModelHyperParams(object): def merge_cfg_from_list(cfg_list, g_cfgs): """ - Set the above global configurations using the cfg_list. + Set the above global configurations using the cfg_list. """ assert len(cfg_list) % 2 == 0 for key, value in zip(cfg_list[0::2], cfg_list[1::2]): @@ -103,23 +103,28 @@ def merge_cfg_from_list(cfg_list, g_cfgs): break +# The placeholder for batch_size in compile time. Must be -1 currently to be +# consistent with some ops' infer-shape output in compile time, such as the +# sequence_expand op used in beamsearch decoder. +batch_size = -1 +# The placeholder for squence length in compile time. +seq_len = ModelHyperParams.max_length # Here list the data shapes and data types of all inputs. # The shapes here act as placeholder and are set to pass the infer-shape in # compile time. input_descs = { # The actual data shape of src_word is: # [batch_size * max_src_len_in_batch, 1] - "src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "src_word": [(batch_size * seq_len, 1L), "int64", 2], # The actual data shape of src_pos is: # [batch_size * max_src_len_in_batch, 1] - "src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "src_pos": [(batch_size * seq_len, 1L), "int64"], # This input is used to remove attention weights on paddings in the # encoder. # The actual data shape of src_slf_attn_bias is: # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] - "src_slf_attn_bias": - [(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), - (ModelHyperParams.max_length + 1)), "float32"], + "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], # This shape input is used to reshape the output of embedding layer. "src_data_shape": [(3L, ), "int32"], # This shape input is used to reshape before softmax in self attention. @@ -128,24 +133,23 @@ input_descs = { "src_slf_attn_post_softmax_shape": [(4L, ), "int32"], # The actual data shape of trg_word is: # [batch_size * max_trg_len_in_batch, 1] - "trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "trg_word": [(batch_size * seq_len, 1L), "int64", + 2], # lod_level is only used in fast decoder. # The actual data shape of trg_pos is: # [batch_size * max_trg_len_in_batch, 1] - "trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "trg_pos": [(batch_size * seq_len, 1L), "int64"], # This input is used to remove attention weights on paddings and # subsequent words in the decoder. # The actual data shape of trg_slf_attn_bias is: # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] - "trg_slf_attn_bias": [(1, ModelHyperParams.n_head, - (ModelHyperParams.max_length + 1), - (ModelHyperParams.max_length + 1)), "float32"], + "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], # This input is used to remove attention weights on paddings of the source # input in the encoder-decoder attention. # The actual data shape of trg_src_attn_bias is: # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] - "trg_src_attn_bias": [(1, ModelHyperParams.n_head, - (ModelHyperParams.max_length + 1), - (ModelHyperParams.max_length + 1)), "float32"], + "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, + seq_len), "float32"], # This shape input is used to reshape the output of embedding layer. "trg_data_shape": [(3L, ), "int32"], # This shape input is used to reshape before softmax in self attention. @@ -161,17 +165,23 @@ input_descs = { # This input is used in independent decoder program for inference. # The actual data shape of enc_output is: # [batch_size, max_src_len_in_batch, d_model] - "enc_output": [(1, (ModelHyperParams.max_length + 1), - ModelHyperParams.d_model), "float32"], + "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], # The actual data shape of label_word is: # [batch_size * max_trg_len_in_batch, 1] - "lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "lbl_word": [(batch_size * seq_len, 1L), "int64"], # This input is used to mask out the loss of paddding tokens. # The actual data shape of label_weight is: # [batch_size * max_trg_len_in_batch, 1] - "lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], + "lbl_weight": [(batch_size * seq_len, 1L), "float32"], + # These inputs are used to change the shape tensor in beam-search decoder. + "trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"], + "trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"], + "init_score": [(batch_size, 1L), "float32"], } +word_emb_param_names = ( + "src_word_emb_table", + "trg_word_emb_table", ) # Names of position encoding table which will be initialized externally. pos_enc_param_names = ( "src_pos_enc_table", @@ -200,3 +210,12 @@ decoder_util_input_fields = ( label_data_input_fields = ( "lbl_word", "lbl_weight", ) +# In fast decoder, trg_pos (only containing the current time step) is generated +# by ops and trg_slf_attn_bias is not needed. +fast_decoder_data_input_fields = ( + "trg_word", + "init_score", + "trg_src_attn_bias", ) +fast_decoder_util_input_fields = decoder_util_input_fields + ( + "trg_slf_attn_pre_softmax_shape_delta", + "trg_slf_attn_post_softmax_shape_delta", ) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 9c5d8adc..ec88c342 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -6,6 +6,8 @@ import paddle.fluid.layers as layers from config import * +WEIGHT_SHARING = True + def position_encoding_init(n_position, d_pos_vec): """ @@ -30,7 +32,8 @@ def multi_head_attention(queries, n_head=1, dropout_rate=0., pre_softmax_shape=None, - post_softmax_shape=None): + post_softmax_shape=None, + cache=None): """ Multi-Head Attention. Note that attn_bias is added to the logit before computing softmax activiation to mask certain selected positions so that @@ -44,30 +47,30 @@ def multi_head_attention(queries, """ Add linear projection to queries, keys, and values. """ - q = layers.fc(input=queries, - size=d_key * n_head, - param_attr=fluid.initializer.Xavier( - uniform=False, - fan_in=d_model * d_key, - fan_out=n_head * d_key), - bias_attr=False, - num_flatten_dims=2) - k = layers.fc(input=keys, - size=d_key * n_head, - param_attr=fluid.initializer.Xavier( - uniform=False, - fan_in=d_model * d_key, - fan_out=n_head * d_key), - bias_attr=False, - num_flatten_dims=2) - v = layers.fc(input=values, - size=d_value * n_head, - param_attr=fluid.initializer.Xavier( - uniform=False, - fan_in=d_model * d_value, - fan_out=n_head * d_value), - bias_attr=False, - num_flatten_dims=2) + q = layers.fc( + input=queries, + size=d_key * n_head, + param_attr=fluid.initializer.Xavier(uniform=True), + # fan_in=d_model * d_key, + # fan_out=n_head * d_key), + bias_attr=False, + num_flatten_dims=2) + k = layers.fc( + input=keys, + size=d_key * n_head, + param_attr=fluid.initializer.Xavier(uniform=True), + # fan_in=d_model * d_key, + # fan_out=n_head * d_key), + bias_attr=False, + num_flatten_dims=2) + v = layers.fc( + input=values, + size=d_value * n_head, + param_attr=fluid.initializer.Xavier(uniform=True), + # fan_in=d_model * d_value, + # fan_out=n_head * d_value), + bias_attr=False, + num_flatten_dims=2) return q, k, v def __split_heads(x, n_head): @@ -84,7 +87,7 @@ def multi_head_attention(queries, # The value 0 in shape attr means copying the corresponding dimension # size of the input as the output dimension size. reshaped = layers.reshape( - x=x, shape=[0, -1, n_head, hidden_size // n_head]) + x=x, shape=[0, 0, n_head, hidden_size // n_head]) # permuate the dimensions into: # [batch_size, n_head, max_sequence_len, hidden_size_per_head] @@ -104,13 +107,13 @@ def multi_head_attention(queries, # size of the input as the output dimension size. return layers.reshape( x=trans_x, - shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]])) + shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]])) def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): """ Scaled Dot-Product Attention """ - scaled_q = layers.scale(x=q, scale=d_model**-0.5) + scaled_q = layers.scale(x=q, scale=d_key**-0.5) product = layers.matmul(x=scaled_q, y=k, transpose_y=True) weights = layers.reshape( x=layers.elementwise_add( @@ -123,11 +126,15 @@ def multi_head_attention(queries, if dropout_rate: weights = layers.dropout( weights, dropout_prob=dropout_rate, is_test=False) + out = layers.matmul(weights, v) return out q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + if cache is not None: # use cache and concat time steps + k = cache["k"] = layers.concat([cache["k"], k], axis=1) + v = cache["v"] = layers.concat([cache["v"], v], axis=1) q = __split_heads(q, n_head) k = __split_heads(k, n_head) v = __split_heads(v, n_head) @@ -136,7 +143,6 @@ def multi_head_attention(queries, dropout_rate) out = __combine_heads(ctx_multiheads) - # Project back to the model size. proj_out = layers.fc(input=out, size=d_model, @@ -146,23 +152,32 @@ def multi_head_attention(queries, return proj_out -def positionwise_feed_forward(x, d_inner_hid, d_hid): +def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate=0.): """ Position-wise Feed-Forward Networks. This module consists of two linear transformations with a ReLU activation in between, which is applied to each position separately and identically. """ - hidden = layers.fc(input=x, - size=d_inner_hid, - num_flatten_dims=2, - param_attr=fluid.initializer.Uniform( - low=-(d_hid**-0.5), high=(d_hid**-0.5)), - act="relu") - out = layers.fc(input=hidden, - size=d_hid, - num_flatten_dims=2, - param_attr=fluid.initializer.Uniform( - low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5))) + hidden = layers.fc( + input=x, + size=d_inner_hid, + num_flatten_dims=2, + param_attr=fluid.initializer.Xavier(uniform=True), + #param_attr=fluid.initializer.Uniform( + # low=-(d_hid**-0.5), high=(d_hid**-0.5)), + bias_attr=True, + act="relu") + if dropout_rate: + hidden = layers.dropout( + hidden, dropout_prob=dropout_rate, is_test=False) + out = layers.fc( + input=hidden, + size=d_hid, + num_flatten_dims=2, + param_attr=fluid.initializer.Xavier(uniform=True), + #param_attr=fluid.initializer.Uniform( + # low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)), + bias_attr=True) return out @@ -200,6 +215,7 @@ def prepare_encoder(src_word, src_max_len, dropout_rate=0., src_data_shape=None, + word_emb_param_name=None, pos_enc_param_name=None): """Add word embeddings and position encodings. The output tensor has a shape of: @@ -209,7 +225,10 @@ def prepare_encoder(src_word, src_word_emb = layers.embedding( src_word, size=[src_vocab_size, src_emb_dim], - param_attr=fluid.initializer.Normal(0., 1.)) + param_attr=fluid.ParamAttr( + name=word_emb_param_name, + initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_pos_enc = layers.embedding( src_pos, size=[src_max_len, src_emb_dim], @@ -218,7 +237,7 @@ def prepare_encoder(src_word, enc_input = src_word_emb + src_pos_enc enc_input = layers.reshape( x=enc_input, - shape=[-1, src_max_len, src_emb_dim], + shape=[batch_size, seq_len, src_emb_dim], actual_shape=src_data_shape) return layers.dropout( enc_input, dropout_prob=dropout_rate, @@ -226,9 +245,14 @@ def prepare_encoder(src_word, prepare_encoder = partial( - prepare_encoder, pos_enc_param_name=pos_enc_param_names[0]) + prepare_encoder, + word_emb_param_name=word_emb_param_names[0], + pos_enc_param_name=pos_enc_param_names[0]) prepare_decoder = partial( - prepare_encoder, pos_enc_param_name=pos_enc_param_names[1]) + prepare_encoder, + word_emb_param_name=word_emb_param_names[0] + if WEIGHT_SHARING else word_emb_param_names[1], + pos_enc_param_name=pos_enc_param_names[1]) def encoder_layer(enc_input, @@ -247,13 +271,14 @@ def encoder_layer(enc_input, with the post_process_layer to add residual connection, layer normalization and droput. """ - attn_output = multi_head_attention( - enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, - n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) - attn_output = post_process_layer(enc_input, attn_output, "dan", - dropout_rate) - ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) - return post_process_layer(attn_output, ffd_output, "dan", dropout_rate) + q = k = v = pre_process_layer(enc_input, "n") + attn_output = multi_head_attention(q, k, v, attn_bias, d_key, d_value, + d_model, n_head, dropout_rate, + pre_softmax_shape, post_softmax_shape) + attn_output = post_process_layer(enc_input, attn_output, "da", dropout_rate) + ffd_output = positionwise_feed_forward( + pre_process_layer(attn_output, "n"), d_inner_hid, d_model, dropout_rate) + return post_process_layer(attn_output, ffd_output, "da", dropout_rate) def encoder(enc_input, @@ -284,6 +309,7 @@ def encoder(enc_input, pre_softmax_shape, post_softmax_shape, ) enc_input = enc_output + enc_output = pre_process_layer(enc_output, "n") return enc_output @@ -300,15 +326,17 @@ def decoder_layer(dec_input, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, - src_attn_post_softmax_shape=None): + src_attn_post_softmax_shape=None, + cache=None): """ The layer to be stacked in decoder part. The structure of this module is similar to that in the encoder part except a multi-head attention is added to implement encoder-decoder attention. """ + q = k = v = pre_process_layer(dec_input, "n") slf_attn_output = multi_head_attention( - dec_input, - dec_input, - dec_input, + q, + k, + v, slf_attn_bias, d_key, d_value, @@ -316,14 +344,15 @@ def decoder_layer(dec_input, n_head, dropout_rate, slf_attn_pre_softmax_shape, - slf_attn_post_softmax_shape, ) + slf_attn_post_softmax_shape, + cache, ) slf_attn_output = post_process_layer( dec_input, slf_attn_output, - "dan", # residual connection + dropout + layer normalization + "da", # residual connection + dropout + layer normalization dropout_rate, ) enc_attn_output = multi_head_attention( - slf_attn_output, + pre_process_layer(slf_attn_output, "n"), enc_output, enc_output, dec_enc_attn_bias, @@ -337,16 +366,17 @@ def decoder_layer(dec_input, enc_attn_output = post_process_layer( slf_attn_output, enc_attn_output, - "dan", # residual connection + dropout + layer normalization + "da", # residual connection + dropout + layer normalization dropout_rate, ) ffd_output = positionwise_feed_forward( - enc_attn_output, + pre_process_layer(enc_attn_output, "n"), d_inner_hid, - d_model, ) + d_model, + dropout_rate, ) dec_output = post_process_layer( enc_attn_output, ffd_output, - "dan", # residual connection + dropout + layer normalization + "da", # residual connection + dropout + layer normalization dropout_rate, ) return dec_output @@ -365,27 +395,20 @@ def decoder(dec_input, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, - src_attn_post_softmax_shape=None): + src_attn_post_softmax_shape=None, + caches=None): """ The decoder is composed of a stack of identical decoder_layer layers. """ for i in range(n_layer): dec_output = decoder_layer( - dec_input, - enc_output, - dec_slf_attn_bias, - dec_enc_attn_bias, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, - slf_attn_pre_softmax_shape, - slf_attn_post_softmax_shape, - src_attn_pre_softmax_shape, - src_attn_post_softmax_shape, ) + dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head, + d_key, d_value, d_model, d_inner_hid, dropout_rate, + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None + if caches is None else caches[i]) dec_input = dec_output + dec_output = pre_process_layer(dec_output, "n") return dec_output @@ -399,6 +422,8 @@ def make_all_inputs(input_fields): name=input_field, shape=input_descs[input_field][0], dtype=input_descs[input_field][1], + lod_level=input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0, append_batch_size=False) inputs.append(input_var) return inputs @@ -459,7 +484,6 @@ def transformer( logits=predict, label=label, soft_label=True if label_smooth_eps else False) - # cost = layers.softmax_with_cross_entropy(logits=predict, label=gold) weighted_cost = cost * weights sum_cost = layers.reduce_sum(weighted_cost) token_num = layers.reduce_sum(weights) @@ -523,7 +547,8 @@ def wrap_decoder(trg_vocab_size, d_inner_hid, dropout_rate, dec_inputs=None, - enc_output=None): + enc_output=None, + caches=None): """ The wrapper assembles together all needed layers for the decoder. """ @@ -563,13 +588,23 @@ def wrap_decoder(trg_vocab_size, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, - src_attn_post_softmax_shape, ) + src_attn_post_softmax_shape, + caches, ) # Return logits for training and probs for inference. - predict = layers.reshape( - x=layers.fc(input=dec_output, - size=trg_vocab_size, - bias_attr=False, - num_flatten_dims=2), - shape=[-1, trg_vocab_size], - act="softmax" if dec_inputs is None else None) + if not WEIGHT_SHARING: + predict = layers.reshape( + x=layers.fc(input=dec_output, + size=trg_vocab_size, + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size], + act="softmax" if dec_inputs is None else None) + else: + predict = layers.reshape( + x=layers.matmul( + x=dec_output, + y=fluid.get_var(word_emb_param_names[0]), + transpose_y=True), + shape=[-1, trg_vocab_size], + act="softmax" if dec_inputs is None else None) return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index bf9edb52..34a75be6 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -288,6 +288,7 @@ def train(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], + max_length=ModelHyperParams.max_length, clip_last_batch=False) train_data = read_multiple(reader=train_data.batch_generator) @@ -319,6 +320,7 @@ def train(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], + max_length=ModelHyperParams.max_length, clip_last_batch=False, shuffle=False, shuffle_batch=False) -- GitLab