From 6871aa80f36d09c93a8b6cf7a5c0ee4cf0bcb1e5 Mon Sep 17 00:00:00 2001 From: Tingquan Gao Date: Fri, 23 Jul 2021 14:27:09 +0800 Subject: [PATCH] fix: fix the img channel order that read by cv2.imread() (#1075) --- docs/en/whl_en.md | 2 +- docs/zh_CN/whl.md | 2 +- paddleclas.py | 30 +++++++++++++++++++----------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/docs/en/whl_en.md b/docs/en/whl_en.md index 05248cb1..e791ef08 100644 --- a/docs/en/whl_en.md +++ b/docs/en/whl_en.md @@ -5,7 +5,7 @@ * installing from pypi ```bash -pip3 install paddleclas==2.2.0 +pip3 install paddleclas==2.2.1 ``` * build own whl package and install diff --git a/docs/zh_CN/whl.md b/docs/zh_CN/whl.md index 58051c52..2a138c7f 100644 --- a/docs/zh_CN/whl.md +++ b/docs/zh_CN/whl.md @@ -5,7 +5,7 @@ * pip安装 ```bash -pip3 install paddleclas==2.2.0 +pip3 install paddleclas==2.2.1 ``` * 本地构建并安装 diff --git a/paddleclas.py b/paddleclas.py index d0bd6b14..91cd030a 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -18,6 +18,7 @@ __dir__ = os.path.dirname(__file__) sys.path.append(os.path.join(__dir__, "")) sys.path.append(os.path.join(__dir__, "deploy")) +from typing import Union, Generator import argparse import shutil import textwrap @@ -356,7 +357,7 @@ def download_with_progressbar(url, save_path): def check_model_file(model_name): - """Check the model files exist and download and untar when no exist. + """Check the model files exist and download and untar when no exist. """ storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, model_name) @@ -410,11 +411,11 @@ class PaddleClas(object): """Init PaddleClas with config. Args: - model_name: The model name supported by PaddleClas, default by None. If specified, override config. - inference_model_dir: The directory that contained model file and params file to be used, default by None. If specified, override config. - use_gpu: Whether use GPU, default by None. If specified, override config. - batch_size: The batch size to pridict, default by None. If specified, override config. - topk: Return the top k prediction results with the highest score. + model_name (str, optional): The model name supported by PaddleClas. If specified, override config. Defaults to None. + inference_model_dir (str, optional): The directory that contained model file and params file to be used. If specified, override config. Defaults to None. + use_gpu (bool, optional): Whether use GPU. If specified, override config. Defaults to True. + batch_size (int, optional): The batch size to pridict. If specified, override config. Defaults to 1. + topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5. """ super().__init__() self._config = init_config(model_name, inference_model_dir, use_gpu, @@ -459,20 +460,26 @@ class PaddleClas(object): raise InputModelError(err) return - def predict(self, input_data, print_pred=False): + def predict(self, input_data: Union[str, np.array], + print_pred: bool=False) -> Generator[list, None, None]: """Predict input_data. Args: - input_data (str | NumPy.array): The path of image, or the directory containing images, or the URL of image from Internet. - print_pred (bool, optional): Whether print the prediction result. Defaults to False. + input_data (Union[str, np.array]): + When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet. + When the type is np.array, it is the image data whose channel order is RGB. + print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False. Raises: ImageTypeError: Illegal input_data. Yields: - list: The prediction result(s) of input_data by batch_size. For every one image, prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". The format is as follow: - [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] + Generator[list, None, None]: + The prediction result(s) of input_data by batch_size. For every one image, + prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". + The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] """ + if isinstance(input_data, np.ndarray): outputs = self.cls_predictor.predict(input_data) yield self.cls_predictor.postprocess(outputs) @@ -502,6 +509,7 @@ class PaddleClas(object): f"Image file failed to read and has been skipped. The path: {img_path}" ) continue + img = img[:, :, ::-1] img_list.append(img) img_path_list.append(img_path) cnt += 1 -- GitLab