未验证 提交 9e12ab90 编写于 作者: R ruri 提交者: GitHub

fix eval bug in googlenet (#4382)

上级 90c7255c
...@@ -113,7 +113,7 @@ def eval(args): ...@@ -113,7 +113,7 @@ def eval(args):
test_program = fluid.default_main_program().clone(for_test=True) test_program = fluid.default_main_program().clone(for_test=True)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name, pred.name] fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
...@@ -151,7 +151,7 @@ def eval(args): ...@@ -151,7 +151,7 @@ def eval(args):
parallel_data.append(image_data) parallel_data.append(image_data)
if place_num == len(parallel_data): if place_num == len(parallel_data):
t1 = time.time() t1 = time.time()
loss_set, acc1_set, acc5_set, pred_set = exe.run( loss_set, acc1_set, acc5_set = exe.run(
compiled_program, compiled_program,
fetch_list=fetch_list, fetch_list=fetch_list,
feed=list(feeder.feed_parallel(parallel_data, place_num))) feed=list(feeder.feed_parallel(parallel_data, place_num)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册