From 898f0c51fb9191858c16b48dfcfc37888d345444 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Fri, 13 Apr 2018 06:40:17 +0000 Subject: [PATCH] add speed --- .../transformer_nist_base/nmt_fluid.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py index 9191ca18..4d5010b4 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py +++ b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py @@ -233,11 +233,16 @@ def main(): position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model), place) - for pass_id in xrange(TrainTaskConfig.pass_num): - pass_start_time = time.time() + def train_loop(exe, trainer_prog): + for pass_id in xrange(args.pass_num): + ts = time.time() + total = 0 for batch_id, data in enumerate(train_reader()): if len(data) != TrainTaskConfig.batch_size: continue + + total += len(data) + start_time = time.time() data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.eos_idx, @@ -249,15 +254,19 @@ def main(): fetch_list=[sum_cost, avg_cost], use_program_cache=True) sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1]) - print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" % + print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f, speed=%.2f /s" % (pass_id, batch_id, sum_cost_val, avg_cost_val, - np.exp([min(avg_cost_val[0], 100)]))) + np.exp([min(avg_cost_val[0], 100)]), + len(data) / (time.time() - start_time))) # Validate and save the model for inference. #val_avg_cost, val_ppl = test(exe) pass_end_time = time.time() time_consumed = pass_end_time - pass_start_time print("pass_id = " + str(pass_id) + " time_consumed = " + str(time_consumed)) + print("pass_id = %d cost = %f avg_speed = %.2f sample/s" % + (pass_id, val_cost, total / (time.time() - ts))) + #print("epoch: %d, val avg loss: %f, val ppl: %f, " # "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed)) fluid.io.save_inference_model( -- GitLab