提交 9d3c22f5 编写于 作者: 悟、's avatar 悟、 提交者: Tingquan Gao

update whl for shitu

上级 7e097f3c
......@@ -33,6 +33,7 @@ from .ppcls.utils import logger
from .deploy.python.predict_cls import ClsPredictor
from .deploy.python.predict_system import SystemPredictor
from .deploy.python.build_gallery import GalleryBuilder
from .deploy.utils.get_image_list import get_image_list
from .deploy.utils import config
......@@ -227,7 +228,9 @@ class InputModelError(Exception):
def init_config(model_type, model_name, inference_model_dir, **kwargs):
if model_type == "pulc":
if kwargs.get("build_gallery", False):
cfg_path = "deploy/configs/inference_general.yaml"
elif model_type == "pulc":
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml"
elif model_type == "shitu":
cfg_path = "deploy/configs/inference_general.yaml"
......@@ -236,7 +239,8 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
__dir__ = os.path.dirname(__file__)
cfg_path = os.path.join(__dir__, cfg_path)
cfg = config.get_config(cfg_path, show=False)
cfg = config.get_config(
cfg_path, overrides=kwargs.get("override", None), show=False)
if cfg.Global.get("inference_model_dir"):
cfg.Global.inference_model_dir = inference_model_dir
else:
......@@ -337,10 +341,15 @@ def args_cfg():
parser.add_argument(
"--infer_imgs",
type=str,
required=True,
required=False,
help="The image(s) to be predicted.")
parser.add_argument(
"--model_name", type=str, help="The model name to be used.")
parser.add_argument(
"--predict_type",
type=str,
default="cls",
help="The predict type to be selected.")
parser.add_argument(
"--inference_model_dir",
type=str,
......@@ -395,7 +404,17 @@ def args_cfg():
parser.add_argument(
"--resize_short", type=int, help="Resize according to short size.")
parser.add_argument("--crop_size", type=int, help="Centor crop size.")
parser.add_argument(
"--build_gallery",
type=str2bool,
default=False,
help="Whether build gallery.")
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return vars(args)
......@@ -549,14 +568,27 @@ class PaddleClas(object):
"""
super().__init__()
self.model_type, inference_model_dir = self._check_input_model(
model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
if self.model_type == "shitu":
self.predictor = SystemPredictor(self._config)
if kwargs.get("build_gallery", False):
self.model_type, inference_model_dir = self._check_input_model(
model_name
if model_name else "PP-ShiTuV2", inference_model_dir)
self._config = init_config(self.model_type, model_name
if model_name else "PP-ShiTuV2",
inference_model_dir, **kwargs)
logger.info("Building Gallery...")
GalleryBuilder(self._config)
else:
self.predictor = ClsPredictor(self._config)
self.model_type, inference_model_dir = self._check_input_model(
model_name, inference_model_dir)
self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
if self.model_type == "shitu":
self.predictor = SystemPredictor(self._config)
else:
self.predictor = ClsPredictor(self._config)
def get_config(self):
"""Get the config.
......@@ -700,6 +732,9 @@ class PaddleClas(object):
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if input_data == None and self._config.Global.infer_imgs:
input_data = self._config.Global.infer_imgs
if isinstance(input_data, np.ndarray):
yield self.predictor.predict(input_data)
elif isinstance(input_data, str):
......@@ -742,6 +777,8 @@ class PaddleClas(object):
input_data: Union[str, np.array],
print_pred: bool=False,
predict_type="cls"):
assert predict_type in ["cls", "shitu"
], "Predict type should be 'cls' or 'shitu'."
if predict_type == "cls":
return self.predict_cls(input_data, print_pred)
elif predict_type == "shitu":
......@@ -760,13 +797,14 @@ def main():
print_info()
cfg = args_cfg()
clas_engine = PaddleClas(**cfg)
res = clas_engine.predict(
cfg["infer_imgs"],
print_pred=True,
predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu")
for _ in res:
pass
logger.info("Predict complete!")
if cfg["build_gallery"] == False:
res = clas_engine.predict(
cfg["infer_imgs"],
print_pred=True,
predict_type=cfg["predict_type"])
for _ in res:
pass
logger.info("Predict complete!")
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册