diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index 150346798e38257a40f1bde4007b4ee6ffa9ccbf..ca7f7dbb0712e2c6f975557ef726cde81fce1004 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, # FIXME(wuyi): For use_reader_op, if the current # pass is not the last, the last batch of this pass # is also equal to args.batch_size. - num_samples += len(args.batch_size) + if args.use_reader_op: + num_samples += args.batch_size + else: + num_samples += len(data) train_losses.append(loss) print("Pass: %d, Iter: %d, Loss: %f\n" % (pass_id, iters, np.mean(train_losses))) @@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, if iters == args.skip_batch_num: start_time = time.time() num_samples = 0 - # NOTE: if use reader ops, the input data is not splited to multiple cards - if args.use_reader_op and iters >= args.iterations / args.gpus: - break if args.use_fake_data or args.use_reader_op: try: loss, = exe.run([avg_loss.name]) @@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, loss, = exe.run([avg_loss.name], feed=feeder.feed(data)) if args.update_method == "pserver": exe.bcast_params() - num_samples += len(data) + if args.use_reader_op: + num_samples += args.batch_size + else: + num_samples += len(data) iters += 1 if batch_id % 1 == 0: print("Pass %d, batch %d, loss %s" %