提交 c492e1b2 编写于 作者: C cuicheng01

Update vision_transformer.py

上级 36a35727
......@@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer):
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
epsilon=1e-5,
**args):
**kwargs):
super().__init__()
self.class_dim = class_dim
......@@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
)
def ViT_small_patch16_224(pretrained,
model,
model_url,
def ViT_small_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained,
return model
def ViT_base_patch16_224(pretrained,
model,
model_url,
def ViT_base_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained,
return model
def ViT_base_patch16_384(pretrained,
model,
model_url,
def ViT_base_patch16_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained,
return model
def ViT_base_patch32_384(pretrained,
model,
model_url,
def ViT_base_patch32_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained,
return model
def ViT_large_patch16_224(pretrained,
model,
model_url,
def ViT_large_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained,
return model
def ViT_large_patch16_384(pretrained,
model,
model_url,
def ViT_large_patch16_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained,
return model
def ViT_large_patch32_384(pretrained,
model,
model_url,
def ViT_large_patch32_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained,
return model
def ViT_huge_patch16_224(pretrained,
model,
model_url,
def ViT_huge_patch16_224(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......@@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained,
return model
def ViT_huge_patch32_384(pretrained,
model,
model_url,
def ViT_huge_patch32_384(pretrained=False,
use_ssld=False,
**kwargs):
model = VisionTransformer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册