From 10de2bf3e65d9f1fdb5244e59f5f9a2f3397c8a6 Mon Sep 17 00:00:00 2001 From: guosheng Date: Sun, 8 Apr 2018 20:59:25 +0800 Subject: [PATCH] Avoid predicting by restricting the size of the final fc_layer in Transformer. --- .../transformer/infer.py | 3 +- .../transformer/model.py | 11 ++-- .../transformer/train.py | 59 +++++++++++++------ 3 files changed, 49 insertions(+), 24 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 5634f160..8bc99dde 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -39,9 +39,10 @@ def translate_batch(exe, enc_in_data = pad_batch_data( src_words, src_pad_idx, + eos_idx, n_head, is_target=False, - return_pos=True, + is_label=False, return_attn_bias=True, return_max_len=False) # Append the data shape input to reshape the output of embedding layer. diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index eabd17a7..1c59c67d 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -724,10 +724,11 @@ def wrap_decoder(trg_vocab_size, src_attn_post_softmax_shape, ) # Return logits for training and probs for inference. predict = layers.reshape( - x=layers.fc(input=dec_output, - size=trg_vocab_size, - bias_attr=False, - num_flatten_dims=2), - shape=[-1, trg_vocab_size], + x=layers.fc( + input=dec_output, + size=trg_vocab_size - 1, # To exclude . + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size - 1], act="softmax" if dec_inputs is None else None) return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 0887a954..5973e11d 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -13,9 +13,10 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \ def pad_batch_data(insts, pad_idx, + eos_idx, n_head, is_target=False, - return_pos=True, + is_label=False, return_attn_bias=True, return_max_len=True): """ @@ -24,14 +25,22 @@ def pad_batch_data(insts, """ return_list = [] max_len = max(len(inst) for inst in insts) - inst_data = np.array( - [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + # Since we restrict the predicted probs excluding the to avoid + # generating the , also replace the with others in labels. + inst_data = np.array([ + inst + [eos_idx if is_label else pad_idx] * (max_len - len(inst)) + for inst in insts + ]) return_list += [inst_data.astype("int64").reshape([-1, 1])] - if return_pos: - inst_pos = np.array([[ - pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst) - ] for inst in inst_data]) - + if is_label: # label weight + inst_weight = np.array( + [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_weight.astype("float32").reshape([-1, 1])] + else: # position data + inst_pos = np.array([ + range(1, len(inst) + 1) + [0] * (max_len - len(inst)) + for inst in insts + ]) return_list += [inst_pos.astype("int64").reshape([-1, 1])] if return_attn_bias: if is_target: @@ -57,14 +66,22 @@ def pad_batch_data(insts, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, - n_head, d_model): + eos_idx, n_head, d_model): """ Put all padded data needed by training into a dict. """ src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( - [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + [inst[0] for inst in insts], + src_pad_idx, + eos_idx, + n_head, + is_target=False) trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( - [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) + [inst[1] for inst in insts], + trg_pad_idx, + eos_idx, + n_head, + is_target=True) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, trg_max_len, 1]).astype("float32") @@ -84,9 +101,15 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, trg_src_attn_post_softmax_shape = np.array( trg_src_attn_bias.shape, dtype="int32") - lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, - False, False, False, False) - lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) + lbl_word, lbl_weight = pad_batch_data( + [inst[2] for inst in insts], + trg_pad_idx, + eos_idx, + n_head, + is_target=False, + is_label=True, + return_attn_bias=False, + return_max_len=False) input_dict = dict( zip(input_data_names, [ @@ -146,8 +169,8 @@ def main(): data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, - ModelHyperParams.d_model) + ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model) test_sum_cost, test_token_num = exe.run( test_program, feed=data_input, @@ -174,8 +197,8 @@ def main(): data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, - ModelHyperParams.d_model) + ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model) lr_scheduler.update_learning_rate(data_input) outs = exe.run(fluid.framework.default_main_program(), feed=data_input, -- GitLab