From ef88e63f85d0cc376b4cfa14060e0c00e9cade64 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 4 Apr 2018 16:30:38 +0800 Subject: [PATCH] Refine the training output in Transformer --- .../transformer/train.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 684cc6f3..b6dace62 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -1,4 +1,5 @@ import os +import time import numpy as np import paddle @@ -159,6 +160,7 @@ def main(): ModelHyperParams.d_model), place) for pass_id in xrange(TrainTaskConfig.pass_num): + pass_start_time = time.time() for batch_id, data in enumerate(train_data()): # The current program desc is coupled with batch_size, thus all # mini-batches must have the same number of instances currently. @@ -175,15 +177,17 @@ def main(): fetch_list=[sum_cost, avg_cost], use_program_cache=True) sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1]) - print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + - " sum_cost = " + str(sum_cost_val) + " avg_cost = " + str( - avg_cost_val) + " ppl = " + str( - np.exp([min(avg_cost_val[0], 100)]))) + print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" % + (pass_id, batch_id, sum_cost_val, avg_cost_val, + np.exp([min(avg_cost_val[0], 100)]))) # Validate and save the model for inference. val_sum_cost, val_avg_cost = test(exe) - print("pass_id = " + str(pass_id) + " val_sum_cost = " + str( - val_sum_cost) + " val_avg_cost = " + str(val_avg_cost) + - " val_ppl = " + str(np.exp(min(val_avg_cost, 100)))) + pass_end_time = time.time() + time_consumed = pass_end_time - pass_start_time + print( + "epoch: %d, val sum loss: %f, val avg loss: %f, val ppl: %f, consumed %fs" + % (pass_id, val_sum_cost, val_avg_cost, + np.exp([min(val_avg_cost, 100)]), time_consumed)) fluid.io.save_inference_model( os.path.join(TrainTaskConfig.model_dir, "pass_" + str(pass_id) + ".infer.model"), -- GitLab