未验证 提交 6871aa80 编写于 作者: T Tingquan Gao 提交者: GitHub

fix: fix the img channel order that read by cv2.imread() (#1075)

上级 f82d1095
......@@ -5,7 +5,7 @@
* installing from pypi
```bash
pip3 install paddleclas==2.2.0
pip3 install paddleclas==2.2.1
```
* build own whl package and install
......
......@@ -5,7 +5,7 @@
* pip安装
```bash
pip3 install paddleclas==2.2.0
pip3 install paddleclas==2.2.1
```
* 本地构建并安装
......
......@@ -18,6 +18,7 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ""))
sys.path.append(os.path.join(__dir__, "deploy"))
from typing import Union, Generator
import argparse
import shutil
import textwrap
......@@ -356,7 +357,7 @@ def download_with_progressbar(url, save_path):
def check_model_file(model_name):
"""Check the model files exist and download and untar when no exist.
"""Check the model files exist and download and untar when no exist.
"""
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
model_name)
......@@ -410,11 +411,11 @@ class PaddleClas(object):
"""Init PaddleClas with config.
Args:
model_name: The model name supported by PaddleClas, default by None. If specified, override config.
inference_model_dir: The directory that contained model file and params file to be used, default by None. If specified, override config.
use_gpu: Whether use GPU, default by None. If specified, override config.
batch_size: The batch size to pridict, default by None. If specified, override config.
topk: Return the top k prediction results with the highest score.
model_name (str, optional): The model name supported by PaddleClas. If specified, override config. Defaults to None.
inference_model_dir (str, optional): The directory that contained model file and params file to be used. If specified, override config. Defaults to None.
use_gpu (bool, optional): Whether use GPU. If specified, override config. Defaults to True.
batch_size (int, optional): The batch size to pridict. If specified, override config. Defaults to 1.
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
"""
super().__init__()
self._config = init_config(model_name, inference_model_dir, use_gpu,
......@@ -459,20 +460,26 @@ class PaddleClas(object):
raise InputModelError(err)
return
def predict(self, input_data, print_pred=False):
def predict(self, input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
"""Predict input_data.
Args:
input_data (str | NumPy.array): The path of image, or the directory containing images, or the URL of image from Internet.
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
input_data (Union[str, np.array]):
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False.
Raises:
ImageTypeError: Illegal input_data.
Yields:
list: The prediction result(s) of input_data by batch_size. For every one image, prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". The format is as follow:
[{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if isinstance(input_data, np.ndarray):
outputs = self.cls_predictor.predict(input_data)
yield self.cls_predictor.postprocess(outputs)
......@@ -502,6 +509,7 @@ class PaddleClas(object):
f"Image file failed to read and has been skipped. The path: {img_path}"
)
continue
img = img[:, :, ::-1]
img_list.append(img)
img_path_list.append(img_path)
cnt += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册