未验证 提交 f17343b0 编写于 作者: L littletomatodonkey 提交者: GitHub

add support for eval using inference engine (#696)

上级 6c2a33fe
......@@ -267,6 +267,8 @@ Among them:
+ `enable_mkldnn`: Wheter to use `MKL-DNN`, default by `False`. When both `use_gpu` and `enable_mkldnn` are set to `True`, GPU is used to run and `enable_mkldnn` will be ignored.
+ `resize_short`: The length of the shortest side of the image that be scaled proportionally, default by `256`;
+ `resize`: The side length of the image that be center cropped from resize_shorted image, default by `224`;
+ `enable_calc_topk`: Whether to calculate top-k accuracy of the predction, default by `False`. Top-k accuracy will be printed out when set as `True`.
+ `gt_label_path`: Image name and label file, used when `enable_calc_topk` is `True` to get image list and labels.
**Note**: If you want to use `Transformer series models`, such as `DeiT_***_384`, `ViT_***_384`, etc., please pay attention to the input size of model, and need to set `resize_short=384`, `resize=384`.
......
......@@ -278,6 +278,9 @@ python tools/infer/predict.py \
+ `enable_mkldnn`:是否启用`MKL-DNN`加速,默认为`False`。注意`enable_mkldnn``use_gpu`同时为`True`时,将忽略`enable_mkldnn`,而使用GPU运行。
+ `resize_short`: 对输入图像进行等比例缩放,表示最短边的尺寸,默认值:`256`
+ `resize`: 对`resize_short`操作后的进行居中裁剪,表示裁剪的尺寸,默认值:`224`
+ `enable_calc_topk`: 是否计算预测结果的Topk精度指标,默认为`False`
+ `gt_label_path`: 图像文件名以及真值标签文件,当`enable_calc_topk`为True时生效,用于读取待预测的图像列表及其标签。
**注意**: 如果使用`Transformer`系列模型,如`DeiT_***_384`, `ViT_***_384`等,请注意模型的输入数据尺寸,需要设置参数`resize_short=384`, `resize=384`
......
......@@ -20,7 +20,8 @@ import time
import sys
sys.path.insert(0, ".")
from ppcls.utils import logger
from tools.infer.utils import parse_args, get_image_list, create_paddle_predictor, preprocess, postprocess
from tools.infer.utils import parse_args, create_paddle_predictor, preprocess, postprocess
from tools.infer.utils import get_image_list, get_image_list_from_label_file, calc_topk_acc
class Predictor(object):
......@@ -46,7 +47,19 @@ class Predictor(object):
return batch_output
def normal_predict(self):
if self.args.enable_calc_topk:
assert self.args.gt_label_path is not None and os.path.exists(self.args.gt_label_path), \
"gt_label_path shoule not be None and must exist, please check its path."
image_list, gt_labels = get_image_list_from_label_file(
self.args.image_file, self.args.gt_label_path)
predicts_map = {
"prediction": [],
"gt_label": [],
}
else:
image_list = get_image_list(self.args.image_file)
gt_labels = None
batch_input_list = []
img_name_list = []
cnt = 0
......@@ -64,6 +77,8 @@ class Predictor(object):
img_name = img_path.split("/")[-1]
img_name_list.append(img_name)
cnt += 1
if self.args.enable_calc_topk:
predicts_map["gt_label"].append(gt_labels[idx])
if cnt % args.batch_size == 0 or (idx + 1) == len(image_list):
batch_outputs = self.predict(np.array(batch_input_list))
......@@ -74,12 +89,20 @@ class Predictor(object):
clas_ids = result_dict["clas_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
print(
logger.info(
"File:{}, Top-{} result: class id(s): {}, score(s): {}".
format(filename, self.args.top_k, clas_ids,
scores_str))
if self.args.enable_calc_topk:
predicts_map["prediction"].append(clas_ids)
batch_input_list = []
img_name_list = []
if self.args.enable_calc_topk:
topk_acc = calc_topk_acc(predicts_map)
for idx, acc in enumerate(topk_acc):
logger.info("Top-{} acc: {:.5f}".format(idx + 1, acc))
def benchmark_predict(self):
test_num = 500
......
......@@ -74,6 +74,12 @@ def parse_args():
# parameters for test hubserving
parser.add_argument("--server_url", type=str)
# enable_calc_metric, when set as true, topk acc will be calculated
parser.add_argument("--enable_calc_topk", type=str2bool, default=False)
# groudtruth label path
# data format for each line: $image_name $class_id
parser.add_argument("--gt_label_path", type=str, default=None)
return parser.parse_args()
......@@ -128,7 +134,6 @@ def preprocess(img, args):
def postprocess(batch_outputs, topk=5, multilabel=False):
batch_results = []
for probs in batch_outputs:
results = []
if multilabel:
index = np.where(probs >= 0.5)[0].astype('int32')
else:
......@@ -159,6 +164,42 @@ def get_image_list(img_file):
return imgs_lists
def get_image_list_from_label_file(image_path, label_file_path):
imgs_lists = []
gt_labels = []
with open(label_file_path, "r") as fin:
lines = fin.readlines()
for line in lines:
image_name, label = line.strip("\n").split()
label = int(label)
imgs_lists.append(os.path.join(image_path, image_name))
gt_labels.append(int(label))
return imgs_lists, gt_labels
def calc_topk_acc(info_map):
'''
calc_topk_acc
input:
info_map(dict): keys are prediction and gt_label
output:
topk_acc(list): top-k accuracy list
'''
gt_label = np.array(info_map["gt_label"])
prediction = np.array(info_map["prediction"])
gt_label = np.reshape(gt_label, (-1, 1)).repeat(
prediction.shape[1], axis=1)
correct = np.equal(prediction, gt_label)
topk_acc = []
for idx in range(prediction.shape[1]):
if idx > 0:
correct[:, idx] = np.logical_or(correct[:, idx],
correct[:, idx - 1])
topk_acc.append(1.0 * np.sum(correct[:, idx]) / correct.shape[0])
return topk_acc
def save_prelabel_results(class_id, input_file_path, output_dir):
output_dir = os.path.join(output_dir, str(class_id))
if not os.path.isdir(output_dir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册