提交 9630aa0e 编写于 作者: L Luo Tao

fix test error in fluid_benchmark.py

上级 f2c55743
...@@ -85,8 +85,8 @@ def parse_args(): ...@@ -85,8 +85,8 @@ def parse_args():
help='If set, use nvprof for CUDA.') help='If set, use nvprof for CUDA.')
parser.add_argument( parser.add_argument(
'--no_test', '--no_test',
action='store_false', action='store_true',
help='If set, test the testset during training.') help='If set, do not test the testset during training.')
parser.add_argument( parser.add_argument(
'--memory_optimize', '--memory_optimize',
action='store_true', action='store_true',
...@@ -229,9 +229,9 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, ...@@ -229,9 +229,9 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
print("Pass: %d, Iter: %d, Loss: %f\n" % print("Pass: %d, Iter: %d, Loss: %f\n" %
(pass_id, iters, np.mean(train_losses))) (pass_id, iters, np.mean(train_losses)))
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))) print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))),
# evaluation # evaluation
if not args.no_test and batch_acc != None: if not args.no_test and batch_acc:
pass_test_acc = test(exe, infer_prog, test_reader, feeder, pass_test_acc = test(exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print(", Test Accuracy: %f" % pass_test_acc) print(", Test Accuracy: %f" % pass_test_acc)
...@@ -310,7 +310,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -310,7 +310,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
print("Pass %d, batch %d, loss %s" % print("Pass %d, batch %d, loss %s" %
(pass_id, batch_id, np.array(loss))) (pass_id, batch_id, np.array(loss)))
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
if not args.no_test and batch_acc != None: if not args.no_test and batch_acc:
test_acc = test(startup_exe, infer_prog, test_reader, feeder, test_acc = test(startup_exe, infer_prog, test_reader, feeder,
batch_acc) batch_acc)
print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc)) print("Pass: %d, Test Accuracy: %f\n" % (pass_id, test_acc))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册