From bb947505d532871096b49ed53c44110cd0edecd7 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 24 Aug 2018 10:07:43 +0800 Subject: [PATCH] Fix the unremoved util_input in Transformer --- .../neural_machine_translation/transformer/config.py | 3 --- fluid/neural_machine_translation/transformer/train.py | 11 +++++------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 4ebc5b7b..4495f6ce 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -190,6 +190,3 @@ fast_decoder_data_input_fields = ( "trg_word", "init_score", "trg_src_attn_bias", ) -# fast_decoder_util_input_fields = ( -# "trg_slf_attn_pre_softmax_shape_delta", -# "trg_slf_attn_post_softmax_shape_delta", ) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index f98f832c..52d8afa4 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -258,7 +258,7 @@ def split_data(data, num_part): def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, - util_input_names, sum_cost, token_num): + sum_cost, token_num): # Context to do validation. test_program = train_progm.clone() with fluid.program_guard(test_program): @@ -299,9 +299,9 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, split_data( data, num_part=dev_count)): data_input_dict, _ = prepare_batch_input( - data_buffer, data_input_names, util_input_names, - ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, - ModelHyperParams.n_head, ModelHyperParams.d_model) + data_buffer, data_input_names, ModelHyperParams.eos_idx, + ModelHyperParams.eos_idx, ModelHyperParams.n_head, + ModelHyperParams.d_model) feed_list.append(data_input_dict) outs = exe.run(feed=feed_list, @@ -363,8 +363,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, if args.val_file_pattern is not None: test = test_context(train_progm, avg_cost, train_exe, dev_count, - data_input_names, util_input_names, sum_cost, - token_num) + data_input_names, sum_cost, token_num) # the best cross-entropy value with label smoothing loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log( -- GitLab