提交 6d924f85 编写于 作者: G gaotingquan 提交者: cuicheng01

fix for clip

1. fix bias_attr to False for conv of PatchEmbed;
2. support return_tokens_mean for Head of CLIP;
3. support remove_cls_token_in_forward for CLIP;
4. support head_init_scale argument for ViT backbone;
5. support get_num_layers() and no_weight_decay() for ViT backbone.
上级 2e9d8534
...@@ -23,6 +23,7 @@ import paddle.nn as nn ...@@ -23,6 +23,7 @@ import paddle.nn as nn
import sys import sys
from paddle.nn.initializer import TruncatedNormal, Constant, Normal from paddle.nn.initializer import TruncatedNormal, Constant, Normal
from ....utils import logger
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
...@@ -77,7 +78,9 @@ _CLIP_diff = { ...@@ -77,7 +78,9 @@ _CLIP_diff = {
'fc_norm': [], 'fc_norm': [],
'return_all_tokens': [], 'return_all_tokens': [],
'return_patch_tokens': [], 'return_patch_tokens': [],
} 'return_tokens_mean': ['vit_base_patch16_224'],
},
'remove_cls_token_in_forward': ['vit_base_patch16_224'],
} }
_MOCOV3_diff = { _MOCOV3_diff = {
...@@ -92,7 +95,9 @@ _MOCOV3_diff = { ...@@ -92,7 +95,9 @@ _MOCOV3_diff = {
'fc_norm': [], 'fc_norm': [],
'return_all_tokens': [], 'return_all_tokens': [],
'return_patch_tokens': [], 'return_patch_tokens': [],
} 'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
} }
_CoCa_diff = { _CoCa_diff = {
...@@ -107,7 +112,9 @@ _CoCa_diff = { ...@@ -107,7 +112,9 @@ _CoCa_diff = {
'fc_norm': [], 'fc_norm': [],
'return_all_tokens': [], 'return_all_tokens': [],
'return_patch_tokens': [], 'return_patch_tokens': [],
} 'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
} }
_BEiTv2_diff = { _BEiTv2_diff = {
...@@ -124,7 +131,9 @@ _BEiTv2_diff = { ...@@ -124,7 +131,9 @@ _BEiTv2_diff = {
'fc_norm': [], 'fc_norm': [],
'return_all_tokens': [], 'return_all_tokens': [],
'return_patch_tokens': [], 'return_patch_tokens': [],
} 'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
} }
_CAE_diff = { _CAE_diff = {
...@@ -139,7 +148,9 @@ _CAE_diff = { ...@@ -139,7 +148,9 @@ _CAE_diff = {
'fc_norm': [], # 3 x 197 x 786 'fc_norm': [], # 3 x 197 x 786
'return_all_tokens': [], # 3 x 197 x 1000 'return_all_tokens': [], # 3 x 197 x 1000
'return_patch_tokens': [], # 3 x 196 x 1000 'return_patch_tokens': [], # 3 x 196 x 1000
} 'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
} }
_EVA_diff = { _EVA_diff = {
...@@ -154,7 +165,9 @@ _EVA_diff = { ...@@ -154,7 +165,9 @@ _EVA_diff = {
'fc_norm': ['vit_huge_patch14'], 'fc_norm': ['vit_huge_patch14'],
'return_all_tokens': [], 'return_all_tokens': [],
'return_patch_tokens': [], 'return_patch_tokens': [],
} 'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
} }
_MAE_diff = { _MAE_diff = {
...@@ -169,7 +182,9 @@ _MAE_diff = { ...@@ -169,7 +182,9 @@ _MAE_diff = {
'fc_norm': ['vit_huge_patch14'], 'fc_norm': ['vit_huge_patch14'],
'return_all_tokens': [], 'return_all_tokens': [],
'return_patch_tokens': [], 'return_patch_tokens': [],
} 'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
} }
trunc_normal_ = TruncatedNormal(std=.02) trunc_normal_ = TruncatedNormal(std=.02)
...@@ -478,7 +493,11 @@ class PatchEmbed(nn.Layer): ...@@ -478,7 +493,11 @@ class PatchEmbed(nn.Layer):
self.num_patches = num_patches self.num_patches = num_patches
self.proj = nn.Conv2D( self.proj = nn.Conv2D(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias_attr=False)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
...@@ -500,6 +519,7 @@ class Head(nn.Layer): ...@@ -500,6 +519,7 @@ class Head(nn.Layer):
epsilon=1e-5) if model_size in setting['fc_norm'] else None epsilon=1e-5) if model_size in setting['fc_norm'] else None
self.return_all_tokens = model_size in setting['return_all_tokens'] self.return_all_tokens = model_size in setting['return_all_tokens']
self.return_patch_tokens = model_size in setting['return_patch_tokens'] self.return_patch_tokens = model_size in setting['return_patch_tokens']
self.return_tokens_mean = model_size in setting['return_tokens_mean']
self.fc_head = nn.Linear(embed_dim, self.fc_head = nn.Linear(embed_dim,
class_num) if class_num > 0 else Identity() class_num) if class_num > 0 else Identity()
...@@ -519,6 +539,8 @@ class Head(nn.Layer): ...@@ -519,6 +539,8 @@ class Head(nn.Layer):
x = x x = x
elif self.return_patch_tokens: elif self.return_patch_tokens:
x = x[:, 1:] x = x[:, 1:]
elif self.return_tokens_mean:
x = x.mean(1)
else: else:
x = x[:, 0] x = x[:, 0]
return self.fc_head(x) return self.fc_head(x)
...@@ -545,6 +567,7 @@ class VisionTransformer(nn.Layer): ...@@ -545,6 +567,7 @@ class VisionTransformer(nn.Layer):
drop_path_rate=0., drop_path_rate=0.,
norm_layer='nn.LayerNorm', norm_layer='nn.LayerNorm',
epsilon=1e-5, epsilon=1e-5,
head_init_scale=1,
**kwargs): **kwargs):
super().__init__() super().__init__()
global _model_diff global _model_diff
...@@ -556,6 +579,7 @@ class VisionTransformer(nn.Layer): ...@@ -556,6 +579,7 @@ class VisionTransformer(nn.Layer):
_model_diff = eval(f'_{self.model_name}_diff') _model_diff = eval(f'_{self.model_name}_diff')
self.class_num = class_num self.class_num = class_num
self.head_init_scale = head_init_scale
self.return_embed = kwargs.get('return_embed', True) self.return_embed = kwargs.get('return_embed', True)
self.num_features = self.embed_dim = embed_dim self.num_features = self.embed_dim = embed_dim
_img_size = to_2tuple(img_size) _img_size = to_2tuple(img_size)
...@@ -634,6 +658,26 @@ class VisionTransformer(nn.Layer): ...@@ -634,6 +658,26 @@ class VisionTransformer(nn.Layer):
zeros_(m.bias) zeros_(m.bias)
ones_(m.weight) ones_(m.weight)
if self.head_init_scale != 1:
if isinstance(self.head, Head) and isinstance(self.head.fc_head,
nn.Linear):
paddle.assign(self.head.fc_head.weight *
paddle.to_tensor(self.head_init_scale),
self.head.fc_head.weight)
paddle.assign(self.head.fc_head.bias *
paddle.to_tensor(self.head_init_scale),
self.head.fc_head.bias)
else:
logger.warning(
"Because the head or head.fc_head of ViT is Identity() class, the argument head_init_scale is invalid."
)
def get_num_layers(self):
return len(self.blocks)
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x): def forward_features(self, x):
# B = x.shape[0] # B = x.shape[0]
B = paddle.shape(x)[0] B = paddle.shape(x)[0]
...@@ -651,6 +695,9 @@ class VisionTransformer(nn.Layer): ...@@ -651,6 +695,9 @@ class VisionTransformer(nn.Layer):
'rel_pos_bias') else None 'rel_pos_bias') else None
for blk in self.blocks: for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias) x = blk(x, rel_pos_bias=rel_pos_bias)
if _model_size in _model_diff['remove_cls_token_in_forward']:
x = x[:, 1:, :]
x = self.norm(x) x = self.norm(x)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册