未验证 提交 7d9dd853 编写于 作者: R ruri 提交者: GitHub

Merge pull request #1604 from shippingwang/fix_eval_multi_loss_bug

Fix multi-loss bug
......@@ -49,7 +49,7 @@ def eval(args):
# model definition
model = models.__dict__[model_name]()
if model_name is "GoogleNet":
if model_name == "GoogleNet":
out0, out1, out2 = model.net(input=image, class_dim=class_dim)
cost0 = fluid.layers.cross_entropy(input=out0, label=label)
cost1 = fluid.layers.cross_entropy(input=out1, label=label)
......@@ -71,8 +71,10 @@ def eval(args):
test_program = fluid.default_main_program().clone(for_test=True)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(
fluid.default_main_program(), skip_opt_set=set(fetch_list))
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -88,8 +90,6 @@ def eval(args):
val_reader = paddle.batch(reader.val(), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
test_info = [[], [], []]
cnt = 0
for batch_id, data in enumerate(val_reader()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册