未验证 提交 6871aa80 编写于 作者: T Tingquan Gao 提交者: GitHub

fix: fix the img channel order that read by cv2.imread() (#1075)

上级 f82d1095
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
* installing from pypi * installing from pypi
```bash ```bash
pip3 install paddleclas==2.2.0 pip3 install paddleclas==2.2.1
``` ```
* build own whl package and install * build own whl package and install
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
* pip安装 * pip安装
```bash ```bash
pip3 install paddleclas==2.2.0 pip3 install paddleclas==2.2.1
``` ```
* 本地构建并安装 * 本地构建并安装
......
...@@ -18,6 +18,7 @@ __dir__ = os.path.dirname(__file__) ...@@ -18,6 +18,7 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, "")) sys.path.append(os.path.join(__dir__, ""))
sys.path.append(os.path.join(__dir__, "deploy")) sys.path.append(os.path.join(__dir__, "deploy"))
from typing import Union, Generator
import argparse import argparse
import shutil import shutil
import textwrap import textwrap
...@@ -410,11 +411,11 @@ class PaddleClas(object): ...@@ -410,11 +411,11 @@ class PaddleClas(object):
"""Init PaddleClas with config. """Init PaddleClas with config.
Args: Args:
model_name: The model name supported by PaddleClas, default by None. If specified, override config. model_name (str, optional): The model name supported by PaddleClas. If specified, override config. Defaults to None.
inference_model_dir: The directory that contained model file and params file to be used, default by None. If specified, override config. inference_model_dir (str, optional): The directory that contained model file and params file to be used. If specified, override config. Defaults to None.
use_gpu: Whether use GPU, default by None. If specified, override config. use_gpu (bool, optional): Whether use GPU. If specified, override config. Defaults to True.
batch_size: The batch size to pridict, default by None. If specified, override config. batch_size (int, optional): The batch size to pridict. If specified, override config. Defaults to 1.
topk: Return the top k prediction results with the highest score. topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
""" """
super().__init__() super().__init__()
self._config = init_config(model_name, inference_model_dir, use_gpu, self._config = init_config(model_name, inference_model_dir, use_gpu,
...@@ -459,20 +460,26 @@ class PaddleClas(object): ...@@ -459,20 +460,26 @@ class PaddleClas(object):
raise InputModelError(err) raise InputModelError(err)
return return
def predict(self, input_data, print_pred=False): def predict(self, input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
"""Predict input_data. """Predict input_data.
Args: Args:
input_data (str | NumPy.array): The path of image, or the directory containing images, or the URL of image from Internet. input_data (Union[str, np.array]):
print_pred (bool, optional): Whether print the prediction result. Defaults to False. When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False.
Raises: Raises:
ImageTypeError: Illegal input_data. ImageTypeError: Illegal input_data.
Yields: Yields:
list: The prediction result(s) of input_data by batch_size. For every one image, prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". The format is as follow: Generator[list, None, None]:
[{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
""" """
if isinstance(input_data, np.ndarray): if isinstance(input_data, np.ndarray):
outputs = self.cls_predictor.predict(input_data) outputs = self.cls_predictor.predict(input_data)
yield self.cls_predictor.postprocess(outputs) yield self.cls_predictor.postprocess(outputs)
...@@ -502,6 +509,7 @@ class PaddleClas(object): ...@@ -502,6 +509,7 @@ class PaddleClas(object):
f"Image file failed to read and has been skipped. The path: {img_path}" f"Image file failed to read and has been skipped. The path: {img_path}"
) )
continue continue
img = img[:, :, ::-1]
img_list.append(img) img_list.append(img)
img_path_list.append(img_path) img_path_list.append(img_path)
cnt += 1 cnt += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册