提交 9b40ee0e 编写于 作者: H HydrogenSulfate

add shitu whl

上级 6b218caf
...@@ -32,6 +32,7 @@ from .ppcls.arch import backbone ...@@ -32,6 +32,7 @@ from .ppcls.arch import backbone
from .ppcls.utils import logger from .ppcls.utils import logger
from .deploy.python.predict_cls import ClsPredictor from .deploy.python.predict_cls import ClsPredictor
from .deploy.python.predict_system import SystemPredictor
from .deploy.utils.get_image_list import get_image_list from .deploy.utils.get_image_list import get_image_list
from .deploy.utils import config from .deploy.utils import config
...@@ -194,6 +195,14 @@ PULC_MODELS = [ ...@@ -194,6 +195,14 @@ PULC_MODELS = [
"textline_orientation", "traffic_sign", "vehicle_attribute" "textline_orientation", "traffic_sign", "vehicle_attribute"
] ]
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
SHITU_MODELS = [
# "picodet_PPLCNet_x2_5_mainbody_lite_v1.0", # ShiTuV1(V2)_mainbody_det
# "general_PPLCNet_x2_5_lite_v1.0" # ShiTuV1_general_rec
# "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0", # ShiTuV2_general_rec TODO(hesensen): add lite model
"PP-ShiTuV2"
]
class ImageTypeError(Exception): class ImageTypeError(Exception):
"""ImageTypeError. """ImageTypeError.
...@@ -213,12 +222,24 @@ class InputModelError(Exception): ...@@ -213,12 +222,24 @@ class InputModelError(Exception):
def init_config(model_type, model_name, inference_model_dir, **kwargs): def init_config(model_type, model_name, inference_model_dir, **kwargs):
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml" if model_type == "pulc":
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml"
elif model_type == "shitu":
cfg_path = "deploy/configs/inference_general.yaml"
else:
cfg_path = "deploy/configs/inference_cls.yaml"
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(__file__)
cfg_path = os.path.join(__dir__, cfg_path) cfg_path = os.path.join(__dir__, cfg_path)
cfg = config.get_config(cfg_path, show=False) cfg = config.get_config(cfg_path, show=False)
if cfg.Global.get("inference_model_dir"):
cfg.Global.inference_model_dir = inference_model_dir cfg.Global.inference_model_dir = inference_model_dir
else:
cfg.Global.rec_inference_model_dir = os.path.join(
inference_model_dir,
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0")
cfg.Global.det_inference_model_dir = os.path.join(
inference_model_dir, "picodet_PPLCNet_x2_5_mainbody_lite_v1.0")
if "batch_size" in kwargs and kwargs["batch_size"]: if "batch_size" in kwargs and kwargs["batch_size"]:
cfg.Global.batch_size = kwargs["batch_size"] cfg.Global.batch_size = kwargs["batch_size"]
...@@ -232,6 +253,10 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): ...@@ -232,6 +253,10 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "infer_imgs" in kwargs and kwargs["infer_imgs"]: if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
cfg.Global.infer_imgs = kwargs["infer_imgs"] cfg.Global.infer_imgs = kwargs["infer_imgs"]
if "index_dir" in kwargs and kwargs["index_dir"]:
cfg.IndexProcess.index_dir = kwargs["index_dir"]
if "data_file" in kwargs and kwargs["data_file"]:
cfg.IndexProcess.data_file = kwargs["data_file"]
if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]: if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"] cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]: if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
...@@ -253,6 +278,7 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): ...@@ -253,6 +278,7 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "thresh" in kwargs and kwargs[ if "thresh" in kwargs and kwargs[
"thresh"] and "ThreshOutput" in cfg.PostProcess: "thresh"] and "ThreshOutput" in cfg.PostProcess:
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"] cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
if cfg.get("PostProcess"):
if "Topk" in cfg.PostProcess: if "Topk" in cfg.PostProcess:
if "topk" in kwargs and kwargs["topk"]: if "topk" in kwargs and kwargs["topk"]:
cfg.PostProcess.Topk.topk = kwargs["topk"] cfg.PostProcess.Topk.topk = kwargs["topk"]
...@@ -295,6 +321,13 @@ def args_cfg(): ...@@ -295,6 +321,13 @@ def args_cfg():
type=str, type=str,
help="The directory of model files. Valid when model_name not specifed." help="The directory of model files. Valid when model_name not specifed."
) )
parser.add_argument(
"--index_dir",
type=str,
required=False,
help="The index directory path.")
parser.add_argument(
"--data_file", type=str, required=False, help="The label file path.")
parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.") parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.")
parser.add_argument( parser.add_argument(
"--gpu_mem", "--gpu_mem",
...@@ -347,6 +380,7 @@ def print_info(): ...@@ -347,6 +380,7 @@ def print_info():
""" """
imn_table = PrettyTable(["IMN Model Series", "Model Name"]) imn_table = PrettyTable(["IMN Model Series", "Model Name"])
pulc_table = PrettyTable(["PULC Models"]) pulc_table = PrettyTable(["PULC Models"])
shitu_table = PrettyTable(["PP-ShiTu Models"])
try: try:
sz = os.get_terminal_size() sz = os.get_terminal_size()
total_width = sz.columns total_width = sz.columns
...@@ -365,11 +399,16 @@ def print_info(): ...@@ -365,11 +399,16 @@ def print_info():
textwrap.fill( textwrap.fill(
" ".join(PULC_MODELS), width=total_width).center(table_width - 4) " ".join(PULC_MODELS), width=total_width).center(table_width - 4)
]) ])
shitu_table.add_row([
textwrap.fill(
" ".join(SHITU_MODELS), width=total_width).center(table_width - 4)
])
print("{}".format("-" * table_width)) print("{}".format("-" * table_width))
print("Models supported by PaddleClas".center(table_width)) print("Models supported by PaddleClas".center(table_width))
print(imn_table) print(imn_table)
print(pulc_table) print(pulc_table)
print(shitu_table)
print("Powered by PaddlePaddle!".rjust(table_width)) print("Powered by PaddlePaddle!".rjust(table_width))
print("{}".format("-" * table_width)) print("{}".format("-" * table_width))
...@@ -425,6 +464,10 @@ def check_model_file(model_type, model_name): ...@@ -425,6 +464,10 @@ def check_model_file(model_type, model_name):
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"PULC", model_name) "PULC", model_name)
url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name) url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
elif model_type == "shitu":
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"PP-ShiTu", model_name)
url = SHITU_MODEL_BASE_DOWNLOAD_URL.format(model_name)
else: else:
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"IMN", model_name) "IMN", model_name)
...@@ -485,8 +528,10 @@ class PaddleClas(object): ...@@ -485,8 +528,10 @@ class PaddleClas(object):
model_name, inference_model_dir) model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name, self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs) inference_model_dir, **kwargs)
if self.model_type == "shitu":
self.cls_predictor = ClsPredictor(self._config) self.predictor = SystemPredictor(self._config)
else:
self.predictor = ClsPredictor(self._config)
def get_config(self): def get_config(self):
"""Get the config. """Get the config.
...@@ -498,6 +543,7 @@ class PaddleClas(object): ...@@ -498,6 +543,7 @@ class PaddleClas(object):
""" """
all_imn_model_names = get_imn_model_names() all_imn_model_names = get_imn_model_names()
all_pulc_model_names = PULC_MODELS all_pulc_model_names = PULC_MODELS
all_shitu_model_names = SHITU_MODELS
if model_name: if model_name:
if model_name in all_imn_model_names: if model_name in all_imn_model_names:
...@@ -506,6 +552,15 @@ class PaddleClas(object): ...@@ -506,6 +552,15 @@ class PaddleClas(object):
elif model_name in all_pulc_model_names: elif model_name in all_pulc_model_names:
inference_model_dir = check_model_file("pulc", model_name) inference_model_dir = check_model_file("pulc", model_name)
return "pulc", inference_model_dir return "pulc", inference_model_dir
elif model_name in all_shitu_model_names:
inference_model_dir = check_model_file(
"shitu",
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0")
inference_model_dir = check_model_file(
"shitu", "picodet_PPLCNet_x2_5_mainbody_lite_v1.0")
inference_model_dir = os.path.abspath(
os.path.dirname(inference_model_dir))
return "shitu", inference_model_dir
else: else:
similar_imn_names = similar_model_names(model_name, similar_imn_names = similar_model_names(model_name,
all_imn_model_names) all_imn_model_names)
...@@ -526,11 +581,12 @@ class PaddleClas(object): ...@@ -526,11 +581,12 @@ class PaddleClas(object):
raise InputModelError(err) raise InputModelError(err)
return "custom", inference_model_dir return "custom", inference_model_dir
else: else:
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)." err = "Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise InputModelError(err) raise InputModelError(err)
return None return None
def predict(self, input_data: Union[str, np.array], def predict_cls(self,
input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]: print_pred: bool=False) -> Generator[list, None, None]:
"""Predict input_data. """Predict input_data.
...@@ -551,7 +607,7 @@ class PaddleClas(object): ...@@ -551,7 +607,7 @@ class PaddleClas(object):
""" """
if isinstance(input_data, np.ndarray): if isinstance(input_data, np.ndarray):
yield self.cls_predictor.predict(input_data) yield self.predictor.predict(input_data)
elif isinstance(input_data, str): elif isinstance(input_data, str):
if input_data.startswith("http") or input_data.startswith("https"): if input_data.startswith("http") or input_data.startswith("https"):
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR) image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
...@@ -583,7 +639,7 @@ class PaddleClas(object): ...@@ -583,7 +639,7 @@ class PaddleClas(object):
cnt += 1 cnt += 1
if cnt % batch_size == 0 or (idx_img + 1) == len(image_list): if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
preds = self.cls_predictor.predict(img_list) preds = self.predictor.predict(img_list)
if preds: if preds:
for idx_pred, pred in enumerate(preds): for idx_pred, pred in enumerate(preds):
...@@ -600,6 +656,77 @@ class PaddleClas(object): ...@@ -600,6 +656,77 @@ class PaddleClas(object):
raise ImageTypeError(err) raise ImageTypeError(err)
return return
def predict_shitu(self,
input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
"""Predict input_data.
Args:
input_data (Union[str, np.array]):
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
Raises:
ImageTypeError: Illegal input_data.
Yields:
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if isinstance(input_data, np.ndarray):
yield self.predictor.predict(input_data)
elif isinstance(input_data, str):
if input_data.startswith("http") or input_data.startswith("https"):
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
if not os.path.exists(image_storage_dir()):
os.makedirs(image_storage_dir())
image_save_path = image_storage_dir("tmp.jpg")
download_with_progressbar(input_data, image_save_path)
logger.info(
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
)
input_data = image_save_path
image_list = get_image_list(input_data)
cnt = 0
for idx_img, img_path in enumerate(image_list):
img = cv2.imread(img_path)
if img is None:
logger.warning(
f"Image file failed to read and has been skipped. The path: {img_path}"
)
continue
img = img[:, :, ::-1]
cnt += 1
preds = self.predictor.predict(
img) # [dict1, dict2, ..., dictn]
if preds:
if print_pred:
logger.info(f"{preds}, filename: {img_path}")
yield preds
else:
err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
raise ImageTypeError(err)
return
def predict(self,
input_data: Union[str, np.array],
print_pred: bool=False,
predict_type="cls"):
if predict_type == "cls":
return self.predict_cls(input_data, print_pred)
elif predict_type == "shitu":
assert not isinstance(input_data, (
list, tuple
)), "PP-ShiTu predictor only support single image as input now."
return self.predict_shitu(input_data, print_pred)
else:
raise ModuleNotFoundError
# for CLI # for CLI
def main(): def main():
...@@ -608,7 +735,10 @@ def main(): ...@@ -608,7 +735,10 @@ def main():
print_info() print_info()
cfg = args_cfg() cfg = args_cfg()
clas_engine = PaddleClas(**cfg) clas_engine = PaddleClas(**cfg)
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True) res = clas_engine.predict(
cfg["infer_imgs"],
print_pred=True,
predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu")
for _ in res: for _ in res:
pass pass
logger.info("Predict complete!") logger.info("Predict complete!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册