未验证 提交 2ffa0a93 编写于 作者: R ruri 提交者: GitHub

Merge pull request #1625 from shippingwang/fix_infer_bug

Fixed image classification infer 
......@@ -11,7 +11,6 @@ import models
import reader
import argparse
import functools
from models.learning_rate import cosine_decay
from utility import add_arguments, print_arguments
import math
......@@ -44,7 +43,6 @@ def infer(args):
# model definition
model = models.__dict__[model_name]()
if model_name is "GoogleNet":
out, _, _ = model.net(input=image, class_dim=class_dim)
else:
......@@ -52,8 +50,10 @@ def infer(args):
test_program = fluid.default_main_program().clone(for_test=True)
fetch_list = [out.name]
if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(
fluid.default_main_program(), skip_opt_set=set(fetch_list))
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -70,8 +70,6 @@ def infer(args):
test_reader = paddle.batch(reader.test(), batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
fetch_list = [out.name]
TOPK = 1
for batch_id, data in enumerate(test_reader()):
result = exe.run(test_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册