diff --git a/paddleclas.py b/paddleclas.py index 58094c8e3a146cb89346376409f1f0b6ffeeff19..1f107af4fe0e3b84f3fbc3edce1d35cad08131ca 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -32,6 +32,7 @@ from .ppcls.arch import backbone from .ppcls.utils import logger from .deploy.python.predict_cls import ClsPredictor +from .deploy.python.predict_system import SystemPredictor from .deploy.utils.get_image_list import get_image_list from .deploy.utils import config @@ -194,6 +195,14 @@ PULC_MODELS = [ "textline_orientation", "traffic_sign", "vehicle_attribute" ] +SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar" +SHITU_MODELS = [ + # "picodet_PPLCNet_x2_5_mainbody_lite_v1.0", # ShiTuV1(V2)_mainbody_det + # "general_PPLCNet_x2_5_lite_v1.0" # ShiTuV1_general_rec + # "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0", # ShiTuV2_general_rec TODO(hesensen): add lite model + "PP-ShiTuV2" +] + class ImageTypeError(Exception): """ImageTypeError. @@ -213,12 +222,24 @@ class InputModelError(Exception): def init_config(model_type, model_name, inference_model_dir, **kwargs): - cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml" + if model_type == "pulc": + cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" + elif model_type == "shitu": + cfg_path = "deploy/configs/inference_general.yaml" + else: + cfg_path = "deploy/configs/inference_cls.yaml" + __dir__ = os.path.dirname(__file__) cfg_path = os.path.join(__dir__, cfg_path) cfg = config.get_config(cfg_path, show=False) - - cfg.Global.inference_model_dir = inference_model_dir + if cfg.Global.get("inference_model_dir"): + cfg.Global.inference_model_dir = inference_model_dir + else: + cfg.Global.rec_inference_model_dir = os.path.join( + inference_model_dir, + "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0") + cfg.Global.det_inference_model_dir = os.path.join( + inference_model_dir, "picodet_PPLCNet_x2_5_mainbody_lite_v1.0") if "batch_size" in kwargs and kwargs["batch_size"]: cfg.Global.batch_size = kwargs["batch_size"] @@ -232,6 +253,10 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): if "infer_imgs" in kwargs and kwargs["infer_imgs"]: cfg.Global.infer_imgs = kwargs["infer_imgs"] + if "index_dir" in kwargs and kwargs["index_dir"]: + cfg.IndexProcess.index_dir = kwargs["index_dir"] + if "data_file" in kwargs and kwargs["data_file"]: + cfg.IndexProcess.data_file = kwargs["data_file"] if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]: cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"] if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]: @@ -253,24 +278,25 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): if "thresh" in kwargs and kwargs[ "thresh"] and "ThreshOutput" in cfg.PostProcess: cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"] - if "Topk" in cfg.PostProcess: - if "topk" in kwargs and kwargs["topk"]: - cfg.PostProcess.Topk.topk = kwargs["topk"] - if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]: - cfg.PostProcess.Topk.class_id_map_file = kwargs[ - "class_id_map_file"] - else: - class_id_map_file_path = os.path.relpath( - cfg.PostProcess.Topk.class_id_map_file, "../") - cfg.PostProcess.Topk.class_id_map_file = os.path.join( - __dir__, class_id_map_file_path) - if "VehicleAttribute" in cfg.PostProcess: - if "color_threshold" in kwargs and kwargs["color_threshold"]: - cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[ - "color_threshold"] - if "type_threshold" in kwargs and kwargs["type_threshold"]: - cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[ - "type_threshold"] + if cfg.get("PostProcess"): + if "Topk" in cfg.PostProcess: + if "topk" in kwargs and kwargs["topk"]: + cfg.PostProcess.Topk.topk = kwargs["topk"] + if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]: + cfg.PostProcess.Topk.class_id_map_file = kwargs[ + "class_id_map_file"] + else: + class_id_map_file_path = os.path.relpath( + cfg.PostProcess.Topk.class_id_map_file, "../") + cfg.PostProcess.Topk.class_id_map_file = os.path.join( + __dir__, class_id_map_file_path) + if "VehicleAttribute" in cfg.PostProcess: + if "color_threshold" in kwargs and kwargs["color_threshold"]: + cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[ + "color_threshold"] + if "type_threshold" in kwargs and kwargs["type_threshold"]: + cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[ + "type_threshold"] if "save_dir" in kwargs and kwargs["save_dir"]: cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"] @@ -295,6 +321,13 @@ def args_cfg(): type=str, help="The directory of model files. Valid when model_name not specifed." ) + parser.add_argument( + "--index_dir", + type=str, + required=False, + help="The index directory path.") + parser.add_argument( + "--data_file", type=str, required=False, help="The label file path.") parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.") parser.add_argument( "--gpu_mem", @@ -347,6 +380,7 @@ def print_info(): """ imn_table = PrettyTable(["IMN Model Series", "Model Name"]) pulc_table = PrettyTable(["PULC Models"]) + shitu_table = PrettyTable(["PP-ShiTu Models"]) try: sz = os.get_terminal_size() total_width = sz.columns @@ -365,11 +399,16 @@ def print_info(): textwrap.fill( " ".join(PULC_MODELS), width=total_width).center(table_width - 4) ]) + shitu_table.add_row([ + textwrap.fill( + " ".join(SHITU_MODELS), width=total_width).center(table_width - 4) + ]) print("{}".format("-" * table_width)) print("Models supported by PaddleClas".center(table_width)) print(imn_table) print(pulc_table) + print(shitu_table) print("Powered by PaddlePaddle!".rjust(table_width)) print("{}".format("-" * table_width)) @@ -425,6 +464,10 @@ def check_model_file(model_type, model_name): storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, "PULC", model_name) url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name) + elif model_type == "shitu": + storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, + "PP-ShiTu", model_name) + url = SHITU_MODEL_BASE_DOWNLOAD_URL.format(model_name) else: storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, "IMN", model_name) @@ -485,8 +528,10 @@ class PaddleClas(object): model_name, inference_model_dir) self._config = init_config(self.model_type, model_name, inference_model_dir, **kwargs) - - self.cls_predictor = ClsPredictor(self._config) + if self.model_type == "shitu": + self.predictor = SystemPredictor(self._config) + else: + self.predictor = ClsPredictor(self._config) def get_config(self): """Get the config. @@ -498,6 +543,7 @@ class PaddleClas(object): """ all_imn_model_names = get_imn_model_names() all_pulc_model_names = PULC_MODELS + all_shitu_model_names = SHITU_MODELS if model_name: if model_name in all_imn_model_names: @@ -506,6 +552,15 @@ class PaddleClas(object): elif model_name in all_pulc_model_names: inference_model_dir = check_model_file("pulc", model_name) return "pulc", inference_model_dir + elif model_name in all_shitu_model_names: + inference_model_dir = check_model_file( + "shitu", + "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0") + inference_model_dir = check_model_file( + "shitu", "picodet_PPLCNet_x2_5_mainbody_lite_v1.0") + inference_model_dir = os.path.abspath( + os.path.dirname(inference_model_dir)) + return "shitu", inference_model_dir else: similar_imn_names = similar_model_names(model_name, all_imn_model_names) @@ -526,12 +581,13 @@ class PaddleClas(object): raise InputModelError(err) return "custom", inference_model_dir else: - err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)." + err = "Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)." raise InputModelError(err) return None - def predict(self, input_data: Union[str, np.array], - print_pred: bool=False) -> Generator[list, None, None]: + def predict_cls(self, + input_data: Union[str, np.array], + print_pred: bool=False) -> Generator[list, None, None]: """Predict input_data. Args: @@ -551,7 +607,7 @@ class PaddleClas(object): """ if isinstance(input_data, np.ndarray): - yield self.cls_predictor.predict(input_data) + yield self.predictor.predict(input_data) elif isinstance(input_data, str): if input_data.startswith("http") or input_data.startswith("https"): image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR) @@ -583,7 +639,7 @@ class PaddleClas(object): cnt += 1 if cnt % batch_size == 0 or (idx_img + 1) == len(image_list): - preds = self.cls_predictor.predict(img_list) + preds = self.predictor.predict(img_list) if preds: for idx_pred, pred in enumerate(preds): @@ -600,6 +656,77 @@ class PaddleClas(object): raise ImageTypeError(err) return + def predict_shitu(self, + input_data: Union[str, np.array], + print_pred: bool=False) -> Generator[list, None, None]: + """Predict input_data. + Args: + input_data (Union[str, np.array]): + When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet. + When the type is np.array, it is the image data whose channel order is RGB. + print_pred (bool, optional): Whether print the prediction result. Defaults to False. + + Raises: + ImageTypeError: Illegal input_data. + + Yields: + Generator[list, None, None]: + The prediction result(s) of input_data by batch_size. For every one image, + prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". + The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] + """ + if isinstance(input_data, np.ndarray): + yield self.predictor.predict(input_data) + elif isinstance(input_data, str): + if input_data.startswith("http") or input_data.startswith("https"): + image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR) + if not os.path.exists(image_storage_dir()): + os.makedirs(image_storage_dir()) + image_save_path = image_storage_dir("tmp.jpg") + download_with_progressbar(input_data, image_save_path) + logger.info( + f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}" + ) + input_data = image_save_path + image_list = get_image_list(input_data) + + cnt = 0 + for idx_img, img_path in enumerate(image_list): + img = cv2.imread(img_path) + if img is None: + logger.warning( + f"Image file failed to read and has been skipped. The path: {img_path}" + ) + continue + img = img[:, :, ::-1] + cnt += 1 + + preds = self.predictor.predict( + img) # [dict1, dict2, ..., dictn] + if preds: + if print_pred: + logger.info(f"{preds}, filename: {img_path}") + + yield preds + else: + err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL" + raise ImageTypeError(err) + return + + def predict(self, + input_data: Union[str, np.array], + print_pred: bool=False, + predict_type="cls"): + if predict_type == "cls": + return self.predict_cls(input_data, print_pred) + elif predict_type == "shitu": + assert not isinstance(input_data, ( + list, tuple + )), "PP-ShiTu predictor only support single image as input now." + return self.predict_shitu(input_data, print_pred) + else: + raise ModuleNotFoundError + # for CLI def main(): @@ -608,7 +735,10 @@ def main(): print_info() cfg = args_cfg() clas_engine = PaddleClas(**cfg) - res = clas_engine.predict(cfg["infer_imgs"], print_pred=True) + res = clas_engine.predict( + cfg["infer_imgs"], + print_pred=True, + predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu") for _ in res: pass logger.info("Predict complete!")