From 6ef54e8e56aaf0c2d8fd87b35e8f7e897c9437a3 Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 20 Mar 2018 20:42:09 +0800 Subject: [PATCH] Refine Transformer by following comments and fix the target self attention bias in inference. --- .../transformer/config.py | 20 +++---- .../transformer/infer.py | 34 ++++++++---- .../transformer/model.py | 54 ++++++++++++------- 3 files changed, 69 insertions(+), 39 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index a1e2fd90..71e43149 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -3,34 +3,36 @@ class TrainTaskConfig(object): # the epoch number to train. pass_num = 2 - # number of sequences contained in a mini-batch. + # the number of sequences contained in a mini-batch. batch_size = 64 - # the hyper params for Adam optimizer. + # the hyper parameters for Adam optimizer. learning_rate = 0.001 beta1 = 0.9 beta2 = 0.98 eps = 1e-9 - # the params for learning rate scheduling + # the parameters for learning rate scheduling. warmup_steps = 4000 - # the directory for saving inference models - model_dir = "transformer_model" + # the directory for saving trained models. + model_dir = "trained_models" class InferTaskConfig(object): use_gpu = False - # number of sequences contained in a mini-batch + # the number of examples in one run for sequence generation. + # currently the batch size can only be set to 1. batch_size = 1 - # the params for beam search + # the parameters for beam search. beam_size = 5 max_length = 30 + # the number of decoded sentences to output. n_best = 1 - # the directory for loading inference model - model_path = "transformer_model/pass_1.infer.model" + # the directory for loading the trained model. + model_path = "trained_models/pass_1.infer.model" class ModelHyperParams(object): diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 3ffa4e77..f5fdfb33 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -66,12 +66,19 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ -1], enc_in_data[-2], 1 + # This is used to remove attention on subsequent words. + trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, + trg_max_len)) + trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape( + [-1, 1, trg_max_len, trg_max_len]) + trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) * + [-1e9]).astype("float32") + # This is used to remove attention on the paddings of source sequences. trg_src_attn_bias = np.tile( src_slf_attn_bias[:, :, ::src_max_length, :], [beam_size, 1, trg_max_len, 1]) enc_output = np.tile(enc_output, [beam_size, 1, 1]) - # No need for trg_slf_attn_bias because of no paddings. - return trg_words, trg_pos, None, trg_src_attn_bias, enc_output + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output def update_dec_in_data(dec_in_data, next_ids, active_beams): """ @@ -79,6 +86,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, input data and dropping the finished instance beams. """ trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data + trg_cur_len = len(next_ids[0]) + 1 # include the trg_words = np.array( [ beam_backtrace( @@ -88,14 +96,22 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, dtype="int64") trg_words = trg_words.reshape([-1, 1]) trg_pos = np.array( - [range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size, + [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, dtype="int64").reshape([-1, 1]) active_beams_indice = ( (np.array(active_beams) * beam_size)[:, np.newaxis] + np.array(range(beam_size))[np.newaxis, :]).flatten() + # This is used to remove attention on subsequent words. + trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len, + trg_cur_len)) + trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape( + [-1, 1, trg_cur_len, trg_cur_len]) + trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) * + [-1e9]).astype("float32") + # This is used to remove attention on the paddings of source sequences. trg_src_attn_bias = np.tile(trg_src_attn_bias[ active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], - [1, 1, len(next_ids[0]) + 1, 1]) + [1, 1, trg_cur_len, 1]) enc_output = enc_output[active_beams_indice, :, :] return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output @@ -103,9 +119,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, enc_output) for i in range(max_length): predict_all = exe.run(decoder, - feed=dict( - filter(lambda item: item[1] is not None, - zip(dec_in_names, dec_in_data))), + feed=dict(zip(dec_in_names, dec_in_data)), fetch_list=dec_out_names)[0] predict_all = np.log(predict_all) predict_all = ( @@ -206,9 +220,9 @@ def main(): encoder_input_data_names, [enc_output.name], decoder_program, decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, InferTaskConfig.max_length, InferTaskConfig.n_best, - InferTaskConfig.batch_size, ModelHyperParams.n_head, - ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, - ModelHyperParams.bos_idx, ModelHyperParams.eos_idx) + len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, + ModelHyperParams.eos_idx) for i in range(len(batch_seqs)): seqs = batch_seqs[i] scores = batch_scores[i] diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index b814419e..ba5ba447 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -283,8 +283,15 @@ def encoder(enc_input, encoder_layer. """ for i in range(n_layer): - enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, - d_model, d_inner_hid, dropout_rate) + enc_output = encoder_layer( + enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, ) enc_input = enc_output return enc_output @@ -381,9 +388,10 @@ def make_inputs(input_data_names, d_model, batch_size, max_length, + is_pos, slf_attn_bias_flag, src_attn_bias_flag, - pos_flag=1): + enc_output_flag=False): """ Define the input data layers for the transformer model. """ @@ -391,35 +399,43 @@ def make_inputs(input_data_names, # The shapes here act as placeholder. # The shapes set here is to pass the infer-shape in compile time. word = layers.data( - name=input_data_names[0], + name=input_data_names[len(input_layers)], shape=[batch_size * max_length, 1], dtype="int64", append_batch_size=False) input_layers += [word] # This is used for position data or label weight. pos = layers.data( - name=input_data_names[1], + name=input_data_names[len(input_layers)], shape=[batch_size * max_length, 1], - dtype="int64" if pos_flag else "float32", + dtype="int64" if is_pos else "float32", append_batch_size=False) input_layers += [pos] if slf_attn_bias_flag: - # This is used for attention bias or encoder output. + # This input is used to remove attention weights on paddings for the + # encoder and to remove attention weights on subsequent words for the + # decoder. slf_attn_bias = layers.data( - name=input_data_names[2] - if slf_attn_bias_flag == 1 else input_data_names[-1], - shape=[batch_size, n_head, max_length, max_length] - if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model], + name=input_data_names[len(input_layers)], + shape=[batch_size, n_head, max_length, max_length], dtype="float32", append_batch_size=False) input_layers += [slf_attn_bias] if src_attn_bias_flag: + # This input is used to remove attention weights on paddings. src_attn_bias = layers.data( - name=input_data_names[3], + name=input_data_names[len(input_layers)], shape=[batch_size, n_head, max_length, max_length], dtype="float32", append_batch_size=False) input_layers += [src_attn_bias] + if enc_output_flag: + enc_output = layers.data( + name=input_data_names[len(input_layers)], + shape=[batch_size, max_length, d_model], + dtype="float32", + append_batch_size=False) + input_layers += [enc_output] return input_layers @@ -438,7 +454,7 @@ def transformer( trg_pad_idx, pos_pad_idx, ): enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, - batch_size, max_length, 1, 0) + batch_size, max_length, True, True, False) enc_output = wrap_encoder( src_vocab_size, @@ -455,7 +471,7 @@ def transformer( enc_input_layers, ) dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, - batch_size, max_length, 1, 1) + batch_size, max_length, True, True, True) predict = wrap_decoder( trg_vocab_size, @@ -475,7 +491,7 @@ def transformer( # Padding index do not contribute to the total loss. The weights is used to # cancel padding index in calculating the loss. gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, - max_length, 0, 0, 0) + max_length, False, False, False) cost = layers.cross_entropy(input=predict, label=gold) weighted_cost = cost * weights return layers.reduce_sum(weighted_cost), predict @@ -500,7 +516,7 @@ def wrap_encoder(src_vocab_size, # This is used to implement independent encoder program in inference. src_word, src_pos, src_slf_attn_bias = make_inputs( encoder_input_data_names, n_head, d_model, batch_size, max_length, - True, False) + True, True, False) else: src_word, src_pos, src_slf_attn_bias = enc_input_layers enc_input = prepare_encoder( @@ -542,11 +558,9 @@ def wrap_decoder(trg_vocab_size, """ if dec_input_layers is None: # This is used to implement independent decoder program in inference. - # No need for trg_slf_attn_bias because of no paddings in inference. - trg_word, trg_pos, enc_output, trg_src_attn_bias = make_inputs( + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs( decoder_input_data_names, n_head, d_model, batch_size, max_length, - 2, 1) - trg_slf_attn_bias = None + True, True, True, True) else: trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers -- GitLab