提交 3d878d02 编写于 作者: Z zhaoyuchen2018 提交者: ruri

Fix infer and eval case fail (#2854)

上级 c0a6ebd1
......@@ -26,7 +26,7 @@ import functools
import paddle
import paddle.fluid as fluid
import reader_cv2 as reader
import reader_cv2 as reader
import models
from utils.learning_rate import cosine_decay
from utils.utility import add_arguments, print_arguments, check_gpu
......@@ -44,6 +44,7 @@ add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.
add_arg('resize_short_size', int, 256, "Set resize short size")
# yapf: enable
def eval(args):
# parameters from arguments
class_dim = args.class_dim
......@@ -93,10 +94,9 @@ def eval(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, pretrained_model)
val_reader = paddle.batch(reader.val(settings=args), batch_size=args.batch_size)
val_reader = reader.val(settings=args, batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
test_info = [[], [], []]
......@@ -127,7 +127,7 @@ def eval(args):
test_acc5 = np.sum(test_info[2]) / cnt
print("Test_loss {0}, test_acc1 {1}, test_acc5 {2}".format(
"%.5f"%test_loss, "%.5f"%test_acc1, "%.5f"%test_acc5))
"%.5f" % test_loss, "%.5f" % test_acc1, "%.5f" % test_acc5))
sys.stdout.flush()
......
......@@ -29,7 +29,7 @@ import paddle.fluid as fluid
import reader_cv2 as reader
import models
import utils
from utils.utility import add_arguments,print_arguments, check_gpu
from utils.utility import add_arguments, print_arguments, check_gpu
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
......@@ -44,6 +44,7 @@ add_arg('save_inference', bool, False, "Whether to save infere
add_arg('resize_short_size', int, 256, "Set resize short size")
# yapf: enable
def infer(args):
# parameters from arguments
class_dim = args.class_dim
......@@ -80,17 +81,18 @@ def infer(args):
fluid.io.load_persistables(exe, pretrained_model)
if save_inference:
fluid.io.save_inference_model(
dirname=model_name,
feeded_var_names=['image'],
main_program=test_program,
target_vars=out,
executor=exe,
model_filename='model',
params_filename='params')
print("model: ",model_name," is already saved")
dirname=model_name,
feeded_var_names=['image'],
main_program=test_program,
target_vars=out,
executor=exe,
model_filename='model',
params_filename='params')
print("model: ", model_name, " is already saved")
exit(0)
test_batch_size = 1
test_reader = paddle.batch(reader.test(settings=args), batch_size=test_batch_size)
test_reader = reader.test(settings=args, batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
TOPK = 1
......
......@@ -243,10 +243,9 @@ def _reader_creator(settings,
img_path = os.path.join(data_dir, img_path)
batch_data.append([img_path, int(label)])
if len(batch_data) == batch_size:
if mode == 'train' or mode == 'val':
if mode == 'train' or mode == 'val' or mode == 'test':
yield batch_data
elif mode == 'test':
yield [sample[0] for sample in batch_data]
batch_data = []
return read_file_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册