未验证 提交 7679edf6 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #11374 from guochaorong/fix_fluid_benchmark

fix bugs in fluid_benchmark
...@@ -180,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -180,7 +180,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))), print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
# evaluation # evaluation
if not args.no_test and batch_acc: if not args.no_test and batch_acc and not args.use_reader_op:
pass_test_acc = test(exe, infer_prog, test_reader, feeder, pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print(", Test Accuracy: %f" % pass_test_acc) print(", Test Accuracy: %f" % pass_test_acc)
...@@ -277,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -277,11 +277,12 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
batch_id += 1 batch_id += 1
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
if not args.no_test and batch_acc: if not args.no_test and batch_acc and not args.use_reader_op:
# we have not implement record io for test
# skip test when use args.use_reader_op
test_acc = test(startup_exe, infer_prog, test_reader, feeder, test_acc = test(startup_exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc)) print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
exit(0)
def print_arguments(args): def print_arguments(args):
......
...@@ -199,7 +199,10 @@ def get_model(args): ...@@ -199,7 +199,10 @@ def get_model(args):
batched_train_reader = paddle.batch( batched_train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader, buf_size=5120), train_reader, buf_size=5120),
batch_size=args.batch_size * args.gpus) batch_size=args.batch_size * args.gpus,
batched_test_reader = paddle.batch(train_reader, batch_size=args.batch_size) drop_last=True)
batched_test_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
return avg_cost, inference_program, optimizer, batched_train_reader, batched_test_reader, batch_acc return avg_cost, inference_program, optimizer, batched_train_reader,\
batched_test_reader, batch_acc
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册