diff --git a/docs/en/tutorials/getting_started_en.md b/docs/en/tutorials/getting_started_en.md index 8f0c4bc84c5957ecd7c29d9e25c7c965a9fcb920..e3658448ca50f008a4127c8a3d6fdfb835acfa08 100644 --- a/docs/en/tutorials/getting_started_en.md +++ b/docs/en/tutorials/getting_started_en.md @@ -267,6 +267,8 @@ Among them: + `enable_mkldnn`: Wheter to use `MKL-DNN`, default by `False`. When both `use_gpu` and `enable_mkldnn` are set to `True`, GPU is used to run and `enable_mkldnn` will be ignored. + `resize_short`: The length of the shortest side of the image that be scaled proportionally, default by `256`; + `resize`: The side length of the image that be center cropped from resize_shorted image, default by `224`; ++ `enable_calc_topk`: Whether to calculate top-k accuracy of the predction, default by `False`. Top-k accuracy will be printed out when set as `True`. ++ `gt_label_path`: Image name and label file, used when `enable_calc_topk` is `True` to get image list and labels. **Note**: If you want to use `Transformer series models`, such as `DeiT_***_384`, `ViT_***_384`, etc., please pay attention to the input size of model, and need to set `resize_short=384`, `resize=384`. diff --git a/docs/zh_CN/tutorials/getting_started.md b/docs/zh_CN/tutorials/getting_started.md index 9d33a66c82259460b4ff427c1d02018d24f75f55..5832f11dc80ff3516260b7041cef0949b0b4c9a1 100644 --- a/docs/zh_CN/tutorials/getting_started.md +++ b/docs/zh_CN/tutorials/getting_started.md @@ -278,6 +278,9 @@ python tools/infer/predict.py \ + `enable_mkldnn`:是否启用`MKL-DNN`加速,默认为`False`。注意`enable_mkldnn`与`use_gpu`同时为`True`时,将忽略`enable_mkldnn`,而使用GPU运行。 + `resize_short`: 对输入图像进行等比例缩放,表示最短边的尺寸,默认值:`256` + `resize`: 对`resize_short`操作后的进行居中裁剪,表示裁剪的尺寸,默认值:`224` ++ `enable_calc_topk`: 是否计算预测结果的Topk精度指标,默认为`False`, ++ `gt_label_path`: 图像文件名以及真值标签文件,当`enable_calc_topk`为True时生效,用于读取待预测的图像列表及其标签。 + **注意**: 如果使用`Transformer`系列模型,如`DeiT_***_384`, `ViT_***_384`等,请注意模型的输入数据尺寸,需要设置参数`resize_short=384`, `resize=384`。 diff --git a/tools/infer/predict.py b/tools/infer/predict.py index de624bcbc11a64af8a37c6860c34632f13d82d4f..024e9a0296450ef7e3cd362207dde0b987577eff 100644 --- a/tools/infer/predict.py +++ b/tools/infer/predict.py @@ -20,7 +20,8 @@ import time import sys sys.path.insert(0, ".") from ppcls.utils import logger -from tools.infer.utils import parse_args, get_image_list, create_paddle_predictor, preprocess, postprocess +from tools.infer.utils import parse_args, create_paddle_predictor, preprocess, postprocess +from tools.infer.utils import get_image_list, get_image_list_from_label_file, calc_topk_acc class Predictor(object): @@ -46,7 +47,19 @@ class Predictor(object): return batch_output def normal_predict(self): - image_list = get_image_list(self.args.image_file) + if self.args.enable_calc_topk: + assert self.args.gt_label_path is not None and os.path.exists(self.args.gt_label_path), \ + "gt_label_path shoule not be None and must exist, please check its path." + image_list, gt_labels = get_image_list_from_label_file( + self.args.image_file, self.args.gt_label_path) + predicts_map = { + "prediction": [], + "gt_label": [], + } + else: + image_list = get_image_list(self.args.image_file) + gt_labels = None + batch_input_list = [] img_name_list = [] cnt = 0 @@ -64,6 +77,8 @@ class Predictor(object): img_name = img_path.split("/")[-1] img_name_list.append(img_name) cnt += 1 + if self.args.enable_calc_topk: + predicts_map["gt_label"].append(gt_labels[idx]) if cnt % args.batch_size == 0 or (idx + 1) == len(image_list): batch_outputs = self.predict(np.array(batch_input_list)) @@ -74,12 +89,20 @@ class Predictor(object): clas_ids = result_dict["clas_ids"] scores_str = "[{}]".format(", ".join("{:.2f}".format( r) for r in result_dict["scores"])) - print( + logger.info( "File:{}, Top-{} result: class id(s): {}, score(s): {}". format(filename, self.args.top_k, clas_ids, scores_str)) + + if self.args.enable_calc_topk: + predicts_map["prediction"].append(clas_ids) + batch_input_list = [] img_name_list = [] + if self.args.enable_calc_topk: + topk_acc = calc_topk_acc(predicts_map) + for idx, acc in enumerate(topk_acc): + logger.info("Top-{} acc: {:.5f}".format(idx + 1, acc)) def benchmark_predict(self): test_num = 500 diff --git a/tools/infer/utils.py b/tools/infer/utils.py index 8862e5f5fb0d6bfa606d2bc7a0a9952edf2a6233..fb0627b3039da6bab83e25566decd8bea653f788 100644 --- a/tools/infer/utils.py +++ b/tools/infer/utils.py @@ -74,6 +74,12 @@ def parse_args(): # parameters for test hubserving parser.add_argument("--server_url", type=str) + # enable_calc_metric, when set as true, topk acc will be calculated + parser.add_argument("--enable_calc_topk", type=str2bool, default=False) + # groudtruth label path + # data format for each line: $image_name $class_id + parser.add_argument("--gt_label_path", type=str, default=None) + return parser.parse_args() @@ -128,7 +134,6 @@ def preprocess(img, args): def postprocess(batch_outputs, topk=5, multilabel=False): batch_results = [] for probs in batch_outputs: - results = [] if multilabel: index = np.where(probs >= 0.5)[0].astype('int32') else: @@ -159,6 +164,42 @@ def get_image_list(img_file): return imgs_lists +def get_image_list_from_label_file(image_path, label_file_path): + imgs_lists = [] + gt_labels = [] + with open(label_file_path, "r") as fin: + lines = fin.readlines() + for line in lines: + image_name, label = line.strip("\n").split() + label = int(label) + imgs_lists.append(os.path.join(image_path, image_name)) + gt_labels.append(int(label)) + return imgs_lists, gt_labels + + +def calc_topk_acc(info_map): + ''' + calc_topk_acc + input: + info_map(dict): keys are prediction and gt_label + output: + topk_acc(list): top-k accuracy list + ''' + gt_label = np.array(info_map["gt_label"]) + prediction = np.array(info_map["prediction"]) + + gt_label = np.reshape(gt_label, (-1, 1)).repeat( + prediction.shape[1], axis=1) + correct = np.equal(prediction, gt_label) + topk_acc = [] + for idx in range(prediction.shape[1]): + if idx > 0: + correct[:, idx] = np.logical_or(correct[:, idx], + correct[:, idx - 1]) + topk_acc.append(1.0 * np.sum(correct[:, idx]) / correct.shape[0]) + return topk_acc + + def save_prelabel_results(class_id, input_file_path, output_dir): output_dir = os.path.join(output_dir, str(class_id)) if not os.path.isdir(output_dir):