未验证 提交 b3922c96 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1436 from TingquanGao/dev/update_whl

update whl
......@@ -14,3 +14,4 @@
__all__ = ['PaddleClas']
from .paddleclas import PaddleClas
from ppcls.arch.backbone import *
......@@ -38,6 +38,10 @@ from deploy.utils.get_image_list import get_image_list
from deploy.utils import config
from ppcls.arch.backbone import *
from ppcls.utils.logger import init_logger
# for building model with loading pretrained weights from backbone
init_logger()
__all__ = ["PaddleClas"]
......@@ -58,20 +62,27 @@ MODEL_SERIES = {
"DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
"DenseNet264"
],
"DLA": [
"DLA46_c", "DLA60x_c", "DLA34", "DLA60", "DLA60x", "DLA102", "DLA102x",
"DLA102x2", "DLA169"
],
"DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
"EfficientNet": [
"EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
"EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
"EfficientNetB6", "EfficientNetB7"
],
"ESNet": ["ESNet_x0_25", "ESNet_x0_5", "ESNet_x0_75", "ESNet_x1_0"],
"GhostNet":
["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
"HarDNet": ["HarDNet39_ds", "HarDNet68_ds", "HarDNet68", "HarDNet85"],
"HRNet": [
"HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
"HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
"HRNet_W48_C_ssld"
],
"Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
"MixNet": ["MixNet_S", "MixNet_M", "MixNet_L"],
"MobileNetV1": [
"MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
"MobileNetV1", "MobileNetV1_ssld"
......@@ -89,6 +100,11 @@ MODEL_SERIES = {
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
],
"PPLCNet": [
"PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75",
"PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5"
],
"RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"],
"RegNet": ["RegNetX_4GF"],
"Res2Net": [
"Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
......@@ -113,6 +129,8 @@ MODEL_SERIES = {
"ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
"ResNeXt152_vd_64x4d"
],
"ReXNet":
["ReXNet_1_0", "ReXNet_1_3", "ReXNet_1_5", "ReXNet_2_0", "ReXNet_3_0"],
"SENet": [
"SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
"SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
......@@ -134,6 +152,10 @@ MODEL_SERIES = {
"SwinTransformer_small_patch4_window7_224",
"SwinTransformer_tiny_patch4_window7_224"
],
"Twins": [
"pcpvt_small", "pcpvt_base", "pcpvt_large", "alt_gvt_small",
"alt_gvt_base", "alt_gvt_large"
],
"VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
"VisionTransformer": [
"ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
......@@ -465,24 +487,23 @@ class PaddleClas(object):
"""Predict input_data.
Args:
input_data (Union[str, np.array]):
input_data (Union[str, np.array]):
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False.
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
Raises:
ImageTypeError: Illegal input_data.
Yields:
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if isinstance(input_data, np.ndarray):
outputs = self.cls_predictor.predict(input_data)
yield self.cls_predictor.postprocess(outputs)
yield self.cls_predictor.predict(input_data)
elif isinstance(input_data, str):
if input_data.startswith("http") or input_data.startswith("https"):
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
......@@ -497,7 +518,7 @@ class PaddleClas(object):
image_list = get_image_list(input_data)
batch_size = self._config.Global.get("batch_size", 1)
topk = self._config.PostProcess.get('topk', 1)
topk = self._config.PostProcess.Topk.get('topk', 1)
img_list = []
img_path_list = []
......@@ -515,16 +536,15 @@ class PaddleClas(object):
cnt += 1
if cnt % batch_size == 0 or (idx + 1) == len(image_list):
outputs = self.cls_predictor.predict(img_list)
preds = self.cls_predictor.postprocess(outputs,
img_path_list)
preds = self.cls_predictor.predict(img_list)
if print_pred and preds:
for pred in preds:
filename = pred.pop("file_name")
for idx, pred in enumerate(preds):
pred_str = ", ".join(
[f"{k}: {pred[k]}" for k in pred])
print(
f"filename: {filename}, top-{topk}, {pred_str}")
f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}"
)
img_list = []
img_path_list = []
......
......@@ -65,6 +65,7 @@ from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
# help whl get all the models' api (class type) and components' api (func type)
def get_apis():
current_func = sys._getframe().f_code.co_name
current_module = sys.modules[__name__]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册