diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 6b780b7743e2d54ad7ddaaac845c531b058fac7b..fb559e19acbd2847e3e81cec1a0f8c7328991be3 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -75,7 +75,7 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384 from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384 from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 -from .model_zoo.cvt import cvt_13_224x224, cvt_13_384x384, cvt_21_224x224, cvt_21_384x384 +from .model_zoo.cvt import CvT_13_224, CvT_13_384, CvT_21_224, CvT_21_384 from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d diff --git a/ppcls/arch/backbone/model_zoo/cvt.py b/ppcls/arch/backbone/model_zoo/cvt.py index c5b7aeb3620a5b4a58f828620f08acd84e9dfb3e..ec31e81bf673359a33feab784608c0ee63133b52 100644 --- a/ppcls/arch/backbone/model_zoo/cvt.py +++ b/ppcls/arch/backbone/model_zoo/cvt.py @@ -22,10 +22,10 @@ from paddle.nn.initializer import XavierUniform, TruncatedNormal, Constant from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url MODEL_URLS = { - "cvt_13_224x224": "", # TODO - "cvt_13_384x384": "", # TODO - "cvt_21_224x224": "", # TODO - "cvt_21_384x384": "", # TODO + "CvT_13_224": "", # TODO + "CvT_13_384": "", # TODO + "CvT_21_224": "", # TODO + "CvT_21_384": "", # TODO } __all__ = list(MODEL_URLS.keys()) @@ -521,7 +521,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ) -def cvt_13_224x224(pretrained=False, use_ssld=False, **kwargs): +def CvT_13_224(pretrained=False, use_ssld=False, **kwargs): msvit_spec = dict( INIT='trunc_norm', NUM_STAGES=3, @@ -551,11 +551,11 @@ def cvt_13_224x224(pretrained=False, use_ssld=False, **kwargs): spec=msvit_spec, **kwargs) _load_pretrained( - pretrained, model, MODEL_URLS["cvt_13_224x224"], use_ssld=use_ssld) + pretrained, model, MODEL_URLS["CvT_13_224"], use_ssld=use_ssld) return model -def cvt_13_384x384(pretrained=False, use_ssld=False, **kwargs): +def CvT_13_384(pretrained=False, use_ssld=False, **kwargs): msvit_spec = dict( INIT='trunc_norm', NUM_STAGES=3, @@ -585,11 +585,11 @@ def cvt_13_384x384(pretrained=False, use_ssld=False, **kwargs): spec=msvit_spec, **kwargs) _load_pretrained( - pretrained, model, MODEL_URLS["cvt_13_384x384"], use_ssld=use_ssld) + pretrained, model, MODEL_URLS["CvT_13_384"], use_ssld=use_ssld) return model -def cvt_21_224x224(pretrained=False, use_ssld=False, **kwargs): +def CvT_21_224(pretrained=False, use_ssld=False, **kwargs): msvit_spec = dict( INIT='trunc_norm', NUM_STAGES=3, @@ -619,11 +619,11 @@ def cvt_21_224x224(pretrained=False, use_ssld=False, **kwargs): spec=msvit_spec, **kwargs) _load_pretrained( - pretrained, model, MODEL_URLS["cvt_21_224x224"], use_ssld=use_ssld) + pretrained, model, MODEL_URLS["CvT_21_224"], use_ssld=use_ssld) return model -def cvt_21_384x384(pretrained=False, use_ssld=False, **kwargs): +def CvT_21_384(pretrained=False, use_ssld=False, **kwargs): msvit_spec = dict( INIT='trunc_norm', NUM_STAGES=3, @@ -653,5 +653,5 @@ def cvt_21_384x384(pretrained=False, use_ssld=False, **kwargs): spec=msvit_spec, **kwargs) _load_pretrained( - pretrained, model, MODEL_URLS["cvt_21_384x384"], use_ssld=use_ssld) + pretrained, model, MODEL_URLS["CvT_21_384"], use_ssld=use_ssld) return model diff --git a/ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml b/ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml index 4db49c24e36120c3ab44798706369b94931b354c..e982454a624a5b695709643ba225db3c953dd1de 100644 --- a/ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml +++ b/ppcls/configs/ImageNet/CvT/cvt_13_224x224.yaml @@ -19,7 +19,7 @@ Global: # model architecture Arch: - name: cvt_13_224x224 + name: CvT_13_224 class_num: 1000 # loss function config for traing/eval process