提交 3d878d02 编写于 作者: Z zhaoyuchen2018 提交者: ruri

Fix infer and eval case fail (#2854)

上级 c0a6ebd1
...@@ -44,6 +44,7 @@ add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use. ...@@ -44,6 +44,7 @@ add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.
add_arg('resize_short_size', int, 256, "Set resize short size") add_arg('resize_short_size', int, 256, "Set resize short size")
# yapf: enable # yapf: enable
def eval(args): def eval(args):
# parameters from arguments # parameters from arguments
class_dim = args.class_dim class_dim = args.class_dim
...@@ -93,10 +94,9 @@ def eval(args): ...@@ -93,10 +94,9 @@ def eval(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, pretrained_model) fluid.io.load_persistables(exe, pretrained_model)
val_reader = paddle.batch(reader.val(settings=args), batch_size=args.batch_size) val_reader = reader.val(settings=args, batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
test_info = [[], [], []] test_info = [[], [], []]
...@@ -127,7 +127,7 @@ def eval(args): ...@@ -127,7 +127,7 @@ def eval(args):
test_acc5 = np.sum(test_info[2]) / cnt test_acc5 = np.sum(test_info[2]) / cnt
print("Test_loss {0}, test_acc1 {1}, test_acc5 {2}".format( print("Test_loss {0}, test_acc1 {1}, test_acc5 {2}".format(
"%.5f"%test_loss, "%.5f"%test_acc1, "%.5f"%test_acc5)) "%.5f" % test_loss, "%.5f" % test_acc1, "%.5f" % test_acc5))
sys.stdout.flush() sys.stdout.flush()
......
...@@ -29,7 +29,7 @@ import paddle.fluid as fluid ...@@ -29,7 +29,7 @@ import paddle.fluid as fluid
import reader_cv2 as reader import reader_cv2 as reader
import models import models
import utils import utils
from utils.utility import add_arguments,print_arguments, check_gpu from utils.utility import add_arguments, print_arguments, check_gpu
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable # yapf: disable
...@@ -44,6 +44,7 @@ add_arg('save_inference', bool, False, "Whether to save infere ...@@ -44,6 +44,7 @@ add_arg('save_inference', bool, False, "Whether to save infere
add_arg('resize_short_size', int, 256, "Set resize short size") add_arg('resize_short_size', int, 256, "Set resize short size")
# yapf: enable # yapf: enable
def infer(args): def infer(args):
# parameters from arguments # parameters from arguments
class_dim = args.class_dim class_dim = args.class_dim
...@@ -87,10 +88,11 @@ def infer(args): ...@@ -87,10 +88,11 @@ def infer(args):
executor=exe, executor=exe,
model_filename='model', model_filename='model',
params_filename='params') params_filename='params')
print("model: ",model_name," is already saved") print("model: ", model_name, " is already saved")
exit(0) exit(0)
test_batch_size = 1 test_batch_size = 1
test_reader = paddle.batch(reader.test(settings=args), batch_size=test_batch_size)
test_reader = reader.test(settings=args, batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image]) feeder = fluid.DataFeeder(place=place, feed_list=[image])
TOPK = 1 TOPK = 1
......
...@@ -243,10 +243,9 @@ def _reader_creator(settings, ...@@ -243,10 +243,9 @@ def _reader_creator(settings,
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
batch_data.append([img_path, int(label)]) batch_data.append([img_path, int(label)])
if len(batch_data) == batch_size: if len(batch_data) == batch_size:
if mode == 'train' or mode == 'val': if mode == 'train' or mode == 'val' or mode == 'test':
yield batch_data yield batch_data
elif mode == 'test':
yield [sample[0] for sample in batch_data]
batch_data = [] batch_data = []
return read_file_list return read_file_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册