diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 4e1ecec3bd285d5f5c1a1f5714ce0b050b35837e..9844465b097b08d64eea30af400545904c65e2a8 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -60,11 +60,10 @@ class ModelHyperParams(object): # index for token unk_idx = 2 - # position value corresponding to the token. - pos_pad_idx = 0 - - # max length of sequences. It should plus 1 to include position - # padding token for position encoding. + # max length of sequences. + # The size of position encoding table should plus 1, since the sinusoid + # position encoding start from 1 and 0 can be used as the padding token + # for position encoding. max_length = 50 # the dimension for word embeddings, which is also the last dimension of diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 9d6eb2cb5aba4a62cdd300e137a7d10c9bbe19d8..ad7fc2fa39db15698842aae26c80d86f7592775b 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -39,7 +39,6 @@ def translate_batch(exe, enc_in_data = pad_batch_data( src_words, src_pad_idx, - eos_idx, n_head, is_target=False, is_label=False, @@ -251,7 +250,7 @@ def main(): encoder_program = fluid.Program() with fluid.program_guard(main_program=encoder_program): enc_output = encoder( - ModelHyperParams.src_vocab_size, ModelHyperParams.max_length, + ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, @@ -259,12 +258,12 @@ def main(): decoder_program = fluid.Program() with fluid.program_guard(main_program=decoder_program): - predict = decoder(ModelHyperParams.trg_vocab_size, - ModelHyperParams.max_length, ModelHyperParams.n_layer, - ModelHyperParams.n_head, ModelHyperParams.d_key, - ModelHyperParams.d_value, ModelHyperParams.d_model, - ModelHyperParams.d_inner_hid, - ModelHyperParams.dropout) + predict = decoder( + ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, ModelHyperParams.n_head, + ModelHyperParams.d_key, ModelHyperParams.d_value, + ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, + ModelHyperParams.dropout) # Load model parameters of encoder and decoder separately from the saved # transformer model. @@ -300,9 +299,6 @@ def main(): trg_idx2word = paddle.dataset.wmt16.get_dict( "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) - # Append the token since the dict provided by dataset.wmt16 does - # not include it. - trg_idx2word[ModelHyperParams.trg_pad_idx] = "" def post_process_seq(seq, bos_idx=ModelHyperParams.bos_idx, @@ -326,19 +322,22 @@ def main(): for batch_id, data in enumerate(test_data()): batch_seqs, batch_scores = translate_batch( - exe, [item[0] for item in data], + exe, + [item[0] for item in data], encoder_program, - encoder_input_data_names, [enc_output.name], + encoder_input_data_names, + [enc_output.name], decoder_program, - decoder_input_data_names, [predict.name], + decoder_input_data_names, + [predict.name], InferTaskConfig.beam_size, InferTaskConfig.max_length, InferTaskConfig.n_best, len(data), ModelHyperParams.n_head, ModelHyperParams.d_model, - ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, + ModelHyperParams.eos_idx, # Use eos_idx to pad. + ModelHyperParams.eos_idx, # Use eos_idx to pad. ModelHyperParams.bos_idx, ModelHyperParams.eos_idx, ModelHyperParams.unk_idx, diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 940c16602190fabe3d311800e10ae4a45e3bcca7..766d88db9f85d2a362c3f721ae3c09d23601bc74 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -117,7 +117,7 @@ def main(): sum_cost, avg_cost, predict, token_num = transformer( ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, - ModelHyperParams.max_length, ModelHyperParams.n_layer, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout) @@ -174,7 +174,7 @@ def main(): pos_enc_param = fluid.global_scope().find_var( pos_enc_param_name).get_tensor() pos_enc_param.set( - position_encoding_init(ModelHyperParams.max_length, + position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model), place) for pass_id in xrange(TrainTaskConfig.pass_num):