From be37ba870f737e4bce6e035fc2988ce78093c35e Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 10 Nov 2021 06:16:35 +0000 Subject: [PATCH] feat: support Twins & PPLCNet --- __init__.py | 1 + paddleclas.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/__init__.py b/__init__.py index b8b43616..2128a6cc 100644 --- a/__init__.py +++ b/__init__.py @@ -14,3 +14,4 @@ __all__ = ['PaddleClas'] from .paddleclas import PaddleClas +from ppcls.arch.backbone import * diff --git a/paddleclas.py b/paddleclas.py index 91cd030a..58a3cddd 100644 --- a/paddleclas.py +++ b/paddleclas.py @@ -38,6 +38,7 @@ from deploy.utils.get_image_list import get_image_list from deploy.utils import config from ppcls.arch.backbone import * +from ppcls.utils.logger import init_logger __all__ = ["PaddleClas"] @@ -89,6 +90,10 @@ MODEL_SERIES = { "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25", "MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld" ], + "PPLCNet": [ + "PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75", + "PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5" + ], "RegNet": ["RegNetX_4GF"], "Res2Net": [ "Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s", @@ -134,6 +139,10 @@ MODEL_SERIES = { "SwinTransformer_small_patch4_window7_224", "SwinTransformer_tiny_patch4_window7_224" ], + "Twins": [ + "pcpvt_small", "pcpvt_base", "pcpvt_large", "alt_gvt_small", + "alt_gvt_base", "alt_gvt_large" + ], "VGG": ["VGG11", "VGG13", "VGG16", "VGG19"], "VisionTransformer": [ "ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384", @@ -399,6 +408,7 @@ class PaddleClas(object): """PaddleClas. """ + init_logger(name='root') print_info() def __init__(self, @@ -465,7 +475,7 @@ class PaddleClas(object): """Predict input_data. Args: - input_data (Union[str, np.array]): + input_data (Union[str, np.array]): When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet. When the type is np.array, it is the image data whose channel order is RGB. print_pred (bool, optional): Whether print the prediction result. Defaults to False. Defaults to False. @@ -474,9 +484,9 @@ class PaddleClas(object): ImageTypeError: Illegal input_data. Yields: - Generator[list, None, None]: - The prediction result(s) of input_data by batch_size. For every one image, - prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". + Generator[list, None, None]: + The prediction result(s) of input_data by batch_size. For every one image, + prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names". The format is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...] """ -- GitLab