提交 9b40ee0e 编写于 作者: H HydrogenSulfate

add shitu whl

上级 6b218caf
......@@ -32,6 +32,7 @@ from .ppcls.arch import backbone
from .ppcls.utils import logger
from .deploy.python.predict_cls import ClsPredictor
from .deploy.python.predict_system import SystemPredictor
from .deploy.utils.get_image_list import get_image_list
from .deploy.utils import config
......@@ -194,6 +195,14 @@ PULC_MODELS = [
"textline_orientation", "traffic_sign", "vehicle_attribute"
]
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
SHITU_MODELS = [
# "picodet_PPLCNet_x2_5_mainbody_lite_v1.0", # ShiTuV1(V2)_mainbody_det
# "general_PPLCNet_x2_5_lite_v1.0" # ShiTuV1_general_rec
# "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0", # ShiTuV2_general_rec TODO(hesensen): add lite model
"PP-ShiTuV2"
]
class ImageTypeError(Exception):
"""ImageTypeError.
......@@ -213,12 +222,24 @@ class InputModelError(Exception):
def init_config(model_type, model_name, inference_model_dir, **kwargs):
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
if model_type == "pulc":
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml"
elif model_type == "shitu":
cfg_path = "deploy/configs/inference_general.yaml"
else:
cfg_path = "deploy/configs/inference_cls.yaml"
__dir__ = os.path.dirname(__file__)
cfg_path = os.path.join(__dir__, cfg_path)
cfg = config.get_config(cfg_path, show=False)
cfg.Global.inference_model_dir = inference_model_dir
if cfg.Global.get("inference_model_dir"):
cfg.Global.inference_model_dir = inference_model_dir
else:
cfg.Global.rec_inference_model_dir = os.path.join(
inference_model_dir,
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0")
cfg.Global.det_inference_model_dir = os.path.join(
inference_model_dir, "picodet_PPLCNet_x2_5_mainbody_lite_v1.0")
if "batch_size" in kwargs and kwargs["batch_size"]:
cfg.Global.batch_size = kwargs["batch_size"]
......@@ -232,6 +253,10 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
cfg.Global.infer_imgs = kwargs["infer_imgs"]
if "index_dir" in kwargs and kwargs["index_dir"]:
cfg.IndexProcess.index_dir = kwargs["index_dir"]
if "data_file" in kwargs and kwargs["data_file"]:
cfg.IndexProcess.data_file = kwargs["data_file"]
if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
......@@ -253,24 +278,25 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "thresh" in kwargs and kwargs[
"thresh"] and "ThreshOutput" in cfg.PostProcess:
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
if "Topk" in cfg.PostProcess:
if "topk" in kwargs and kwargs["topk"]:
cfg.PostProcess.Topk.topk = kwargs["topk"]
if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
cfg.PostProcess.Topk.class_id_map_file = kwargs[
"class_id_map_file"]
else:
class_id_map_file_path = os.path.relpath(
cfg.PostProcess.Topk.class_id_map_file, "../")
cfg.PostProcess.Topk.class_id_map_file = os.path.join(
__dir__, class_id_map_file_path)
if "VehicleAttribute" in cfg.PostProcess:
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"]
if cfg.get("PostProcess"):
if "Topk" in cfg.PostProcess:
if "topk" in kwargs and kwargs["topk"]:
cfg.PostProcess.Topk.topk = kwargs["topk"]
if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
cfg.PostProcess.Topk.class_id_map_file = kwargs[
"class_id_map_file"]
else:
class_id_map_file_path = os.path.relpath(
cfg.PostProcess.Topk.class_id_map_file, "../")
cfg.PostProcess.Topk.class_id_map_file = os.path.join(
__dir__, class_id_map_file_path)
if "VehicleAttribute" in cfg.PostProcess:
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"]
if "save_dir" in kwargs and kwargs["save_dir"]:
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
......@@ -295,6 +321,13 @@ def args_cfg():
type=str,
help="The directory of model files. Valid when model_name not specifed."
)
parser.add_argument(
"--index_dir",
type=str,
required=False,
help="The index directory path.")
parser.add_argument(
"--data_file", type=str, required=False, help="The label file path.")
parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.")
parser.add_argument(
"--gpu_mem",
......@@ -347,6 +380,7 @@ def print_info():
"""
imn_table = PrettyTable(["IMN Model Series", "Model Name"])
pulc_table = PrettyTable(["PULC Models"])
shitu_table = PrettyTable(["PP-ShiTu Models"])
try:
sz = os.get_terminal_size()
total_width = sz.columns
......@@ -365,11 +399,16 @@ def print_info():
textwrap.fill(
" ".join(PULC_MODELS), width=total_width).center(table_width - 4)
])
shitu_table.add_row([
textwrap.fill(
" ".join(SHITU_MODELS), width=total_width).center(table_width - 4)
])
print("{}".format("-" * table_width))
print("Models supported by PaddleClas".center(table_width))
print(imn_table)
print(pulc_table)
print(shitu_table)
print("Powered by PaddlePaddle!".rjust(table_width))
print("{}".format("-" * table_width))
......@@ -425,6 +464,10 @@ def check_model_file(model_type, model_name):
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"PULC", model_name)
url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
elif model_type == "shitu":
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"PP-ShiTu", model_name)
url = SHITU_MODEL_BASE_DOWNLOAD_URL.format(model_name)
else:
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"IMN", model_name)
......@@ -485,8 +528,10 @@ class PaddleClas(object):
model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
self.cls_predictor = ClsPredictor(self._config)
if self.model_type == "shitu":
self.predictor = SystemPredictor(self._config)
else:
self.predictor = ClsPredictor(self._config)
def get_config(self):
"""Get the config.
......@@ -498,6 +543,7 @@ class PaddleClas(object):
"""
all_imn_model_names = get_imn_model_names()
all_pulc_model_names = PULC_MODELS
all_shitu_model_names = SHITU_MODELS
if model_name:
if model_name in all_imn_model_names:
......@@ -506,6 +552,15 @@ class PaddleClas(object):
elif model_name in all_pulc_model_names:
inference_model_dir = check_model_file("pulc", model_name)
return "pulc", inference_model_dir
elif model_name in all_shitu_model_names:
inference_model_dir = check_model_file(
"shitu",
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0")
inference_model_dir = check_model_file(
"shitu", "picodet_PPLCNet_x2_5_mainbody_lite_v1.0")
inference_model_dir = os.path.abspath(
os.path.dirname(inference_model_dir))
return "shitu", inference_model_dir
else:
similar_imn_names = similar_model_names(model_name,
all_imn_model_names)
......@@ -526,12 +581,13 @@ class PaddleClas(object):
raise InputModelError(err)
return "custom", inference_model_dir
else:
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
err = "Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise InputModelError(err)
return None
def predict(self, input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
def predict_cls(self,
input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
"""Predict input_data.
Args:
......@@ -551,7 +607,7 @@ class PaddleClas(object):
"""
if isinstance(input_data, np.ndarray):
yield self.cls_predictor.predict(input_data)
yield self.predictor.predict(input_data)
elif isinstance(input_data, str):
if input_data.startswith("http") or input_data.startswith("https"):
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
......@@ -583,7 +639,7 @@ class PaddleClas(object):
cnt += 1
if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
preds = self.cls_predictor.predict(img_list)
preds = self.predictor.predict(img_list)
if preds:
for idx_pred, pred in enumerate(preds):
......@@ -600,6 +656,77 @@ class PaddleClas(object):
raise ImageTypeError(err)
return
def predict_shitu(self,
input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]:
"""Predict input_data.
Args:
input_data (Union[str, np.array]):
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
Raises:
ImageTypeError: Illegal input_data.
Yields:
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if isinstance(input_data, np.ndarray):
yield self.predictor.predict(input_data)
elif isinstance(input_data, str):
if input_data.startswith("http") or input_data.startswith("https"):
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
if not os.path.exists(image_storage_dir()):
os.makedirs(image_storage_dir())
image_save_path = image_storage_dir("tmp.jpg")
download_with_progressbar(input_data, image_save_path)
logger.info(
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
)
input_data = image_save_path
image_list = get_image_list(input_data)
cnt = 0
for idx_img, img_path in enumerate(image_list):
img = cv2.imread(img_path)
if img is None:
logger.warning(
f"Image file failed to read and has been skipped. The path: {img_path}"
)
continue
img = img[:, :, ::-1]
cnt += 1
preds = self.predictor.predict(
img) # [dict1, dict2, ..., dictn]
if preds:
if print_pred:
logger.info(f"{preds}, filename: {img_path}")
yield preds
else:
err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
raise ImageTypeError(err)
return
def predict(self,
input_data: Union[str, np.array],
print_pred: bool=False,
predict_type="cls"):
if predict_type == "cls":
return self.predict_cls(input_data, print_pred)
elif predict_type == "shitu":
assert not isinstance(input_data, (
list, tuple
)), "PP-ShiTu predictor only support single image as input now."
return self.predict_shitu(input_data, print_pred)
else:
raise ModuleNotFoundError
# for CLI
def main():
......@@ -608,7 +735,10 @@ def main():
print_info()
cfg = args_cfg()
clas_engine = PaddleClas(**cfg)
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
res = clas_engine.predict(
cfg["infer_imgs"],
print_pred=True,
predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu")
for _ in res:
pass
logger.info("Predict complete!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册