提交 4cfd2159 编写于 作者: Y Yang Nie 提交者: Tingquan Gao

rename cvt_{depth}_{size}x{size} to CvT_{depth}_{size}

上级 5c2a5655
...@@ -75,7 +75,7 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p ...@@ -75,7 +75,7 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p
from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384 from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384
from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384 from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .model_zoo.cvt import cvt_13_224x224, cvt_13_384x384, cvt_21_224x224, cvt_21_384x384 from .model_zoo.cvt import CvT_13_224, CvT_13_384, CvT_21_224, CvT_21_384
from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
......
...@@ -22,10 +22,10 @@ from paddle.nn.initializer import XavierUniform, TruncatedNormal, Constant ...@@ -22,10 +22,10 @@ from paddle.nn.initializer import XavierUniform, TruncatedNormal, Constant
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
"cvt_13_224x224": "", # TODO "CvT_13_224": "", # TODO
"cvt_13_384x384": "", # TODO "CvT_13_384": "", # TODO
"cvt_21_224x224": "", # TODO "CvT_21_224": "", # TODO
"cvt_21_384x384": "", # TODO "CvT_21_384": "", # TODO
} }
__all__ = list(MODEL_URLS.keys()) __all__ = list(MODEL_URLS.keys())
...@@ -521,7 +521,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ...@@ -521,7 +521,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
) )
def cvt_13_224x224(pretrained=False, use_ssld=False, **kwargs): def CvT_13_224(pretrained=False, use_ssld=False, **kwargs):
msvit_spec = dict( msvit_spec = dict(
INIT='trunc_norm', INIT='trunc_norm',
NUM_STAGES=3, NUM_STAGES=3,
...@@ -551,11 +551,11 @@ def cvt_13_224x224(pretrained=False, use_ssld=False, **kwargs): ...@@ -551,11 +551,11 @@ def cvt_13_224x224(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["cvt_13_224x224"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["CvT_13_224"], use_ssld=use_ssld)
return model return model
def cvt_13_384x384(pretrained=False, use_ssld=False, **kwargs): def CvT_13_384(pretrained=False, use_ssld=False, **kwargs):
msvit_spec = dict( msvit_spec = dict(
INIT='trunc_norm', INIT='trunc_norm',
NUM_STAGES=3, NUM_STAGES=3,
...@@ -585,11 +585,11 @@ def cvt_13_384x384(pretrained=False, use_ssld=False, **kwargs): ...@@ -585,11 +585,11 @@ def cvt_13_384x384(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["cvt_13_384x384"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["CvT_13_384"], use_ssld=use_ssld)
return model return model
def cvt_21_224x224(pretrained=False, use_ssld=False, **kwargs): def CvT_21_224(pretrained=False, use_ssld=False, **kwargs):
msvit_spec = dict( msvit_spec = dict(
INIT='trunc_norm', INIT='trunc_norm',
NUM_STAGES=3, NUM_STAGES=3,
...@@ -619,11 +619,11 @@ def cvt_21_224x224(pretrained=False, use_ssld=False, **kwargs): ...@@ -619,11 +619,11 @@ def cvt_21_224x224(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["cvt_21_224x224"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["CvT_21_224"], use_ssld=use_ssld)
return model return model
def cvt_21_384x384(pretrained=False, use_ssld=False, **kwargs): def CvT_21_384(pretrained=False, use_ssld=False, **kwargs):
msvit_spec = dict( msvit_spec = dict(
INIT='trunc_norm', INIT='trunc_norm',
NUM_STAGES=3, NUM_STAGES=3,
...@@ -653,5 +653,5 @@ def cvt_21_384x384(pretrained=False, use_ssld=False, **kwargs): ...@@ -653,5 +653,5 @@ def cvt_21_384x384(pretrained=False, use_ssld=False, **kwargs):
spec=msvit_spec, spec=msvit_spec,
**kwargs) **kwargs)
_load_pretrained( _load_pretrained(
pretrained, model, MODEL_URLS["cvt_21_384x384"], use_ssld=use_ssld) pretrained, model, MODEL_URLS["CvT_21_384"], use_ssld=use_ssld)
return model return model
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
# model architecture # model architecture
Arch: Arch:
name: cvt_13_224x224 name: CvT_13_224
class_num: 1000 class_num: 1000
# loss function config for traing/eval process # loss function config for traing/eval process
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册