未验证 提交 ffbf5f7f 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #1185 from guoshengCS/fix-until-input-transformer

Fix the unremoved util_input in Transformer
...@@ -190,6 +190,3 @@ fast_decoder_data_input_fields = ( ...@@ -190,6 +190,3 @@ fast_decoder_data_input_fields = (
"trg_word", "trg_word",
"init_score", "init_score",
"trg_src_attn_bias", ) "trg_src_attn_bias", )
# fast_decoder_util_input_fields = (
# "trg_slf_attn_pre_softmax_shape_delta",
# "trg_slf_attn_post_softmax_shape_delta", )
...@@ -258,7 +258,7 @@ def split_data(data, num_part): ...@@ -258,7 +258,7 @@ def split_data(data, num_part):
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
util_input_names, sum_cost, token_num): sum_cost, token_num):
# Context to do validation. # Context to do validation.
test_program = train_progm.clone() test_program = train_progm.clone()
with fluid.program_guard(test_program): with fluid.program_guard(test_program):
...@@ -299,9 +299,9 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, ...@@ -299,9 +299,9 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, _ = prepare_batch_input( data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
feed_list.append(data_input_dict) feed_list.append(data_input_dict)
outs = exe.run(feed=feed_list, outs = exe.run(feed=feed_list,
...@@ -363,8 +363,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -363,8 +363,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
if args.val_file_pattern is not None: if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count, test = test_context(train_progm, avg_cost, train_exe, dev_count,
data_input_names, util_input_names, sum_cost, data_input_names, sum_cost, token_num)
token_num)
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log( loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册