diff --git a/MANIFEST.in b/MANIFEST.in
index b0a4f6dc151b0e11d83655d3f7ef40c200a88ee8..97372da0035488913c83dfe6f2ddfb8fe0c906c3 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,7 +1,8 @@
include LICENSE.txt
include README.md
include docs/en/whl_en.md
-recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py
+recursive-include deploy/python *.py
+recursive-include deploy/configs *.yaml
recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py
recursive-include ppcls/ *.py *.txt
\ No newline at end of file
diff --git a/deploy/configs/PULC/traffic_sign/inference_traffic_sign.yaml b/deploy/configs/PULC/traffic_sign/inference_traffic_sign.yaml
index 09c4521f2fc0e9c54d78226f13305a9daae9349d..53699718b4fdd38da86eaee4cccc584dcc87d2b7 100644
--- a/deploy/configs/PULC/traffic_sign/inference_traffic_sign.yaml
+++ b/deploy/configs/PULC/traffic_sign/inference_traffic_sign.yaml
@@ -30,6 +30,6 @@ PostProcess:
main_indicator: Topk
Topk:
topk: 5
- class_id_map_file: "../dataset/traffic_sign/label_name_id.txt"
+ class_id_map_file: "../ppcls/utils/PULC_label_list/traffic_sign_label_list.txt"
SavePreLabel:
save_dir: ./pre_label/
diff --git a/docs/en/inference_deployment/whl_deploy_en.md b/docs/en/inference_deployment/whl_deploy_en.md
index 224d41a7c1f2de9886fd830a36b8910dae0f97b6..9fd7223274b187ec187ecb4046948c5837c73b59 100644
--- a/docs/en/inference_deployment/whl_deploy_en.md
+++ b/docs/en/inference_deployment/whl_deploy_en.md
@@ -212,14 +212,14 @@ You can save the prediction result(s) as pre-label, only need to use `pre_label_
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/')
-infer_imgs = 'docs/images/inference_deployment/whl_' # it can be infer_imgs folder path which contains all of images you want to predict.
+infer_imgs = 'docs/images/' # it can be infer_imgs folder path which contains all of images you want to predict.
result=clas.predict(infer_imgs)
print(next(result))
```
* CLI
```bash
-paddleclas --model_name='ResNet50' --infer_imgs='docs/images/inference_deployment/whl_' --save_dir='./output_pre_label/'
+paddleclas --model_name='ResNet50' --infer_imgs='docs/images/' --save_dir='./output_pre_label/'
```
diff --git a/docs/zh_CN/inference_deployment/whl_deploy.md b/docs/zh_CN/inference_deployment/whl_deploy.md
index 14582ace5ce13636c7c14e7fdb9ba9ad2ebbfe90..e6ad70904853d17f89974ff62b812a3420d21a2b 100644
--- a/docs/zh_CN/inference_deployment/whl_deploy.md
+++ b/docs/zh_CN/inference_deployment/whl_deploy.md
@@ -18,7 +18,7 @@ PaddleClas 支持 Python Whl 包方式进行预测,目前 Whl 包方式仅支
- [4.6 对 `NumPy.ndarray` 格式数据进行预测](#4.6)
- [4.7 保存预测结果](#4.7)
- [4.8 指定 label name](#4.8)
-
+
## 1. 安装 paddleclas
@@ -212,14 +212,14 @@ print(next(result))
```python
from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/')
-infer_imgs = 'docs/images/whl/' # it can be infer_imgs folder path which contains all of images you want to predict.
+infer_imgs = 'docs/images/' # it can be infer_imgs folder path which contains all of images you want to predict.
result=clas.predict(infer_imgs)
print(next(result))
```
* CLI
```bash
-paddleclas --model_name='ResNet50' --infer_imgs='docs/images/whl/' --save_dir='./output_pre_label/'
+paddleclas --model_name='ResNet50' --infer_imgs='docs/images/' --save_dir='./output_pre_label/'
```
diff --git a/paddleclas.py b/paddleclas.py
index bfad1931bdec5c305000775a6af891f4d7295244..ef1c47daa119bca1693ba639420aca47fa2993eb 100644
--- a/paddleclas.py
+++ b/paddleclas.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,6 @@ import shutil
import textwrap
import tarfile
import requests
-import warnings
from functools import partial
from difflib import SequenceMatcher
@@ -32,24 +31,25 @@ import cv2
import numpy as np
from tqdm import tqdm
from prettytable import PrettyTable
+import paddle
from deploy.python.predict_cls import ClsPredictor
from deploy.utils.get_image_list import get_image_list
from deploy.utils import config
-from ppcls.arch.backbone import *
-from ppcls.utils.logger import init_logger
+import ppcls.arch.backbone as backbone
+from ppcls.utils import logger
# for building model with loading pretrained weights from backbone
-init_logger()
+logger.init_logger()
__all__ = ["PaddleClas"]
BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
-BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
-MODEL_SERIES = {
+IMN_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
+IMN_MODEL_SERIES = {
"AlexNet": ["AlexNet"],
"DarkNet": ["DarkNet53"],
"DeiT": [
@@ -100,10 +100,17 @@ MODEL_SERIES = {
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
],
+ "PPHGNet": [
+ "PPHGNet_tiny",
+ "PPHGNet_small",
+ "PPHGNet_tiny_ssld",
+ "PPHGNet_small_ssld",
+ ],
"PPLCNet": [
"PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75",
"PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5"
],
+ "PPLCNetV2": ["PPLCNetV2_base"],
"RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"],
"RegNet": ["RegNetX_4GF"],
"Res2Net": [
@@ -168,6 +175,13 @@ MODEL_SERIES = {
]
}
+PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/{}_infer.tar"
+PULC_MODELS = [
+ "person_exists", "person_attribute", "safety_helmet", "traffic_sign",
+ "vehicle_exists", "vehicle_attr", "textline_orientation",
+ "text_image_orientation", "language_classification"
+]
+
class ImageTypeError(Exception):
"""ImageTypeError.
@@ -185,76 +199,67 @@ class InputModelError(Exception):
super().__init__(message)
-def init_config(model_name,
- inference_model_dir,
- use_gpu=True,
- batch_size=1,
- topk=5,
- **kwargs):
- imagenet1k_map_path = os.path.join(
- os.path.abspath(__dir__), "ppcls/utils/imagenet1k_label_list.txt")
- cfg = {
- "Global": {
- "infer_imgs": kwargs["infer_imgs"]
- if "infer_imgs" in kwargs else False,
- "model_name": model_name,
- "inference_model_dir": inference_model_dir,
- "batch_size": batch_size,
- "use_gpu": use_gpu,
- "enable_mkldnn": kwargs["enable_mkldnn"]
- if "enable_mkldnn" in kwargs else False,
- "cpu_num_threads": kwargs["cpu_num_threads"]
- if "cpu_num_threads" in kwargs else 1,
- "enable_benchmark": False,
- "use_fp16": kwargs["use_fp16"] if "use_fp16" in kwargs else False,
- "ir_optim": True,
- "use_tensorrt": kwargs["use_tensorrt"]
- if "use_tensorrt" in kwargs else False,
- "gpu_mem": kwargs["gpu_mem"] if "gpu_mem" in kwargs else 8000,
- "enable_profile": False
- },
- "PreProcess": {
- "transform_ops": [{
- "ResizeImage": {
- "resize_short": kwargs["resize_short"]
- if "resize_short" in kwargs else 256
- }
- }, {
- "CropImage": {
- "size": kwargs["crop_size"]
- if "crop_size" in kwargs else 224
- }
- }, {
- "NormalizeImage": {
- "scale": 0.00392157,
- "mean": [0.485, 0.456, 0.406],
- "std": [0.229, 0.224, 0.225],
- "order": ''
- }
- }, {
- "ToCHWImage": None
- }]
- },
- "PostProcess": {
- "main_indicator": "Topk",
- "Topk": {
- "topk": topk,
- "class_id_map_file": imagenet1k_map_path
- }
- }
- }
- if "save_dir" in kwargs:
- if kwargs["save_dir"] is not None:
- cfg["PostProcess"]["SavePreLabel"] = {
- "save_dir": kwargs["save_dir"]
- }
- if "class_id_map_file" in kwargs:
- if kwargs["class_id_map_file"] is not None:
- cfg["PostProcess"]["Topk"]["class_id_map_file"] = kwargs[
+def init_config(model_type, model_name, inference_model_dir, **kwargs):
+
+ cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
+ cfg_path = os.path.join(__dir__, cfg_path)
+ cfg = config.get_config(cfg_path, show=False)
+
+ cfg.Global.inference_model_dir = inference_model_dir
+
+ if "batch_size" in kwargs and kwargs["batch_size"]:
+ cfg.Global.batch_size = kwargs["batch_size"]
+
+ if "use_gpu" in kwargs and kwargs["use_gpu"]:
+ cfg.Global.use_gpu = kwargs["use_gpu"]
+ if cfg.Global.use_gpu and not paddle.device.is_compiled_with_cuda():
+ msg = "The current running environment does not support the use of GPU. CPU has been used instead."
+ logger.warning(msg)
+ cfg.Global.use_gpu = False
+
+ if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
+ cfg.Global.infer_imgs = kwargs["infer_imgs"]
+ if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
+ cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
+ if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
+ cfg.Global.cpu_num_threads = kwargs["cpu_num_threads"]
+ if "use_fp16" in kwargs and kwargs["use_fp16"]:
+ cfg.Global.use_fp16 = kwargs["use_fp16"]
+ if "use_tensorrt" in kwargs and kwargs["use_tensorrt"]:
+ cfg.Global.use_tensorrt = kwargs["use_tensorrt"]
+ if "gpu_mem" in kwargs and kwargs["gpu_mem"]:
+ cfg.Global.gpu_mem = kwargs["gpu_mem"]
+ if "resize_short" in kwargs and kwargs["resize_short"]:
+ cfg.PreProcess.transform_ops[0]["ResizeImage"][
+ "resize_short"] = kwargs["resize_short"]
+ if "crop_size" in kwargs and kwargs["crop_size"]:
+ cfg.PreProcess.transform_ops[1]["CropImage"]["size"] = kwargs[
+ "crop_size"]
+
+ # TODO(gaotingquan): not robust
+ if "thresh" in kwargs and kwargs[
+ "thresh"] and "ThreshOutput" in cfg.PostProcess:
+ cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
+ if "Topk" in cfg.PostProcess:
+ if "topk" in kwargs and kwargs["topk"]:
+ cfg.PostProcess.Topk.topk = kwargs["topk"]
+ if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
+ cfg.PostProcess.Topk.class_id_map_file = kwargs[
"class_id_map_file"]
+ else:
+ cfg.PostProcess.Topk.class_id_map_file = os.path.relpath(
+ cfg.PostProcess.Topk.class_id_map_file, "../")
+ if "VehicleAttribute" in cfg.PostProcess:
+ if "color_threshold" in kwargs and kwargs["color_threshold"]:
+ cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
+ "color_threshold"]
+ if "type_threshold" in kwargs and kwargs["type_threshold"]:
+ cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
+ "type_threshold"]
+
+ if "save_dir" in kwargs and kwargs["save_dir"]:
+ cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
- cfg = config.AttrDict(cfg)
- config.create_attr_dict(cfg)
return cfg
@@ -275,40 +280,48 @@ def args_cfg():
type=str,
help="The directory of model files. Valid when model_name not specifed."
)
+ parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.")
parser.add_argument(
- "--use_gpu", type=str, default=True, help="Whether use GPU.")
- parser.add_argument("--gpu_mem", type=int, default=8000, help="")
+ "--gpu_mem",
+ type=int,
+ help="The memory size of GPU allocated to predict.")
parser.add_argument(
"--enable_mkldnn",
type=str2bool,
- default=False,
help="Whether use MKLDNN. Valid when use_gpu is False")
- parser.add_argument("--cpu_num_threads", type=int, default=1, help="")
parser.add_argument(
- "--use_tensorrt", type=str2bool, default=False, help="")
- parser.add_argument("--use_fp16", type=str2bool, default=False, help="")
+ "--cpu_num_threads",
+ type=int,
+ help="The threads number when predicting on CPU.")
+ parser.add_argument(
+ "--use_tensorrt",
+ type=str2bool,
+ help="Whether use TensorRT to accelerate. ")
parser.add_argument(
- "--batch_size", type=int, default=1, help="Batch size. Default by 1.")
+ "--use_fp16", type=str2bool, help="Whether use FP16 to predict.")
+ parser.add_argument("--batch_size", type=int, help="Batch size.")
parser.add_argument(
"--topk",
type=int,
- default=5,
- help="Return topk score(s) and corresponding results. Default by 5.")
+ help="Return topk score(s) and corresponding results when Topk postprocess is used."
+ )
parser.add_argument(
"--class_id_map_file",
type=str,
help="The path of file that map class_id and label.")
+ parser.add_argument(
+ "--threshold",
+ type=float,
+ help="The threshold of ThreshOutput when postprocess is used.")
+ parser.add_argument("--color_threshold", type=float, help="")
+ parser.add_argument("--type_threshold", type=float, help="")
parser.add_argument(
"--save_dir",
type=str,
help="The directory to save prediction results as pre-label.")
parser.add_argument(
- "--resize_short",
- type=int,
- default=256,
- help="Resize according to short size.")
- parser.add_argument(
- "--crop_size", type=int, default=224, help="Centor crop size.")
+ "--resize_short", type=int, help="Resize according to short size.")
+ parser.add_argument("--crop_size", type=int, help="Centor crop size.")
args = parser.parse_args()
return vars(args)
@@ -317,33 +330,44 @@ def args_cfg():
def print_info():
"""Print list of supported models in formatted.
"""
- table = PrettyTable(["Series", "Name"])
+ imn_table = PrettyTable(["IMN Model Series", "Model Name"])
+ pulc_table = PrettyTable(["PULC Models"])
try:
sz = os.get_terminal_size()
- width = sz.columns - 30 if sz.columns > 50 else 10
+ total_width = sz.columns
+ first_width = 30
+ second_width = total_width - first_width if total_width > 50 else 10
except OSError:
- width = 100
- for series in MODEL_SERIES:
- names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width)
- table.add_row([series, names])
- width = len(str(table).split("\n")[0])
- print("{}".format("-" * width))
- print("Models supported by PaddleClas".center(width))
- print(table)
- print("Powered by PaddlePaddle!".rjust(width))
- print("{}".format("-" * width))
-
-
-def get_model_names():
+ second_width = 100
+ for series in IMN_MODEL_SERIES:
+ names = textwrap.fill(
+ " ".join(IMN_MODEL_SERIES[series]), width=second_width)
+ imn_table.add_row([series, names])
+
+ table_width = len(str(imn_table).split("\n")[0])
+ pulc_table.add_row([
+ textwrap.fill(
+ " ".join(PULC_MODELS), width=total_width).center(table_width - 4)
+ ])
+
+ print("{}".format("-" * table_width))
+ print("Models supported by PaddleClas".center(table_width))
+ print(imn_table)
+ print(pulc_table)
+ print("Powered by PaddlePaddle!".rjust(table_width))
+ print("{}".format("-" * table_width))
+
+
+def get_imn_model_names():
"""Get the model names list.
"""
model_names = []
- for series in MODEL_SERIES:
- model_names += (MODEL_SERIES[series])
+ for series in IMN_MODEL_SERIES:
+ model_names += (IMN_MODEL_SERIES[series])
return model_names
-def similar_architectures(name="", names=[], thresh=0.1, topk=10):
+def similar_model_names(name="", names=[], thresh=0.1, topk=5):
"""Find the most similar topk model names.
"""
scores = []
@@ -378,12 +402,17 @@ def download_with_progressbar(url, save_path):
f"Something went wrong while downloading file from {url}")
-def check_model_file(model_name):
+def check_model_file(model_type, model_name):
"""Check the model files exist and download and untar when no exist.
"""
- storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
- model_name)
- url = BASE_DOWNLOAD_URL.format(model_name)
+ if model_type == "pulc":
+ storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
+ "PULC", model_name)
+ url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
+ else:
+ storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
+ "IMN", model_name)
+ url = IMN_MODEL_BASE_DOWNLOAD_URL.format(model_name)
tar_file_name_list = [
"inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
@@ -393,7 +422,7 @@ def check_model_file(model_name):
if not os.path.exists(model_file_path) or not os.path.exists(
params_file_path):
tmp_path = storage_directory(url.split("/")[-1])
- print(f"download {url} to {tmp_path}")
+ logger.info(f"download {url} to {tmp_path}")
os.makedirs(storage_directory(), exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, "r") as tarObj:
@@ -426,9 +455,6 @@ class PaddleClas(object):
def __init__(self,
model_name: str=None,
inference_model_dir: str=None,
- use_gpu: bool=True,
- batch_size: int=1,
- topk: int=5,
**kwargs):
"""Init PaddleClas with config.
@@ -440,9 +466,11 @@ class PaddleClas(object):
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
"""
super().__init__()
- self._config = init_config(model_name, inference_model_dir, use_gpu,
- batch_size, topk, **kwargs)
- self._check_input_model()
+ self.model_type, inference_model_dir = self._check_input_model(
+ model_name, inference_model_dir)
+ self._config = init_config(self.model_type, model_name,
+ inference_model_dir, **kwargs)
+
self.cls_predictor = ClsPredictor(self._config)
def get_config(self):
@@ -450,24 +478,29 @@ class PaddleClas(object):
"""
return self._config
- def _check_input_model(self):
+ def _check_input_model(self, model_name, inference_model_dir):
"""Check input model name or model files.
"""
- candidate_model_names = get_model_names()
- input_model_name = self._config.Global.get("model_name", None)
- inference_model_dir = self._config.Global.get("inference_model_dir",
- None)
- if input_model_name is not None:
- similar_names = similar_architectures(input_model_name,
- candidate_model_names)
- similar_names_str = ", ".join(similar_names)
- if input_model_name not in candidate_model_names:
- err = f"{input_model_name} is not provided by PaddleClas. \nMaybe you want: [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
+ all_imn_model_names = get_imn_model_names()
+ all_pulc_model_names = PULC_MODELS
+
+ if model_name:
+ if model_name in all_imn_model_names:
+ inference_model_dir = check_model_file("imn", model_name)
+ return "imn", inference_model_dir
+ elif model_name in all_pulc_model_names:
+ inference_model_dir = check_model_file("pulc", model_name)
+ return "pulc", inference_model_dir
+ else:
+ similar_imn_names = similar_model_names(model_name,
+ all_imn_model_names)
+ similar_pulc_names = similar_model_names(model_name,
+ all_pulc_model_names)
+ similar_names_str = ", ".join(similar_imn_names +
+ similar_pulc_names)
+ err = f"{model_name} is not provided by PaddleClas. \nMaybe you want the : [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
raise InputModelError(err)
- self._config.Global.inference_model_dir = check_model_file(
- input_model_name)
- return
- elif inference_model_dir is not None:
+ elif inference_model_dir:
model_file_path = os.path.join(inference_model_dir,
"inference.pdmodel")
params_file_path = os.path.join(inference_model_dir,
@@ -476,11 +509,11 @@ class PaddleClas(object):
params_file_path):
err = f"There is no model file or params file in this directory: {inference_model_dir}"
raise InputModelError(err)
- return
+ return "custom", inference_model_dir
else:
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise InputModelError(err)
- return
+ return None
def predict(self, input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
@@ -511,22 +544,21 @@ class PaddleClas(object):
os.makedirs(image_storage_dir())
image_save_path = image_storage_dir("tmp.jpg")
download_with_progressbar(input_data, image_save_path)
- input_data = image_save_path
- warnings.warn(
+ logger.info(
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
)
+ input_data = image_save_path
image_list = get_image_list(input_data)
batch_size = self._config.Global.get("batch_size", 1)
- topk = self._config.PostProcess.Topk.get('topk', 1)
img_list = []
img_path_list = []
cnt = 0
- for idx, img_path in enumerate(image_list):
+ for idx_img, img_path in enumerate(image_list):
img = cv2.imread(img_path)
if img is None:
- warnings.warn(
+ logger.warning(
f"Image file failed to read and has been skipped. The path: {img_path}"
)
continue
@@ -535,16 +567,15 @@ class PaddleClas(object):
img_path_list.append(img_path)
cnt += 1
- if cnt % batch_size == 0 or (idx + 1) == len(image_list):
+ if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
preds = self.cls_predictor.predict(img_list)
- if print_pred and preds:
- for idx, pred in enumerate(preds):
- pred_str = ", ".join(
- [f"{k}: {pred[k]}" for k in pred])
- print(
- f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}"
- )
+ if preds:
+ for idx_pred, pred in enumerate(preds):
+ pred["filename"] = img_path_list[idx_pred]
+ if print_pred:
+ logger.info(", ".join(
+ [f"{k}: {pred[k]}" for k in pred]))
img_list = []
img_path_list = []
@@ -564,7 +595,7 @@ def main():
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
for _ in res:
pass
- print("Predict complete!")
+ logger.info("Predict complete!")
return
diff --git a/ppcls/utils/PULC_label_list/text_image_orientation_label_list.txt b/ppcls/utils/PULC_label_list/text_image_orientation_label_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..051944a929f323a3a25f1807ac0297170513484a
--- /dev/null
+++ b/ppcls/utils/PULC_label_list/text_image_orientation_label_list.txt
@@ -0,0 +1,4 @@
+0 0
+1 90
+2 180
+3 270
diff --git a/ppcls/utils/PULC_label_list/traffic_sign_label_list.txt b/ppcls/utils/PULC_label_list/traffic_sign_label_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c1e41d539d1af5611b2b047d664000b8f41afb15
--- /dev/null
+++ b/ppcls/utils/PULC_label_list/traffic_sign_label_list.txt
@@ -0,0 +1,232 @@
+0 pl80
+1 w9
+2 p6
+3 ph4.2
+4 i8
+5 w14
+6 w33
+7 pa13
+8 im
+9 w58
+10 pl90
+11 il70
+12 p5
+13 pm55
+14 pl60
+15 ip
+16 p11
+17 pdd
+18 wc
+19 i2r
+20 w30
+21 pmr
+22 p23
+23 pl15
+24 pm10
+25 pss
+26 w1
+27 p4
+28 w38
+29 w50
+30 w34
+31 pw3.5
+32 iz
+33 w39
+34 w11
+35 p1n
+36 pr70
+37 pd
+38 pnl
+39 pg
+40 ph5.3
+41 w66
+42 il80
+43 pb
+44 pbm
+45 pm5
+46 w24
+47 w67
+48 w49
+49 pm40
+50 ph4
+51 w45
+52 i4
+53 w37
+54 ph2.6
+55 pl70
+56 ph5.5
+57 i14
+58 i11
+59 p7
+60 p29
+61 pne
+62 pr60
+63 pm13
+64 ph4.5
+65 p12
+66 p3
+67 w40
+68 pl5
+69 w13
+70 pr10
+71 p14
+72 i4l
+73 pr30
+74 pw4.2
+75 w16
+76 p17
+77 ph3
+78 i9
+79 w15
+80 w35
+81 pa8
+82 pt
+83 pr45
+84 w17
+85 pl30
+86 pcs
+87 pctl
+88 pr50
+89 ph4.4
+90 pm46
+91 pm35
+92 i15
+93 pa12
+94 pclr
+95 i1
+96 pcd
+97 pbp
+98 pcr
+99 w28
+100 ps
+101 pm8
+102 w18
+103 w2
+104 w52
+105 ph2.9
+106 ph1.8
+107 pe
+108 p20
+109 w36
+110 p10
+111 pn
+112 pa14
+113 w54
+114 ph3.2
+115 p2
+116 ph2.5
+117 w62
+118 w55
+119 pw3
+120 pw4.5
+121 i12
+122 ph4.3
+123 phclr
+124 i10
+125 pr5
+126 i13
+127 w10
+128 p26
+129 w26
+130 p8
+131 w5
+132 w42
+133 il50
+134 p13
+135 pr40
+136 p25
+137 w41
+138 pl20
+139 ph4.8
+140 pnlc
+141 ph3.3
+142 w29
+143 ph2.1
+144 w53
+145 pm30
+146 p24
+147 p21
+148 pl40
+149 w27
+150 pmb
+151 pc
+152 i6
+153 pr20
+154 p18
+155 ph3.8
+156 pm50
+157 pm25
+158 i2
+159 w22
+160 w47
+161 w56
+162 pl120
+163 ph2.8
+164 i7
+165 w12
+166 pm1.5
+167 pm2.5
+168 w32
+169 pm15
+170 ph5
+171 w19
+172 pw3.2
+173 pw2.5
+174 pl10
+175 il60
+176 w57
+177 w48
+178 w60
+179 pl100
+180 pr80
+181 p16
+182 pl110
+183 w59
+184 w64
+185 w20
+186 ph2
+187 p9
+188 il100
+189 w31
+190 w65
+191 ph2.4
+192 pr100
+193 p19
+194 ph3.5
+195 pa10
+196 pcl
+197 pl35
+198 p15
+199 w7
+200 pa6
+201 phcs
+202 w43
+203 p28
+204 w6
+205 w3
+206 w25
+207 pl25
+208 il110
+209 p1
+210 w46
+211 pn-2
+212 w51
+213 w44
+214 w63
+215 w23
+216 pm20
+217 w8
+218 pmblr
+219 w4
+220 i5
+221 il90
+222 w21
+223 p27
+224 pl50
+225 pl65
+226 w61
+227 ph2.2
+228 pm2
+229 i3
+230 pa18
+231 pw4
diff --git a/ppcls/utils/cls_demo/person_label_list.txt b/ppcls/utils/cls_demo/person_label_list.txt
deleted file mode 100644
index 8eea2b6dc2433abf303a0ea508021698559b749b..0000000000000000000000000000000000000000
--- a/ppcls/utils/cls_demo/person_label_list.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-0 nobody
-1 someone