diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index e68ab17e69eff890cb8e6b028ead5e6163213761..ed8f95dde6830fd49419226d4446f8bb12be9fa6 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -79,8 +79,14 @@ class ModelHyperParams(object): n_head = 8 # number of sub-layers to be stacked in the encoder and decoder. n_layer = 6 - # dropout rate used by all dropout layers. - dropout = 0.1 + # dropout rates of different modules. + prepostprocess_dropout = 0.1 + attention_dropout = 0.1 + relu_dropout = 0.1 + # to process before each sub-layer + preprocess_cmd = "n" # layer normalization + # to process after each sub-layer + postprocess_cmd = "da" # dropout + residual connection # the flag indicating whether to share embedding and softmax weights. # vocabularies in source and target should be same for weight sharing. weight_sharing = True diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 505bf0b0062bda27a0299ed7d844e2f05abd95b8..7d1ecab0cc72e03f13703e3f322672bd78ea05a8 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -335,7 +335,10 @@ def py_infer(test_data, trg_idx2word): ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, - ModelHyperParams.dropout, ModelHyperParams.weight_sharing) + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing) decoder_program = fluid.Program() with fluid.program_guard(main_program=decoder_program): @@ -344,7 +347,10 @@ def py_infer(test_data, trg_idx2word): ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, - ModelHyperParams.dropout, ModelHyperParams.weight_sharing) + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing) # Load model parameters of encoder and decoder separately from the saved # transformer model. @@ -477,7 +483,9 @@ def fast_infer(test_data, trg_idx2word): ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, - ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd, ModelHyperParams.weight_sharing, InferTaskConfig.beam_size, InferTaskConfig.max_out_len, ModelHyperParams.eos_idx) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 46c9f7a9065765b1e5ab5fa4d66042fc3312f75a..f7e9ffe8b39907de748a590b450d9a0ee0a86626 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -37,6 +37,9 @@ def multi_head_attention(queries, computing softmax activiation to mask certain selected positions so that they will not considered in attention weights. """ + keys = queries if keys is None else keys + values = keys if values is None else values + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): raise ValueError( "Inputs: quries, keys and values should all be 3-D tensors.") @@ -95,11 +98,11 @@ def multi_head_attention(queries, x=trans_x, shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]])) - def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): + def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): """ Scaled Dot-Product Attention """ - scaled_q = layers.scale(x=q, scale=d_model**-0.5) + scaled_q = layers.scale(x=q, scale=d_key**-0.5) product = layers.matmul(x=scaled_q, y=k, transpose_y=True) weights = layers.reshape( x=layers.elementwise_add( @@ -138,7 +141,7 @@ def multi_head_attention(queries, return proj_out -def positionwise_feed_forward(x, d_inner_hid, d_hid): +def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): """ Position-wise Feed-Forward Networks. This module consists of two linear transformations with a ReLU activation @@ -148,6 +151,9 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): size=d_inner_hid, num_flatten_dims=2, act="relu") + if dropout_rate: + hidden = layers.dropout( + hidden, dropout_prob=dropout_rate, is_test=False) out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) return out @@ -228,7 +234,11 @@ def encoder_layer(enc_input, d_value, d_model, d_inner_hid, - dropout_rate=0., + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, pre_softmax_shape=None, post_softmax_shape=None): """The encoder layers that can be stacked to form a deep encoder. @@ -238,12 +248,16 @@ def encoder_layer(enc_input, and droput. """ attn_output = multi_head_attention( - enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, - n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) - attn_output = post_process_layer(enc_input, attn_output, "dan", - dropout_rate) - ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) - return post_process_layer(attn_output, ffd_output, "dan", dropout_rate) + pre_process_layer(enc_input, preprocess_cmd, prepostprocess_dropout), + None, None, attn_bias, d_key, d_value, d_model, n_head, + attention_dropout, pre_softmax_shape, post_softmax_shape) + attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd, + prepostprocess_dropout) + ffd_output = positionwise_feed_forward( + pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout), + d_inner_hid, d_model, relu_dropout) + return post_process_layer(attn_output, ffd_output, postprocess_cmd, + prepostprocess_dropout) def encoder(enc_input, @@ -254,7 +268,11 @@ def encoder(enc_input, d_value, d_model, d_inner_hid, - dropout_rate=0., + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, pre_softmax_shape=None, post_softmax_shape=None): """ @@ -270,10 +288,16 @@ def encoder(enc_input, d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, pre_softmax_shape, post_softmax_shape, ) enc_input = enc_output + enc_output = pre_process_layer(enc_output, preprocess_cmd, + prepostprocess_dropout) return enc_output @@ -286,7 +310,11 @@ def decoder_layer(dec_input, d_value, d_model, d_inner_hid, - dropout_rate=0., + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, @@ -297,25 +325,26 @@ def decoder_layer(dec_input, a multi-head attention is added to implement encoder-decoder attention. """ slf_attn_output = multi_head_attention( - dec_input, - dec_input, - dec_input, + pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout), + None, + None, slf_attn_bias, d_key, d_value, d_model, n_head, - dropout_rate, + attention_dropout, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, cache, ) slf_attn_output = post_process_layer( dec_input, slf_attn_output, - "dan", # residual connection + dropout + layer normalization - dropout_rate, ) + postprocess_cmd, + prepostprocess_dropout, ) enc_attn_output = multi_head_attention( - slf_attn_output, + pre_process_layer(slf_attn_output, preprocess_cmd, + prepostprocess_dropout), enc_output, enc_output, dec_enc_attn_bias, @@ -323,23 +352,25 @@ def decoder_layer(dec_input, d_value, d_model, n_head, - dropout_rate, + attention_dropout, src_attn_pre_softmax_shape, src_attn_post_softmax_shape, ) enc_attn_output = post_process_layer( slf_attn_output, enc_attn_output, - "dan", # residual connection + dropout + layer normalization - dropout_rate, ) + postprocess_cmd, + prepostprocess_dropout, ) ffd_output = positionwise_feed_forward( - enc_attn_output, + pre_process_layer(enc_attn_output, preprocess_cmd, + prepostprocess_dropout), d_inner_hid, - d_model, ) + d_model, + relu_dropout, ) dec_output = post_process_layer( enc_attn_output, ffd_output, - "dan", # residual connection + dropout + layer normalization - dropout_rate, ) + postprocess_cmd, + prepostprocess_dropout, ) return dec_output @@ -353,7 +384,11 @@ def decoder(dec_input, d_value, d_model, d_inner_hid, - dropout_rate=0., + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, @@ -373,13 +408,19 @@ def decoder(dec_input, d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None if caches is None else caches[i], ) dec_input = dec_output + dec_output = pre_process_layer(dec_output, preprocess_cmd, + prepostprocess_dropout) return dec_output @@ -410,7 +451,11 @@ def transformer( d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, label_smooth_eps, ): if weight_sharing: @@ -429,7 +474,11 @@ def transformer( d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, enc_inputs, ) @@ -445,7 +494,11 @@ def transformer( d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, dec_inputs, enc_output, ) @@ -477,7 +530,11 @@ def wrap_encoder(src_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, enc_inputs=None): """ @@ -499,7 +556,7 @@ def wrap_encoder(src_vocab_size, src_vocab_size, d_model, max_length, - dropout_rate, + prepostprocess_dropout, src_data_shape, word_emb_param_name=word_emb_param_names[0]) enc_output = encoder( @@ -511,7 +568,11 @@ def wrap_encoder(src_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, ) return enc_output @@ -525,7 +586,11 @@ def wrap_decoder(trg_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, dec_inputs=None, enc_output=None, @@ -552,7 +617,7 @@ def wrap_decoder(trg_vocab_size, trg_vocab_size, d_model, max_length, - dropout_rate, + prepostprocess_dropout, trg_data_shape, word_emb_param_name=word_emb_param_names[0] if weight_sharing else word_emb_param_names[1]) @@ -567,7 +632,11 @@ def wrap_decoder(trg_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, @@ -603,7 +672,11 @@ def fast_decode( d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, beam_size, max_out_len, @@ -612,9 +685,10 @@ def fast_decode( Use beam search to decode. Caches will be used to store states of history steps which can make the decoding faster. """ - enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head, - d_key, d_value, d_model, d_inner_hid, - dropout_rate, weight_sharing) + enc_output = wrap_encoder( + src_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, + preprocess_cmd, postprocess_cmd, weight_sharing) start_tokens, init_scores, trg_src_attn_bias, trg_data_shape, \ slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ @@ -679,7 +753,11 @@ def fast_decode( d_value, d_model, d_inner_hid, - dropout_rate, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, weight_sharing, dec_inputs=( pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape, diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index d2cd5a185b2b4e2a35b5a485cd2be8b6e0f488de..085c58ac6fd9c83e464af9bc912f457fbddc39a2 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -454,7 +454,9 @@ def train(args): ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, - ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd, ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps) lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, TrainTaskConfig.warmup_steps,