提交 baa01f6f 编写于 作者: G guosheng

Refine the validation in Transformer.

上级 afe55c9e
...@@ -594,7 +594,7 @@ def transformer( ...@@ -594,7 +594,7 @@ def transformer(
sum_cost = layers.reduce_sum(weighted_cost) sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights) token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict return sum_cost, avg_cost, predict, token_num
def wrap_encoder(src_vocab_size, def wrap_encoder(src_vocab_size,
......
...@@ -104,7 +104,7 @@ def main(): ...@@ -104,7 +104,7 @@ def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
sum_cost, avg_cost, predict = transformer( sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.n_layer, ModelHyperParams.n_head,
...@@ -140,23 +140,24 @@ def main(): ...@@ -140,23 +140,24 @@ def main():
batch_size=TrainTaskConfig.batch_size) batch_size=TrainTaskConfig.batch_size)
def test(exe): def test(exe):
test_sum_costs = [] test_total_cost = 0
test_avg_costs = [] test_total_token = 0
for batch_id, data in enumerate(val_data()): for batch_id, data in enumerate(val_data()):
if len(data) != TrainTaskConfig.batch_size:
# Fix the batch size to keep comparable cost among all
# mini-batches and compute the mean.
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model) ModelHyperParams.d_model)
test_sum_cost, test_avg_cost = exe.run( test_sum_cost, test_token_num = exe.run(
test_program, feed=data_input, fetch_list=[sum_cost, avg_cost]) test_program,
test_sum_costs.append(test_sum_cost) feed=data_input,
test_avg_costs.append(test_avg_cost) fetch_list=[sum_cost, token_num],
return np.mean(test_sum_costs), np.mean(test_avg_costs) use_program_cache=True)
test_total_cost += test_sum_cost
test_total_token += test_token_num
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters. # Initialize the parameters.
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
...@@ -185,13 +186,11 @@ def main(): ...@@ -185,13 +186,11 @@ def main():
(pass_id, batch_id, sum_cost_val, avg_cost_val, (pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)]))) np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference. # Validate and save the model for inference.
val_sum_cost, val_avg_cost = test(exe) val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time() pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time time_consumed = pass_end_time - pass_start_time
print("epoch: %d, val sum loss: %f, val avg loss: %f, val ppl: %f, " print("epoch: %d, val avg loss: %f, val ppl: %f, "
"consumed %fs" % "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
(pass_id, val_sum_cost, val_avg_cost,
np.exp([min(val_avg_cost, 100)]), time_consumed))
fluid.io.save_inference_model( fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir, os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册