From a9b22c5dc581fa372d11d080b019a888901424ba Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 9 Apr 2018 09:28:10 +0000 Subject: [PATCH] clean up --- fluid/neural_machine_translation/transformer/nmt_fluid.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/nmt_fluid.py b/fluid/neural_machine_translation/transformer/nmt_fluid.py index 997e2006..44c36be9 100644 --- a/fluid/neural_machine_translation/transformer/nmt_fluid.py +++ b/fluid/neural_machine_translation/transformer/nmt_fluid.py @@ -32,8 +32,6 @@ parser.add_argument( default=TrainTaskConfig.learning_rate, help="Learning rate for training.") -parser.add_argument('--num_passes', type=int, default=50, help="No. of passes.") - parser.add_argument( '--device', type=str, @@ -208,14 +206,16 @@ def main(): return np.mean(test_costs) def train_loop(exe, trainer_prog): - ts = time.time() for pass_id in xrange(args.pass_num): + ts = time.time() + total = 0 for batch_id, data in enumerate(train_reader()): # The current program desc is coupled with batch_size, thus all # mini-batches must have the same number of instances currently. if len(data) != args.batch_size: continue + total += len(data) start_time = time.time() data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + @@ -237,7 +237,7 @@ def main(): # Validate and save the model for inference. val_cost = test(exe) print("pass_id = %d cost = %f avg_speed = %.2f sample/s" % - (pass_id, cost_val, len(data) / (time.time() - ts))) + (pass_id, val_cost, total / (time.time() - ts))) if args.local: # Initialize the parameters. -- GitLab