未验证 提交 68d17711 编写于 作者: R ruri 提交者: GitHub

fix bugs in image classification infer (#4125)

上级 7634fe7e
......@@ -277,6 +277,7 @@ python eval.py \
**参数说明:**
* **data_dir**: 数据存储位置,默认值:`/data/ILSVRC2012/val/`
* **save_inference**: 是否保存二进制模型,默认值:`False`
* **topk**: 按照置信由高到低排序标签结果,返回的结果数量,默认值:1
* **class_map_path**: 可读标签文件路径,默认值:`./utils/tools/readable_label.txt`
......
......@@ -36,7 +36,7 @@ from utils import *
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet data")
add_arg('data_dir', str, "./data/ILSVRC2012/val/", "The ImageNet data")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 1000, "Class number.")
parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model")
......@@ -148,6 +148,9 @@ def infer(args):
parallel_data = []
parallel_id = []
place_num = paddle.fluid.core.get_cuda_device_count() if args.use_gpu else 1
if os.path.exists(args.save_json_path):
logger.warning("path: {} Already exists! will recover it\n".format(
args.save_json_path))
with open(args.save_json_path, "w") as fout:
for batch_id, data in enumerate(test_reader()):
image_data = [[items[0]] for items in data]
......@@ -180,7 +183,7 @@ def infer(args):
info[real_id]['score'], info[real_id]['class'] = str(
res[pred_label]), str(pred_label)
logger.info(real_id, info[real_id])
logger.info("{}, {}".format(real_id, info[real_id]))
sys.stdout.flush()
fout.write(real_id + "\t" + json.dumps(info[real_id]) +
"\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册