diff --git a/paddleclas.py b/paddleclas.py index d7eb3f9430871a19c432f36789feeb3cabfce7a5..a1e3b2d77d23393be04b3b96e8715e239ef26841 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -33,6 +33,7 @@ from .ppcls.utils import logger from .deploy.python.predict_cls import ClsPredictor from .deploy.python.predict_system import SystemPredictor +from .deploy.python.build_gallery import GalleryBuilder from .deploy.utils.get_image_list import get_image_list from .deploy.utils import config @@ -227,7 +228,9 @@ class InputModelError(Exception): def init_config(model_type, model_name, inference_model_dir, **kwargs): - if model_type == "pulc": + if kwargs.get("build_gallery", False): + cfg_path = "deploy/configs/inference_general.yaml" + elif model_type == "pulc": cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" elif model_type == "shitu": cfg_path = "deploy/configs/inference_general.yaml" @@ -236,7 +239,8 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): __dir__ = os.path.dirname(__file__) cfg_path = os.path.join(__dir__, cfg_path) - cfg = config.get_config(cfg_path, show=False) + cfg = config.get_config( + cfg_path, overrides=kwargs.get("override", None), show=False) if cfg.Global.get("inference_model_dir"): cfg.Global.inference_model_dir = inference_model_dir else: @@ -337,10 +341,15 @@ def args_cfg(): parser.add_argument( "--infer_imgs", type=str, - required=True, + required=False, help="The image(s) to be predicted.") parser.add_argument( "--model_name", type=str, help="The model name to be used.") + parser.add_argument( + "--predict_type", + type=str, + default="cls", + help="The predict type to be selected.") parser.add_argument( "--inference_model_dir", type=str, @@ -395,7 +404,17 @@ def args_cfg(): parser.add_argument( "--resize_short", type=int, help="Resize according to short size.") parser.add_argument("--crop_size", type=int, help="Centor crop size.") - + parser.add_argument( + "--build_gallery", + type=str2bool, + default=False, + help="Whether build gallery.") + parser.add_argument( + '-o', + '--override', + action='append', + default=[], + help='config options to be overridden') args = parser.parse_args() return vars(args) @@ -549,14 +568,27 @@ class PaddleClas(object): """ super().__init__() - self.model_type, inference_model_dir = self._check_input_model( - model_name, inference_model_dir) - self._config = init_config(self.model_type, model_name, - inference_model_dir, **kwargs) - if self.model_type == "shitu": - self.predictor = SystemPredictor(self._config) + if kwargs.get("build_gallery", False): + self.model_type, inference_model_dir = self._check_input_model( + model_name + if model_name else "PP-ShiTuV2", inference_model_dir) + self._config = init_config(self.model_type, model_name + if model_name else "PP-ShiTuV2", + inference_model_dir, **kwargs) + + logger.info("Building Gallery...") + GalleryBuilder(self._config) + else: - self.predictor = ClsPredictor(self._config) + self.model_type, inference_model_dir = self._check_input_model( + model_name, inference_model_dir) + self._config = init_config(self.model_type, model_name, + inference_model_dir, **kwargs) + + if self.model_type == "shitu": + self.predictor = SystemPredictor(self._config) + else: + self.predictor = ClsPredictor(self._config) def get_config(self): """Get the config. @@ -700,6 +732,9 @@ class PaddleClas(object): prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] """ + if input_data == None and self._config.Global.infer_imgs: + input_data = self._config.Global.infer_imgs + if isinstance(input_data, np.ndarray): yield self.predictor.predict(input_data) elif isinstance(input_data, str): @@ -742,6 +777,8 @@ class PaddleClas(object): input_data: Union[str, np.array], print_pred: bool=False, predict_type="cls"): + assert predict_type in ["cls", "shitu" + ], "Predict type should be 'cls' or 'shitu'." if predict_type == "cls": return self.predict_cls(input_data, print_pred) elif predict_type == "shitu": @@ -760,13 +797,14 @@ def main(): print_info() cfg = args_cfg() clas_engine = PaddleClas(**cfg) - res = clas_engine.predict( - cfg["infer_imgs"], - print_pred=True, - predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu") - for _ in res: - pass - logger.info("Predict complete!") + if cfg["build_gallery"] == False: + res = clas_engine.predict( + cfg["infer_imgs"], + print_pred=True, + predict_type=cfg["predict_type"]) + for _ in res: + pass + logger.info("Predict complete!") return