From 6fdd5de2aac0e8eba046a21840127e3d650d388f Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 6 Jun 2018 13:39:09 +0800 Subject: [PATCH] update --- benchmark/fluid/fluid_benchmark.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index 150346798e3..ca7f7dbb071 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, # FIXME(wuyi): For use_reader_op, if the current # pass is not the last, the last batch of this pass # is also equal to args.batch_size. - num_samples += len(args.batch_size) + if args.use_reader_op: + num_samples += args.batch_size + else: + num_samples += len(data) train_losses.append(loss) print("Pass: %d, Iter: %d, Loss: %f\n" % (pass_id, iters, np.mean(train_losses))) @@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, if iters == args.skip_batch_num: start_time = time.time() num_samples = 0 - # NOTE: if use reader ops, the input data is not splited to multiple cards - if args.use_reader_op and iters >= args.iterations / args.gpus: - break if args.use_fake_data or args.use_reader_op: try: loss, = exe.run([avg_loss.name]) @@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, loss, = exe.run([avg_loss.name], feed=feeder.feed(data)) if args.update_method == "pserver": exe.bcast_params() - num_samples += len(data) + if args.use_reader_op: + num_samples += args.batch_size + else: + num_samples += len(data) iters += 1 if batch_id % 1 == 0: print("Pass %d, batch %d, loss %s" % -- GitLab