未验证 提交 1e78e71c 编写于 作者: G gaotingquan

feat: support PULC to deploy with whl

上级 794af8c0
...@@ -2,6 +2,7 @@ include LICENSE.txt ...@@ -2,6 +2,7 @@ include LICENSE.txt
include README.md include README.md
include docs/en/whl_en.md include docs/en/whl_en.md
recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py
recursive-include deploy/configs *.yaml
recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py
recursive-include ppcls/ *.py *.txt recursive-include ppcls/ *.py *.txt
\ No newline at end of file
...@@ -37,7 +37,7 @@ from deploy.python.predict_cls import ClsPredictor ...@@ -37,7 +37,7 @@ from deploy.python.predict_cls import ClsPredictor
from deploy.utils.get_image_list import get_image_list from deploy.utils.get_image_list import get_image_list
from deploy.utils import config from deploy.utils import config
from ppcls.arch.backbone import * import ppcls.arch.backbone as backbone
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
# for building model with loading pretrained weights from backbone # for building model with loading pretrained weights from backbone
...@@ -48,8 +48,8 @@ __all__ = ["PaddleClas"] ...@@ -48,8 +48,8 @@ __all__ = ["PaddleClas"]
BASE_DIR = os.path.expanduser("~/.paddleclas/") BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model") BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images") BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar" IMN_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
MODEL_SERIES = { IMN_MODEL_SERIES = {
"AlexNet": ["AlexNet"], "AlexNet": ["AlexNet"],
"DarkNet": ["DarkNet53"], "DarkNet": ["DarkNet53"],
"DeiT": [ "DeiT": [
...@@ -168,6 +168,12 @@ MODEL_SERIES = { ...@@ -168,6 +168,12 @@ MODEL_SERIES = {
] ]
} }
PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/{}_infer.tar"
PULC_MODELS = [
"person_exists", "person_attribute", "safety_helmet", "traffic_sign",
"car_exists", "car_attribute", "text_line", "multilingual"
]
class ImageTypeError(Exception): class ImageTypeError(Exception):
"""ImageTypeError. """ImageTypeError.
...@@ -185,76 +191,61 @@ class InputModelError(Exception): ...@@ -185,76 +191,61 @@ class InputModelError(Exception):
super().__init__(message) super().__init__(message)
def init_config(model_name, def init_config(model_type, model_name, inference_model_dir, **kwargs):
inference_model_dir,
use_gpu=True, cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
batch_size=1, cfg = config.get_config(cfg_path, show=False)
topk=5,
**kwargs): cfg.Global.inference_model_dir = inference_model_dir
imagenet1k_map_path = os.path.join(
os.path.abspath(__dir__), "ppcls/utils/imagenet1k_label_list.txt") if "batch_size" in kwargs and kwargs["batch_size"]:
cfg = { cfg.Global.batch_size = kwargs["batch_size"]
"Global": { if "use_gpu" in kwargs and kwargs["use_gpu"]:
"infer_imgs": kwargs["infer_imgs"] cfg.Global.use_gpu = kwargs["use_gpu"]
if "infer_imgs" in kwargs else False,
"model_name": model_name, if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
"inference_model_dir": inference_model_dir, cfg.Global.infer_imgs = kwargs["infer_imgs"]
"batch_size": batch_size, if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
"use_gpu": use_gpu, cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
"enable_mkldnn": kwargs["enable_mkldnn"] if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
if "enable_mkldnn" in kwargs else False, cfg.Global.cpu_num_threads = kwargs["cpu_num_threads"]
"cpu_num_threads": kwargs["cpu_num_threads"] if "use_fp16" in kwargs and kwargs["use_fp16"]:
if "cpu_num_threads" in kwargs else 1, cfg.Global.use_fp16 = kwargs["use_fp16"]
"enable_benchmark": False, if "use_tensorrt" in kwargs and kwargs["use_tensorrt"]:
"use_fp16": kwargs["use_fp16"] if "use_fp16" in kwargs else False, cfg.Global.use_tensorrt = kwargs["use_tensorrt"]
"ir_optim": True, if "gpu_mem" in kwargs and kwargs["gpu_mem"]:
"use_tensorrt": kwargs["use_tensorrt"] cfg.Global.gpu_mem = kwargs["gpu_mem"]
if "use_tensorrt" in kwargs else False, if "resize_short" in kwargs and kwargs["resize_short"]:
"gpu_mem": kwargs["gpu_mem"] if "gpu_mem" in kwargs else 8000, cfg.PreProcess.transform_ops[0]["ResizeImage"][
"enable_profile": False "resize_short"] = kwargs["resize_short"]
}, if "crop_size" in kwargs and kwargs["crop_size"]:
"PreProcess": { cfg.PreProcess.transform_ops[1]["CropImage"]["size"] = kwargs[
"transform_ops": [{ "crop_size"]
"ResizeImage": {
"resize_short": kwargs["resize_short"] # TODO(gaotingquan): not robust
if "resize_short" in kwargs else 256 if "thresh" in kwargs and kwargs[
} "thresh"] and "ThreshOutput" in cfg.PostProcess:
}, { cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
"CropImage": { if "Topk" in cfg.PostProcess:
"size": kwargs["crop_size"] if "topk" in kwargs and kwargs["topk"]:
if "crop_size" in kwargs else 224 cfg.PostProcess.Topk.topk = kwargs["topk"]
} if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
}, { cfg.PostProcess.Topk.class_id_map_file = kwargs[
"NormalizeImage": {
"scale": 0.00392157,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"order": ''
}
}, {
"ToCHWImage": None
}]
},
"PostProcess": {
"main_indicator": "Topk",
"Topk": {
"topk": topk,
"class_id_map_file": imagenet1k_map_path
}
}
}
if "save_dir" in kwargs:
if kwargs["save_dir"] is not None:
cfg["PostProcess"]["SavePreLabel"] = {
"save_dir": kwargs["save_dir"]
}
if "class_id_map_file" in kwargs:
if kwargs["class_id_map_file"] is not None:
cfg["PostProcess"]["Topk"]["class_id_map_file"] = kwargs[
"class_id_map_file"] "class_id_map_file"]
else:
cfg.PostProcess.Topk.class_id_map_file = os.path.relpath(
cfg.PostProcess.Topk.class_id_map_file, "../")
if "VehicleAttribute" in cfg.PostProcess:
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"]
if "save_dir" in kwargs and kwargs["save_dir"]:
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
cfg = config.AttrDict(cfg)
config.create_attr_dict(cfg)
return cfg return cfg
...@@ -275,8 +266,7 @@ def args_cfg(): ...@@ -275,8 +266,7 @@ def args_cfg():
type=str, type=str,
help="The directory of model files. Valid when model_name not specifed." help="The directory of model files. Valid when model_name not specifed."
) )
parser.add_argument( parser.add_argument("--use_gpu", type=str, help="Whether use GPU.")
"--use_gpu", type=str, default=True, help="Whether use GPU.")
parser.add_argument("--gpu_mem", type=int, default=8000, help="") parser.add_argument("--gpu_mem", type=int, default=8000, help="")
parser.add_argument( parser.add_argument(
"--enable_mkldnn", "--enable_mkldnn",
...@@ -287,28 +277,29 @@ def args_cfg(): ...@@ -287,28 +277,29 @@ def args_cfg():
parser.add_argument( parser.add_argument(
"--use_tensorrt", type=str2bool, default=False, help="") "--use_tensorrt", type=str2bool, default=False, help="")
parser.add_argument("--use_fp16", type=str2bool, default=False, help="") parser.add_argument("--use_fp16", type=str2bool, default=False, help="")
parser.add_argument( parser.add_argument("--batch_size", type=int, help="Batch size.")
"--batch_size", type=int, default=1, help="Batch size. Default by 1.")
parser.add_argument( parser.add_argument(
"--topk", "--topk",
type=int, type=int,
default=5, help="Return topk score(s) and corresponding results when Topk postprocess is used."
help="Return topk score(s) and corresponding results. Default by 5.") )
parser.add_argument( parser.add_argument(
"--class_id_map_file", "--class_id_map_file",
type=str, type=str,
help="The path of file that map class_id and label.") help="The path of file that map class_id and label.")
parser.add_argument(
"--threshold",
type=float,
help="The threshold of ThreshOutput when postprocess is used.")
parser.add_argument("--color_threshold", type=float, help="")
parser.add_argument("--type_threshold", type=float, help="")
parser.add_argument( parser.add_argument(
"--save_dir", "--save_dir",
type=str, type=str,
help="The directory to save prediction results as pre-label.") help="The directory to save prediction results as pre-label.")
parser.add_argument( parser.add_argument(
"--resize_short", "--resize_short", type=int, help="Resize according to short size.")
type=int, parser.add_argument("--crop_size", type=int, help="Centor crop size.")
default=256,
help="Resize according to short size.")
parser.add_argument(
"--crop_size", type=int, default=224, help="Centor crop size.")
args = parser.parse_args() args = parser.parse_args()
return vars(args) return vars(args)
...@@ -317,33 +308,44 @@ def args_cfg(): ...@@ -317,33 +308,44 @@ def args_cfg():
def print_info(): def print_info():
"""Print list of supported models in formatted. """Print list of supported models in formatted.
""" """
table = PrettyTable(["Series", "Name"]) imn_table = PrettyTable(["IMN Model Series", "Model Name"])
pulc_table = PrettyTable(["PULC Models"])
try: try:
sz = os.get_terminal_size() sz = os.get_terminal_size()
width = sz.columns - 30 if sz.columns > 50 else 10 total_width = sz.columns
first_width = 30
second_width = total_width - first_width if total_width > 50 else 10
except OSError: except OSError:
width = 100 second_width = 100
for series in MODEL_SERIES: for series in IMN_MODEL_SERIES:
names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width) names = textwrap.fill(
table.add_row([series, names]) " ".join(IMN_MODEL_SERIES[series]), width=second_width)
width = len(str(table).split("\n")[0]) imn_table.add_row([series, names])
print("{}".format("-" * width))
print("Models supported by PaddleClas".center(width)) table_width = len(str(imn_table).split("\n")[0])
print(table) pulc_table.add_row([
print("Powered by PaddlePaddle!".rjust(width)) textwrap.fill(
print("{}".format("-" * width)) " ".join(PULC_MODELS), width=total_width).center(table_width - 4)
])
def get_model_names(): print("{}".format("-" * table_width))
print("Models supported by PaddleClas".center(table_width))
print(imn_table)
print(pulc_table)
print("Powered by PaddlePaddle!".rjust(table_width))
print("{}".format("-" * table_width))
def get_imn_model_names():
"""Get the model names list. """Get the model names list.
""" """
model_names = [] model_names = []
for series in MODEL_SERIES: for series in IMN_MODEL_SERIES:
model_names += (MODEL_SERIES[series]) model_names += (IMN_MODEL_SERIES[series])
return model_names return model_names
def similar_architectures(name="", names=[], thresh=0.1, topk=10): def similar_model_names(name="", names=[], thresh=0.1, topk=5):
"""Find the most similar topk model names. """Find the most similar topk model names.
""" """
scores = [] scores = []
...@@ -378,12 +380,17 @@ def download_with_progressbar(url, save_path): ...@@ -378,12 +380,17 @@ def download_with_progressbar(url, save_path):
f"Something went wrong while downloading file from {url}") f"Something went wrong while downloading file from {url}")
def check_model_file(model_name): def check_model_file(model_type, model_name):
"""Check the model files exist and download and untar when no exist. """Check the model files exist and download and untar when no exist.
""" """
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, if model_type == "pulc":
model_name) storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
url = BASE_DOWNLOAD_URL.format(model_name) "PULC", model_name)
url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
else:
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"IMN", model_name)
url = IMN_MODEL_BASE_DOWNLOAD_URL.format(model_name)
tar_file_name_list = [ tar_file_name_list = [
"inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel" "inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
...@@ -426,9 +433,6 @@ class PaddleClas(object): ...@@ -426,9 +433,6 @@ class PaddleClas(object):
def __init__(self, def __init__(self,
model_name: str=None, model_name: str=None,
inference_model_dir: str=None, inference_model_dir: str=None,
use_gpu: bool=True,
batch_size: int=1,
topk: int=5,
**kwargs): **kwargs):
"""Init PaddleClas with config. """Init PaddleClas with config.
...@@ -440,9 +444,11 @@ class PaddleClas(object): ...@@ -440,9 +444,11 @@ class PaddleClas(object):
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5. topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
""" """
super().__init__() super().__init__()
self._config = init_config(model_name, inference_model_dir, use_gpu, self.model_type, inference_model_dir = self._check_input_model(
batch_size, topk, **kwargs) model_name, inference_model_dir)
self._check_input_model() self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
self.cls_predictor = ClsPredictor(self._config) self.cls_predictor = ClsPredictor(self._config)
def get_config(self): def get_config(self):
...@@ -450,24 +456,29 @@ class PaddleClas(object): ...@@ -450,24 +456,29 @@ class PaddleClas(object):
""" """
return self._config return self._config
def _check_input_model(self): def _check_input_model(self, model_name, inference_model_dir):
"""Check input model name or model files. """Check input model name or model files.
""" """
candidate_model_names = get_model_names() all_imn_model_names = get_imn_model_names()
input_model_name = self._config.Global.get("model_name", None) all_pulc_model_names = PULC_MODELS
inference_model_dir = self._config.Global.get("inference_model_dir",
None) if model_name:
if input_model_name is not None: if model_name in all_imn_model_names:
similar_names = similar_architectures(input_model_name, inference_model_dir = check_model_file("imn", model_name)
candidate_model_names) return "imn", inference_model_dir
similar_names_str = ", ".join(similar_names) elif model_name in all_pulc_model_names:
if input_model_name not in candidate_model_names: inference_model_dir = check_model_file("pulc", model_name)
err = f"{input_model_name} is not provided by PaddleClas. \nMaybe you want: [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!" return "pulc", inference_model_dir
else:
similar_imn_names = similar_model_names(model_name,
all_imn_model_names)
similar_pulc_names = similar_model_names(model_name,
all_pulc_model_names)
similar_names_str = ", ".join(similar_imn_names +
similar_pulc_names)
err = f"{model_name} is not provided by PaddleClas. \nMaybe you want the : [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
raise InputModelError(err) raise InputModelError(err)
self._config.Global.inference_model_dir = check_model_file( elif inference_model_dir:
input_model_name)
return
elif inference_model_dir is not None:
model_file_path = os.path.join(inference_model_dir, model_file_path = os.path.join(inference_model_dir,
"inference.pdmodel") "inference.pdmodel")
params_file_path = os.path.join(inference_model_dir, params_file_path = os.path.join(inference_model_dir,
...@@ -476,11 +487,11 @@ class PaddleClas(object): ...@@ -476,11 +487,11 @@ class PaddleClas(object):
params_file_path): params_file_path):
err = f"There is no model file or params file in this directory: {inference_model_dir}" err = f"There is no model file or params file in this directory: {inference_model_dir}"
raise InputModelError(err) raise InputModelError(err)
return return "custom", inference_model_dir
else: else:
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)." err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise InputModelError(err) raise InputModelError(err)
return return None
def predict(self, input_data: Union[str, np.array], def predict(self, input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]: print_pred: bool=False) -> Generator[list, None, None]:
...@@ -511,19 +522,18 @@ class PaddleClas(object): ...@@ -511,19 +522,18 @@ class PaddleClas(object):
os.makedirs(image_storage_dir()) os.makedirs(image_storage_dir())
image_save_path = image_storage_dir("tmp.jpg") image_save_path = image_storage_dir("tmp.jpg")
download_with_progressbar(input_data, image_save_path) download_with_progressbar(input_data, image_save_path)
input_data = image_save_path
warnings.warn( warnings.warn(
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}" f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
) )
input_data = image_save_path
image_list = get_image_list(input_data) image_list = get_image_list(input_data)
batch_size = self._config.Global.get("batch_size", 1) batch_size = self._config.Global.get("batch_size", 1)
topk = self._config.PostProcess.Topk.get('topk', 1)
img_list = [] img_list = []
img_path_list = [] img_path_list = []
cnt = 0 cnt = 0
for idx, img_path in enumerate(image_list): for idx_img, img_path in enumerate(image_list):
img = cv2.imread(img_path) img = cv2.imread(img_path)
if img is None: if img is None:
warnings.warn( warnings.warn(
...@@ -535,16 +545,15 @@ class PaddleClas(object): ...@@ -535,16 +545,15 @@ class PaddleClas(object):
img_path_list.append(img_path) img_path_list.append(img_path)
cnt += 1 cnt += 1
if cnt % batch_size == 0 or (idx + 1) == len(image_list): if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
preds = self.cls_predictor.predict(img_list) preds = self.cls_predictor.predict(img_list)
if print_pred and preds: if preds:
for idx, pred in enumerate(preds): for idx_pred, pred in enumerate(preds):
pred_str = ", ".join( pred["filename"] = img_path_list[idx_pred]
[f"{k}: {pred[k]}" for k in pred]) if print_pred:
print( print(", ".join(
f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}" [f"{k}: {pred[k]}" for k in pred]))
)
img_list = [] img_list = []
img_path_list = [] img_path_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册