diff --git a/paddleclas.py b/paddleclas.py index 9b03039bc3c34b73d01abc74a2c843ad584120fb..3b5bfd31df145ed008a24ca2f3ee0abc79bc4f0a 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -480,7 +480,7 @@ class PaddleClas(object): input_data (Union[str, np.array]): When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet. When the type is np.array, it is the image data whose channel order is RGB. - print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False. + print_pred (bool, optional): Whether print the prediction result. Defaults to False. Raises: ImageTypeError: Illegal input_data. @@ -489,12 +489,11 @@ class PaddleClas(object): Generator[list, None, None]: The prediction result(s) of input_data by batch_size. For every one image, prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". - The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] + The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] """ if isinstance(input_data, np.ndarray): - outputs = self.cls_predictor.predict(input_data) - yield self.cls_predictor.postprocess(outputs) + yield self.cls_predictor.predict(input_data) elif isinstance(input_data, str): if input_data.startswith("http") or input_data.startswith("https"): image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR) @@ -509,7 +508,7 @@ class PaddleClas(object): image_list = get_image_list(input_data) batch_size = self._config.Global.get("batch_size", 1) - topk = self._config.PostProcess.get('topk', 1) + topk = self._config.PostProcess.Topk.get('topk', 1) img_list = [] img_path_list = [] @@ -527,16 +526,15 @@ class PaddleClas(object): cnt += 1 if cnt % batch_size == 0 or (idx + 1) == len(image_list): - outputs = self.cls_predictor.predict(img_list) - preds = self.cls_predictor.postprocess(outputs, - img_path_list) + preds = self.cls_predictor.predict(img_list) + if print_pred and preds: - for pred in preds: - filename = pred.pop("file_name") + for idx, pred in enumerate(preds): pred_str = ", ".join( [f"{k}: {pred[k]}" for k in pred]) print( - f"filename: {filename}, top-{topk}, {pred_str}") + f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}" + ) img_list = [] img_path_list = []