提交 c492e1b2 编写于 作者: C cuicheng01

Update vision_transformer.py

上级 36a35727
...@@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer): ...@@ -243,7 +243,7 @@ class VisionTransformer(nn.Layer):
drop_path_rate=0., drop_path_rate=0.,
norm_layer='nn.LayerNorm', norm_layer='nn.LayerNorm',
epsilon=1e-5, epsilon=1e-5,
**args): **kwargs):
super().__init__() super().__init__()
self.class_dim = class_dim self.class_dim = class_dim
...@@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False): ...@@ -331,9 +331,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld=False):
) )
def ViT_small_patch16_224(pretrained, def ViT_small_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained, ...@@ -352,9 +350,7 @@ def ViT_small_patch16_224(pretrained,
return model return model
def ViT_base_patch16_224(pretrained, def ViT_base_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained, ...@@ -374,9 +370,7 @@ def ViT_base_patch16_224(pretrained,
return model return model
def ViT_base_patch16_384(pretrained, def ViT_base_patch16_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained, ...@@ -397,9 +391,7 @@ def ViT_base_patch16_384(pretrained,
return model return model
def ViT_base_patch32_384(pretrained, def ViT_base_patch32_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained, ...@@ -420,9 +412,7 @@ def ViT_base_patch32_384(pretrained,
return model return model
def ViT_large_patch16_224(pretrained, def ViT_large_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained, ...@@ -442,9 +432,7 @@ def ViT_large_patch16_224(pretrained,
return model return model
def ViT_large_patch16_384(pretrained, def ViT_large_patch16_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained, ...@@ -465,9 +453,7 @@ def ViT_large_patch16_384(pretrained,
return model return model
def ViT_large_patch32_384(pretrained, def ViT_large_patch32_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained, ...@@ -488,9 +474,7 @@ def ViT_large_patch32_384(pretrained,
return model return model
def ViT_huge_patch16_224(pretrained, def ViT_huge_patch16_224(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
...@@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained, ...@@ -508,9 +492,7 @@ def ViT_huge_patch16_224(pretrained,
return model return model
def ViT_huge_patch32_384(pretrained, def ViT_huge_patch32_384(pretrained=False,
model,
model_url,
use_ssld=False, use_ssld=False,
**kwargs): **kwargs):
model = VisionTransformer( model = VisionTransformer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册