未验证 提交 ffbf5f7f 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #1185 from guoshengCS/fix-until-input-transformer

Fix the unremoved util_input in Transformer
......@@ -190,6 +190,3 @@ fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"trg_src_attn_bias", )
# fast_decoder_util_input_fields = (
# "trg_slf_attn_pre_softmax_shape_delta",
# "trg_slf_attn_post_softmax_shape_delta", )
......@@ -258,7 +258,7 @@ def split_data(data, num_part):
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
util_input_names, sum_cost, token_num):
sum_cost, token_num):
# Context to do validation.
test_program = train_progm.clone()
with fluid.program_guard(test_program):
......@@ -299,9 +299,9 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
split_data(
data, num_part=dev_count)):
data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
feed_list.append(data_input_dict)
outs = exe.run(feed=feed_list,
......@@ -363,8 +363,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count,
data_input_names, util_input_names, sum_cost,
token_num)
data_input_names, sum_cost, token_num)
# the best cross-entropy value with label smoothing
loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册