未验证 提交 d5173bf1 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1987 from TingquanGao/dev/pulc_whl_deploy

[WIP] feat: support PULC to deploy with whl
include LICENSE.txt include LICENSE.txt
include README.md include README.md
include docs/en/whl_en.md include docs/en/whl_en.md
recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py recursive-include deploy/python *.py
recursive-include deploy/configs *.yaml
recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py
recursive-include ppcls/ *.py *.txt recursive-include ppcls/ *.py *.txt
\ No newline at end of file
...@@ -30,6 +30,6 @@ PostProcess: ...@@ -30,6 +30,6 @@ PostProcess:
main_indicator: Topk main_indicator: Topk
Topk: Topk:
topk: 5 topk: 5
class_id_map_file: "../dataset/traffic_sign/label_name_id.txt" class_id_map_file: "../ppcls/utils/PULC_label_list/traffic_sign_label_list.txt"
SavePreLabel: SavePreLabel:
save_dir: ./pre_label/ save_dir: ./pre_label/
...@@ -212,14 +212,14 @@ You can save the prediction result(s) as pre-label, only need to use `pre_label_ ...@@ -212,14 +212,14 @@ You can save the prediction result(s) as pre-label, only need to use `pre_label_
```python ```python
from paddleclas import PaddleClas from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/') clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/')
infer_imgs = 'docs/images/inference_deployment/whl_' # it can be infer_imgs folder path which contains all of images you want to predict. infer_imgs = 'docs/images/' # it can be infer_imgs folder path which contains all of images you want to predict.
result=clas.predict(infer_imgs) result=clas.predict(infer_imgs)
print(next(result)) print(next(result))
``` ```
* CLI * CLI
```bash ```bash
paddleclas --model_name='ResNet50' --infer_imgs='docs/images/inference_deployment/whl_' --save_dir='./output_pre_label/' paddleclas --model_name='ResNet50' --infer_imgs='docs/images/' --save_dir='./output_pre_label/'
``` ```
<a name="4.8"></a> <a name="4.8"></a>
......
...@@ -212,14 +212,14 @@ print(next(result)) ...@@ -212,14 +212,14 @@ print(next(result))
```python ```python
from paddleclas import PaddleClas from paddleclas import PaddleClas
clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/') clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/')
infer_imgs = 'docs/images/whl/' # it can be infer_imgs folder path which contains all of images you want to predict. infer_imgs = 'docs/images/' # it can be infer_imgs folder path which contains all of images you want to predict.
result=clas.predict(infer_imgs) result=clas.predict(infer_imgs)
print(next(result)) print(next(result))
``` ```
* CLI * CLI
```bash ```bash
paddleclas --model_name='ResNet50' --infer_imgs='docs/images/whl/' --save_dir='./output_pre_label/' paddleclas --model_name='ResNet50' --infer_imgs='docs/images/' --save_dir='./output_pre_label/'
``` ```
<a name="4.8"></a> <a name="4.8"></a>
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,7 +24,6 @@ import shutil ...@@ -24,7 +24,6 @@ import shutil
import textwrap import textwrap
import tarfile import tarfile
import requests import requests
import warnings
from functools import partial from functools import partial
from difflib import SequenceMatcher from difflib import SequenceMatcher
...@@ -32,24 +31,25 @@ import cv2 ...@@ -32,24 +31,25 @@ import cv2
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from prettytable import PrettyTable from prettytable import PrettyTable
import paddle
from deploy.python.predict_cls import ClsPredictor from deploy.python.predict_cls import ClsPredictor
from deploy.utils.get_image_list import get_image_list from deploy.utils.get_image_list import get_image_list
from deploy.utils import config from deploy.utils import config
from ppcls.arch.backbone import * import ppcls.arch.backbone as backbone
from ppcls.utils.logger import init_logger from ppcls.utils import logger
# for building model with loading pretrained weights from backbone # for building model with loading pretrained weights from backbone
init_logger() logger.init_logger()
__all__ = ["PaddleClas"] __all__ = ["PaddleClas"]
BASE_DIR = os.path.expanduser("~/.paddleclas/") BASE_DIR = os.path.expanduser("~/.paddleclas/")
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model") BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images") BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar" IMN_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
MODEL_SERIES = { IMN_MODEL_SERIES = {
"AlexNet": ["AlexNet"], "AlexNet": ["AlexNet"],
"DarkNet": ["DarkNet53"], "DarkNet": ["DarkNet53"],
"DeiT": [ "DeiT": [
...@@ -100,10 +100,17 @@ MODEL_SERIES = { ...@@ -100,10 +100,17 @@ MODEL_SERIES = {
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25", "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld" "MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
], ],
"PPHGNet": [
"PPHGNet_tiny",
"PPHGNet_small",
"PPHGNet_tiny_ssld",
"PPHGNet_small_ssld",
],
"PPLCNet": [ "PPLCNet": [
"PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75", "PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75",
"PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5" "PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5"
], ],
"PPLCNetV2": ["PPLCNetV2_base"],
"RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"], "RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"],
"RegNet": ["RegNetX_4GF"], "RegNet": ["RegNetX_4GF"],
"Res2Net": [ "Res2Net": [
...@@ -168,6 +175,13 @@ MODEL_SERIES = { ...@@ -168,6 +175,13 @@ MODEL_SERIES = {
] ]
} }
PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/{}_infer.tar"
PULC_MODELS = [
"person_exists", "person_attribute", "safety_helmet", "traffic_sign",
"vehicle_exists", "vehicle_attr", "textline_orientation",
"text_image_orientation", "language_classification"
]
class ImageTypeError(Exception): class ImageTypeError(Exception):
"""ImageTypeError. """ImageTypeError.
...@@ -185,76 +199,67 @@ class InputModelError(Exception): ...@@ -185,76 +199,67 @@ class InputModelError(Exception):
super().__init__(message) super().__init__(message)
def init_config(model_name, def init_config(model_type, model_name, inference_model_dir, **kwargs):
inference_model_dir,
use_gpu=True, cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
batch_size=1, cfg_path = os.path.join(__dir__, cfg_path)
topk=5, cfg = config.get_config(cfg_path, show=False)
**kwargs):
imagenet1k_map_path = os.path.join( cfg.Global.inference_model_dir = inference_model_dir
os.path.abspath(__dir__), "ppcls/utils/imagenet1k_label_list.txt")
cfg = { if "batch_size" in kwargs and kwargs["batch_size"]:
"Global": { cfg.Global.batch_size = kwargs["batch_size"]
"infer_imgs": kwargs["infer_imgs"]
if "infer_imgs" in kwargs else False, if "use_gpu" in kwargs and kwargs["use_gpu"]:
"model_name": model_name, cfg.Global.use_gpu = kwargs["use_gpu"]
"inference_model_dir": inference_model_dir, if cfg.Global.use_gpu and not paddle.device.is_compiled_with_cuda():
"batch_size": batch_size, msg = "The current running environment does not support the use of GPU. CPU has been used instead."
"use_gpu": use_gpu, logger.warning(msg)
"enable_mkldnn": kwargs["enable_mkldnn"] cfg.Global.use_gpu = False
if "enable_mkldnn" in kwargs else False,
"cpu_num_threads": kwargs["cpu_num_threads"] if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
if "cpu_num_threads" in kwargs else 1, cfg.Global.infer_imgs = kwargs["infer_imgs"]
"enable_benchmark": False, if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
"use_fp16": kwargs["use_fp16"] if "use_fp16" in kwargs else False, cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
"ir_optim": True, if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
"use_tensorrt": kwargs["use_tensorrt"] cfg.Global.cpu_num_threads = kwargs["cpu_num_threads"]
if "use_tensorrt" in kwargs else False, if "use_fp16" in kwargs and kwargs["use_fp16"]:
"gpu_mem": kwargs["gpu_mem"] if "gpu_mem" in kwargs else 8000, cfg.Global.use_fp16 = kwargs["use_fp16"]
"enable_profile": False if "use_tensorrt" in kwargs and kwargs["use_tensorrt"]:
}, cfg.Global.use_tensorrt = kwargs["use_tensorrt"]
"PreProcess": { if "gpu_mem" in kwargs and kwargs["gpu_mem"]:
"transform_ops": [{ cfg.Global.gpu_mem = kwargs["gpu_mem"]
"ResizeImage": { if "resize_short" in kwargs and kwargs["resize_short"]:
"resize_short": kwargs["resize_short"] cfg.PreProcess.transform_ops[0]["ResizeImage"][
if "resize_short" in kwargs else 256 "resize_short"] = kwargs["resize_short"]
} if "crop_size" in kwargs and kwargs["crop_size"]:
}, { cfg.PreProcess.transform_ops[1]["CropImage"]["size"] = kwargs[
"CropImage": { "crop_size"]
"size": kwargs["crop_size"]
if "crop_size" in kwargs else 224 # TODO(gaotingquan): not robust
} if "thresh" in kwargs and kwargs[
}, { "thresh"] and "ThreshOutput" in cfg.PostProcess:
"NormalizeImage": { cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
"scale": 0.00392157, if "Topk" in cfg.PostProcess:
"mean": [0.485, 0.456, 0.406], if "topk" in kwargs and kwargs["topk"]:
"std": [0.229, 0.224, 0.225], cfg.PostProcess.Topk.topk = kwargs["topk"]
"order": '' if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
} cfg.PostProcess.Topk.class_id_map_file = kwargs[
}, {
"ToCHWImage": None
}]
},
"PostProcess": {
"main_indicator": "Topk",
"Topk": {
"topk": topk,
"class_id_map_file": imagenet1k_map_path
}
}
}
if "save_dir" in kwargs:
if kwargs["save_dir"] is not None:
cfg["PostProcess"]["SavePreLabel"] = {
"save_dir": kwargs["save_dir"]
}
if "class_id_map_file" in kwargs:
if kwargs["class_id_map_file"] is not None:
cfg["PostProcess"]["Topk"]["class_id_map_file"] = kwargs[
"class_id_map_file"] "class_id_map_file"]
else:
cfg.PostProcess.Topk.class_id_map_file = os.path.relpath(
cfg.PostProcess.Topk.class_id_map_file, "../")
if "VehicleAttribute" in cfg.PostProcess:
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"]
if "save_dir" in kwargs and kwargs["save_dir"]:
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
cfg = config.AttrDict(cfg)
config.create_attr_dict(cfg)
return cfg return cfg
...@@ -275,40 +280,48 @@ def args_cfg(): ...@@ -275,40 +280,48 @@ def args_cfg():
type=str, type=str,
help="The directory of model files. Valid when model_name not specifed." help="The directory of model files. Valid when model_name not specifed."
) )
parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.")
parser.add_argument( parser.add_argument(
"--use_gpu", type=str, default=True, help="Whether use GPU.") "--gpu_mem",
parser.add_argument("--gpu_mem", type=int, default=8000, help="") type=int,
help="The memory size of GPU allocated to predict.")
parser.add_argument( parser.add_argument(
"--enable_mkldnn", "--enable_mkldnn",
type=str2bool, type=str2bool,
default=False,
help="Whether use MKLDNN. Valid when use_gpu is False") help="Whether use MKLDNN. Valid when use_gpu is False")
parser.add_argument("--cpu_num_threads", type=int, default=1, help="")
parser.add_argument( parser.add_argument(
"--use_tensorrt", type=str2bool, default=False, help="") "--cpu_num_threads",
parser.add_argument("--use_fp16", type=str2bool, default=False, help="") type=int,
help="The threads number when predicting on CPU.")
parser.add_argument( parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size. Default by 1.") "--use_tensorrt",
type=str2bool,
help="Whether use TensorRT to accelerate. ")
parser.add_argument(
"--use_fp16", type=str2bool, help="Whether use FP16 to predict.")
parser.add_argument("--batch_size", type=int, help="Batch size.")
parser.add_argument( parser.add_argument(
"--topk", "--topk",
type=int, type=int,
default=5, help="Return topk score(s) and corresponding results when Topk postprocess is used."
help="Return topk score(s) and corresponding results. Default by 5.") )
parser.add_argument( parser.add_argument(
"--class_id_map_file", "--class_id_map_file",
type=str, type=str,
help="The path of file that map class_id and label.") help="The path of file that map class_id and label.")
parser.add_argument(
"--threshold",
type=float,
help="The threshold of ThreshOutput when postprocess is used.")
parser.add_argument("--color_threshold", type=float, help="")
parser.add_argument("--type_threshold", type=float, help="")
parser.add_argument( parser.add_argument(
"--save_dir", "--save_dir",
type=str, type=str,
help="The directory to save prediction results as pre-label.") help="The directory to save prediction results as pre-label.")
parser.add_argument( parser.add_argument(
"--resize_short", "--resize_short", type=int, help="Resize according to short size.")
type=int, parser.add_argument("--crop_size", type=int, help="Centor crop size.")
default=256,
help="Resize according to short size.")
parser.add_argument(
"--crop_size", type=int, default=224, help="Centor crop size.")
args = parser.parse_args() args = parser.parse_args()
return vars(args) return vars(args)
...@@ -317,33 +330,44 @@ def args_cfg(): ...@@ -317,33 +330,44 @@ def args_cfg():
def print_info(): def print_info():
"""Print list of supported models in formatted. """Print list of supported models in formatted.
""" """
table = PrettyTable(["Series", "Name"]) imn_table = PrettyTable(["IMN Model Series", "Model Name"])
pulc_table = PrettyTable(["PULC Models"])
try: try:
sz = os.get_terminal_size() sz = os.get_terminal_size()
width = sz.columns - 30 if sz.columns > 50 else 10 total_width = sz.columns
first_width = 30
second_width = total_width - first_width if total_width > 50 else 10
except OSError: except OSError:
width = 100 second_width = 100
for series in MODEL_SERIES: for series in IMN_MODEL_SERIES:
names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width) names = textwrap.fill(
table.add_row([series, names]) " ".join(IMN_MODEL_SERIES[series]), width=second_width)
width = len(str(table).split("\n")[0]) imn_table.add_row([series, names])
print("{}".format("-" * width))
print("Models supported by PaddleClas".center(width)) table_width = len(str(imn_table).split("\n")[0])
print(table) pulc_table.add_row([
print("Powered by PaddlePaddle!".rjust(width)) textwrap.fill(
print("{}".format("-" * width)) " ".join(PULC_MODELS), width=total_width).center(table_width - 4)
])
def get_model_names(): print("{}".format("-" * table_width))
print("Models supported by PaddleClas".center(table_width))
print(imn_table)
print(pulc_table)
print("Powered by PaddlePaddle!".rjust(table_width))
print("{}".format("-" * table_width))
def get_imn_model_names():
"""Get the model names list. """Get the model names list.
""" """
model_names = [] model_names = []
for series in MODEL_SERIES: for series in IMN_MODEL_SERIES:
model_names += (MODEL_SERIES[series]) model_names += (IMN_MODEL_SERIES[series])
return model_names return model_names
def similar_architectures(name="", names=[], thresh=0.1, topk=10): def similar_model_names(name="", names=[], thresh=0.1, topk=5):
"""Find the most similar topk model names. """Find the most similar topk model names.
""" """
scores = [] scores = []
...@@ -378,12 +402,17 @@ def download_with_progressbar(url, save_path): ...@@ -378,12 +402,17 @@ def download_with_progressbar(url, save_path):
f"Something went wrong while downloading file from {url}") f"Something went wrong while downloading file from {url}")
def check_model_file(model_name): def check_model_file(model_type, model_name):
"""Check the model files exist and download and untar when no exist. """Check the model files exist and download and untar when no exist.
""" """
if model_type == "pulc":
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
"PULC", model_name)
url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
else:
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
model_name) "IMN", model_name)
url = BASE_DOWNLOAD_URL.format(model_name) url = IMN_MODEL_BASE_DOWNLOAD_URL.format(model_name)
tar_file_name_list = [ tar_file_name_list = [
"inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel" "inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
...@@ -393,7 +422,7 @@ def check_model_file(model_name): ...@@ -393,7 +422,7 @@ def check_model_file(model_name):
if not os.path.exists(model_file_path) or not os.path.exists( if not os.path.exists(model_file_path) or not os.path.exists(
params_file_path): params_file_path):
tmp_path = storage_directory(url.split("/")[-1]) tmp_path = storage_directory(url.split("/")[-1])
print(f"download {url} to {tmp_path}") logger.info(f"download {url} to {tmp_path}")
os.makedirs(storage_directory(), exist_ok=True) os.makedirs(storage_directory(), exist_ok=True)
download_with_progressbar(url, tmp_path) download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, "r") as tarObj: with tarfile.open(tmp_path, "r") as tarObj:
...@@ -426,9 +455,6 @@ class PaddleClas(object): ...@@ -426,9 +455,6 @@ class PaddleClas(object):
def __init__(self, def __init__(self,
model_name: str=None, model_name: str=None,
inference_model_dir: str=None, inference_model_dir: str=None,
use_gpu: bool=True,
batch_size: int=1,
topk: int=5,
**kwargs): **kwargs):
"""Init PaddleClas with config. """Init PaddleClas with config.
...@@ -440,9 +466,11 @@ class PaddleClas(object): ...@@ -440,9 +466,11 @@ class PaddleClas(object):
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5. topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
""" """
super().__init__() super().__init__()
self._config = init_config(model_name, inference_model_dir, use_gpu, self.model_type, inference_model_dir = self._check_input_model(
batch_size, topk, **kwargs) model_name, inference_model_dir)
self._check_input_model() self._config = init_config(self.model_type, model_name,
inference_model_dir, **kwargs)
self.cls_predictor = ClsPredictor(self._config) self.cls_predictor = ClsPredictor(self._config)
def get_config(self): def get_config(self):
...@@ -450,24 +478,29 @@ class PaddleClas(object): ...@@ -450,24 +478,29 @@ class PaddleClas(object):
""" """
return self._config return self._config
def _check_input_model(self): def _check_input_model(self, model_name, inference_model_dir):
"""Check input model name or model files. """Check input model name or model files.
""" """
candidate_model_names = get_model_names() all_imn_model_names = get_imn_model_names()
input_model_name = self._config.Global.get("model_name", None) all_pulc_model_names = PULC_MODELS
inference_model_dir = self._config.Global.get("inference_model_dir",
None) if model_name:
if input_model_name is not None: if model_name in all_imn_model_names:
similar_names = similar_architectures(input_model_name, inference_model_dir = check_model_file("imn", model_name)
candidate_model_names) return "imn", inference_model_dir
similar_names_str = ", ".join(similar_names) elif model_name in all_pulc_model_names:
if input_model_name not in candidate_model_names: inference_model_dir = check_model_file("pulc", model_name)
err = f"{input_model_name} is not provided by PaddleClas. \nMaybe you want: [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!" return "pulc", inference_model_dir
else:
similar_imn_names = similar_model_names(model_name,
all_imn_model_names)
similar_pulc_names = similar_model_names(model_name,
all_pulc_model_names)
similar_names_str = ", ".join(similar_imn_names +
similar_pulc_names)
err = f"{model_name} is not provided by PaddleClas. \nMaybe you want the : [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
raise InputModelError(err) raise InputModelError(err)
self._config.Global.inference_model_dir = check_model_file( elif inference_model_dir:
input_model_name)
return
elif inference_model_dir is not None:
model_file_path = os.path.join(inference_model_dir, model_file_path = os.path.join(inference_model_dir,
"inference.pdmodel") "inference.pdmodel")
params_file_path = os.path.join(inference_model_dir, params_file_path = os.path.join(inference_model_dir,
...@@ -476,11 +509,11 @@ class PaddleClas(object): ...@@ -476,11 +509,11 @@ class PaddleClas(object):
params_file_path): params_file_path):
err = f"There is no model file or params file in this directory: {inference_model_dir}" err = f"There is no model file or params file in this directory: {inference_model_dir}"
raise InputModelError(err) raise InputModelError(err)
return return "custom", inference_model_dir
else: else:
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)." err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise InputModelError(err) raise InputModelError(err)
return return None
def predict(self, input_data: Union[str, np.array], def predict(self, input_data: Union[str, np.array],
print_pred: bool=False) -> Generator[list, None, None]: print_pred: bool=False) -> Generator[list, None, None]:
...@@ -511,22 +544,21 @@ class PaddleClas(object): ...@@ -511,22 +544,21 @@ class PaddleClas(object):
os.makedirs(image_storage_dir()) os.makedirs(image_storage_dir())
image_save_path = image_storage_dir("tmp.jpg") image_save_path = image_storage_dir("tmp.jpg")
download_with_progressbar(input_data, image_save_path) download_with_progressbar(input_data, image_save_path)
input_data = image_save_path logger.info(
warnings.warn(
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}" f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
) )
input_data = image_save_path
image_list = get_image_list(input_data) image_list = get_image_list(input_data)
batch_size = self._config.Global.get("batch_size", 1) batch_size = self._config.Global.get("batch_size", 1)
topk = self._config.PostProcess.Topk.get('topk', 1)
img_list = [] img_list = []
img_path_list = [] img_path_list = []
cnt = 0 cnt = 0
for idx, img_path in enumerate(image_list): for idx_img, img_path in enumerate(image_list):
img = cv2.imread(img_path) img = cv2.imread(img_path)
if img is None: if img is None:
warnings.warn( logger.warning(
f"Image file failed to read and has been skipped. The path: {img_path}" f"Image file failed to read and has been skipped. The path: {img_path}"
) )
continue continue
...@@ -535,16 +567,15 @@ class PaddleClas(object): ...@@ -535,16 +567,15 @@ class PaddleClas(object):
img_path_list.append(img_path) img_path_list.append(img_path)
cnt += 1 cnt += 1
if cnt % batch_size == 0 or (idx + 1) == len(image_list): if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
preds = self.cls_predictor.predict(img_list) preds = self.cls_predictor.predict(img_list)
if print_pred and preds: if preds:
for idx, pred in enumerate(preds): for idx_pred, pred in enumerate(preds):
pred_str = ", ".join( pred["filename"] = img_path_list[idx_pred]
[f"{k}: {pred[k]}" for k in pred]) if print_pred:
print( logger.info(", ".join(
f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}" [f"{k}: {pred[k]}" for k in pred]))
)
img_list = [] img_list = []
img_path_list = [] img_path_list = []
...@@ -564,7 +595,7 @@ def main(): ...@@ -564,7 +595,7 @@ def main():
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True) res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
for _ in res: for _ in res:
pass pass
print("Predict complete!") logger.info("Predict complete!")
return return
......
0 pl80
1 w9
2 p6
3 ph4.2
4 i8
5 w14
6 w33
7 pa13
8 im
9 w58
10 pl90
11 il70
12 p5
13 pm55
14 pl60
15 ip
16 p11
17 pdd
18 wc
19 i2r
20 w30
21 pmr
22 p23
23 pl15
24 pm10
25 pss
26 w1
27 p4
28 w38
29 w50
30 w34
31 pw3.5
32 iz
33 w39
34 w11
35 p1n
36 pr70
37 pd
38 pnl
39 pg
40 ph5.3
41 w66
42 il80
43 pb
44 pbm
45 pm5
46 w24
47 w67
48 w49
49 pm40
50 ph4
51 w45
52 i4
53 w37
54 ph2.6
55 pl70
56 ph5.5
57 i14
58 i11
59 p7
60 p29
61 pne
62 pr60
63 pm13
64 ph4.5
65 p12
66 p3
67 w40
68 pl5
69 w13
70 pr10
71 p14
72 i4l
73 pr30
74 pw4.2
75 w16
76 p17
77 ph3
78 i9
79 w15
80 w35
81 pa8
82 pt
83 pr45
84 w17
85 pl30
86 pcs
87 pctl
88 pr50
89 ph4.4
90 pm46
91 pm35
92 i15
93 pa12
94 pclr
95 i1
96 pcd
97 pbp
98 pcr
99 w28
100 ps
101 pm8
102 w18
103 w2
104 w52
105 ph2.9
106 ph1.8
107 pe
108 p20
109 w36
110 p10
111 pn
112 pa14
113 w54
114 ph3.2
115 p2
116 ph2.5
117 w62
118 w55
119 pw3
120 pw4.5
121 i12
122 ph4.3
123 phclr
124 i10
125 pr5
126 i13
127 w10
128 p26
129 w26
130 p8
131 w5
132 w42
133 il50
134 p13
135 pr40
136 p25
137 w41
138 pl20
139 ph4.8
140 pnlc
141 ph3.3
142 w29
143 ph2.1
144 w53
145 pm30
146 p24
147 p21
148 pl40
149 w27
150 pmb
151 pc
152 i6
153 pr20
154 p18
155 ph3.8
156 pm50
157 pm25
158 i2
159 w22
160 w47
161 w56
162 pl120
163 ph2.8
164 i7
165 w12
166 pm1.5
167 pm2.5
168 w32
169 pm15
170 ph5
171 w19
172 pw3.2
173 pw2.5
174 pl10
175 il60
176 w57
177 w48
178 w60
179 pl100
180 pr80
181 p16
182 pl110
183 w59
184 w64
185 w20
186 ph2
187 p9
188 il100
189 w31
190 w65
191 ph2.4
192 pr100
193 p19
194 ph3.5
195 pa10
196 pcl
197 pl35
198 p15
199 w7
200 pa6
201 phcs
202 w43
203 p28
204 w6
205 w3
206 w25
207 pl25
208 il110
209 p1
210 w46
211 pn-2
212 w51
213 w44
214 w63
215 w23
216 pm20
217 w8
218 pmblr
219 w4
220 i5
221 il90
222 w21
223 p27
224 pl50
225 pl65
226 w61
227 ph2.2
228 pm2
229 i3
230 pa18
231 pw4
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册