From 0efcf0b67d79599297b1b0298ef53c4a1da81983 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 15 Nov 2021 04:19:52 +0000 Subject: [PATCH] refactor: adapt to predict_cls refer to commit: bd586f4a1d13aedf40baf8b1d8677ce7eebd5d66 --- paddleclas.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/paddleclas.py b/paddleclas.py index 9b03039b..3b5bfd31 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -480,7 +480,7 @@ class PaddleClas(object): input_data (Union[str, np.array]): When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet. When the type is np.array, it is the image data whose channel order is RGB. - print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False. + print_pred (bool, optional): Whether print the prediction result. Defaults to False. Raises: ImageTypeError: Illegal input_data. @@ -489,12 +489,11 @@ class PaddleClas(object): Generator[list, None, None]: The prediction result(s) of input_data by batch_size. For every one image, prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". - The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] + The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] """ if isinstance(input_data, np.ndarray): - outputs = self.cls_predictor.predict(input_data) - yield self.cls_predictor.postprocess(outputs) + yield self.cls_predictor.predict(input_data) elif isinstance(input_data, str): if input_data.startswith("http") or input_data.startswith("https"): image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR) @@ -509,7 +508,7 @@ class PaddleClas(object): image_list = get_image_list(input_data) batch_size = self._config.Global.get("batch_size", 1) - topk = self._config.PostProcess.get('topk', 1) + topk = self._config.PostProcess.Topk.get('topk', 1) img_list = [] img_path_list = [] @@ -527,16 +526,15 @@ class PaddleClas(object): cnt += 1 if cnt % batch_size == 0 or (idx + 1) == len(image_list): - outputs = self.cls_predictor.predict(img_list) - preds = self.cls_predictor.postprocess(outputs, - img_path_list) + preds = self.cls_predictor.predict(img_list) + if print_pred and preds: - for pred in preds: - filename = pred.pop("file_name") + for idx, pred in enumerate(preds): pred_str = ", ".join( [f"{k}: {pred[k]}" for k in pred]) print( - f"filename: {filename}, top-{topk}, {pred_str}") + f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}" + ) img_list = [] img_path_list = [] -- GitLab