提交 10376757 编写于 作者: 悟、's avatar 悟、 提交者: zengshao0622

cp update shitu whl

上级 8eba0f0b
# PP-ShiTu Whl 使用说明
PaddleClas 支持 Python Whl 包方式进行预测。
---
## 目录
<a name="1"></a>
## 1. 安装 paddleclas
* **[推荐]** 直接 pip 安装:
```bash
pip3 install paddleclas
```
* 如需使用 PaddleClas develop 分支体验最新功能,或是需要基于 PaddleClas 进行二次开发,请本地构建安装:
```bash
python3 setup.py install
```
<a name="2"></a>
## 2. 快速开始
<a name="2.1"></a>
### 2.1 构建索引库
下载demo数据集,命令如下:
```shell
# 下载 demo 数据并解压
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
```
解压完毕后,`drink_dataset_v2.0/` 文件夹下应有如下文件结构:
```log
├── drink_dataset_v2.0/
│ ├── gallery/
│ ├── index/
│ ├── index_all/
│ └── test_images/
├── ...
```
其中 `gallery` 文件夹中存放的是用于构建索引库的原始图像,`index` 表示基于原始图像构建得到的索引库信息,`test_images` 文件夹中存放的是用于测试识别效果的图像列表。
**在Python代码中构建索引库**
```python
from paddleclas import PaddleClas
build = PaddleClas(
build_gallery=True,
gallery_image_root='./drink_dataset_v2.0/gallery/',
gallery_data_file='./drink_dataset_v2.0/gallery/drink_label.txt',
index_dir='./drink_dataset_v2.0/index')
```
参数说明:
- build_gallery:是否使用索引库构建模式,默认为`False`
- gallery_image_root:构建索引库使用的`gallery`图像地址。
- gallery_data_file:构建索引库图像的真值文件。
- index_dir:索引库存放地址。
**在命令行中构建索引库**
```shell
paddleclas --build_gallery=True --model_name="PP-ShiTuV2" \
-o IndexProcess.image_root=./drink_dataset_v2.0/gallery/ \
-o IndexProcess.index_dir=./drink_dataset_v2.0/index \
-o IndexProcess.data_file=./drink_dataset_v2.0/gallery/drink_label.txt
```
其中参数`build_gallery(bool)`控制是否使用索引库构建模式,默认为`False`
同时可以通过`-o`指令更改构建索引库使用的配置,字段说明如下:
- IndexProcess.image_root(str): 构建索引库使用的`gallery`图像地址。
- IndexProcess.index_dir(str): 索引库存放地址。
- IndexProcess.data_file(str): 构建索引库图像的真值文件。
<a name="2.2"></a>
### 2.2 瓶装饮料识别
体验瓶装饮料识别,对图像`./drink_dataset_v2.0/test_images/001.jpeg`进行识别与检索。
待检索图像如下:
![](../../../images/recognition/drink_data_demo/test_images/100.jpeg)
**在Python代码中进行识别和检索**
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='PP-ShiTuV2',
index_dir='./drink_dataset_v2.0/index')
infer_imgs='./drink_dataset_v2.0/test_images/001.jpeg'
result=clas.predict(infer_imgs, predict_type='shitu')
print(next(result))
```
参数说明:
- model_name(str):用于检索和识别的模型。
- index_dir(str):用于检索的索引库地址。
最终输出结果如下:
```
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}]
```
**在命令行中进行识别和检索**
```shell
paddleclas --model_name=PP-ShiTuV2 --predict_type=shitu \
-o Global.infer_imgs='./drink_dataset_v2.0/test_images/001.jpeg' \
-o IndexProcess.index_dir='./drink_dataset_v2.0/index'
```
其中参数`model_name`为用于检索和识别的模型、`predict_type`设置为'shitu'模式。
同时可以通过`-o`指令更改检索图像以及索引库,字段说明如下:
- Global.infer_imgs(str):待检索图像地址。
- IndexProcess.index_dir(str): 索引库存放地址。
最终输出结果如下:
```
[{'bbox': [437, 71, 660, 728], 'rec_docs': '元气森林', 'rec_scores': 0.7740249}, {'bbox': [221, 72, 449, 701], 'rec_docs': '元气森林', 'rec_scores': 0.6950992}, {'bbox': [794, 104, 979, 652], 'rec_docs': '元气森林', 'rec_scores': 0.6305153}], filename: ./drink_dataset_v2.0/test_images/100.jpeg
```
......@@ -33,6 +33,7 @@ from .ppcls.utils import logger
from .deploy.python.predict_cls import ClsPredictor
from .deploy.python.predict_system import SystemPredictor
from .deploy.python.build_gallery import GalleryBuilder
from .deploy.utils.get_image_list import get_image_list
from .deploy.utils import config
......@@ -196,7 +197,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf
PULC_MODELS = [
"car_exists", "language_classification", "person_attribute",
"person_exists", "safety_helmet", "text_image_orientation",
"textline_orientation", "traffic_sign", "vehicle_attribute"
"textline_orientation", "traffic_sign", "vehicle_attribute",
"table_attribute"
]
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
......@@ -226,7 +228,9 @@ class InputModelError(Exception):
def init_config(model_type, model_name, inference_model_dir, **kwargs):
if model_type == "pulc":
if kwargs.get("build_gallery", False):
cfg_path = "deploy/configs/inference_general.yaml"
elif model_type == "pulc":
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml"
elif model_type == "shitu":
cfg_path = "deploy/configs/inference_general.yaml"
......@@ -235,7 +239,8 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
__dir__ = os.path.dirname(__file__)
cfg_path = os.path.join(__dir__, cfg_path)
cfg = config.get_config(cfg_path, show=False)
cfg = config.get_config(
cfg_path, overrides=kwargs.get("override", None), show=False)
if cfg.Global.get("inference_model_dir"):
cfg.Global.inference_model_dir = inference_model_dir
else:
......@@ -282,6 +287,7 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "thresh" in kwargs and kwargs[
"thresh"] and "ThreshOutput" in cfg.PostProcess:
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
if cfg.get("PostProcess"):
if "Topk" in cfg.PostProcess:
if "topk" in kwargs and kwargs["topk"]:
......@@ -301,7 +307,26 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"]
if "TableAttribute" in cfg.PostProcess:
if "source_threshold" in kwargs and kwargs["source_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"source_threshold"]
if "number_threshold" in kwargs and kwargs["number_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"number_threshold"]
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"clarity_threshold"]
if "obstruction_threshold" in kwargs and kwargs[
"obstruction_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"obstruction_threshold"]
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"angle_threshold"]
if "save_dir" in kwargs and kwargs["save_dir"]:
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
......@@ -316,10 +341,15 @@ def args_cfg():
parser.add_argument(
"--infer_imgs",
type=str,
required=True,
required=False,
help="The image(s) to be predicted.")
parser.add_argument(
"--model_name", type=str, help="The model name to be used.")
parser.add_argument(
"--predict_type",
type=str,
default="cls",
help="The predict type to be selected.")
parser.add_argument(
"--inference_model_dir",
type=str,
......@@ -374,7 +404,17 @@ def args_cfg():
parser.add_argument(
"--resize_short", type=int, help="Resize according to short size.")
parser.add_argument("--crop_size", type=int, help="Centor crop size.")
parser.add_argument(
"--build_gallery",
type=str2bool,
default=False,
help="Whether build gallery.")
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return vars(args)
......@@ -514,6 +554,10 @@ class PaddleClas(object):
"""
def __init__(self,
build_gallery: bool=False,
gallery_image_root: str=None,
gallery_data_file: str=None,
index_dir: str=None,
model_name: str=None,
inference_model_dir: str=None,
**kwargs):
......@@ -528,14 +572,35 @@ class PaddleClas(object):
"""
super().__init__()
self.model_type, inference_model_dir = self._check_input_model(
model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
if self.model_type == "shitu":
self.predictor = SystemPredictor(self._config)
if build_gallery:
self.model_type, inference_model_dir = self._check_input_model(
model_name
if model_name else "PP-ShiTuV2", inference_model_dir)
self._config = init_config(self.model_type, model_name
if model_name else "PP-ShiTuV2",
inference_model_dir, **kwargs)
if gallery_image_root:
self._config.IndexProcess.image_root = gallery_image_root
if gallery_data_file:
self._config.IndexProcess.data_file = gallery_data_file
if index_dir:
self._config.IndexProcess.index_dir = index_dir
logger.info("Building Gallery...")
GalleryBuilder(self._config)
else:
self.predictor = ClsPredictor(self._config)
self.model_type, inference_model_dir = self._check_input_model(
model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
if self.model_type == "shitu":
if index_dir:
self._config.IndexProcess.index_dir = index_dir
self.predictor = SystemPredictor(self._config)
else:
self.predictor = ClsPredictor(self._config)
def get_config(self):
"""Get the config.
......@@ -679,6 +744,9 @@ class PaddleClas(object):
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if input_data == None and self._config.Global.infer_imgs:
input_data = self._config.Global.infer_imgs
if isinstance(input_data, np.ndarray):
yield self.predictor.predict(input_data)
elif isinstance(input_data, str):
......@@ -721,6 +789,8 @@ class PaddleClas(object):
input_data: Union[str, np.array],
print_pred: bool=False,
predict_type="cls"):
assert predict_type in ["cls", "shitu"
], "Predict type should be 'cls' or 'shitu'."
if predict_type == "cls":
return self.predict_cls(input_data, print_pred)
elif predict_type == "shitu":
......@@ -739,13 +809,14 @@ def main():
print_info()
cfg = args_cfg()
clas_engine = PaddleClas(**cfg)
res = clas_engine.predict(
cfg["infer_imgs"],
print_pred=True,
predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu")
for _ in res:
pass
logger.info("Predict complete!")
if cfg["build_gallery"] == False:
res = clas_engine.predict(
cfg["infer_imgs"],
print_pred=True,
predict_type=cfg["predict_type"])
for _ in res:
pass
logger.info("Predict complete!")
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册