提交 6d924f85 编写于 作者: G gaotingquan 提交者: cuicheng01

fix for clip

1. fix bias_attr to False for conv of PatchEmbed;
2. support return_tokens_mean for Head of CLIP;
3. support remove_cls_token_in_forward for CLIP;
4. support head_init_scale argument for ViT backbone;
5. support get_num_layers() and no_weight_decay() for ViT backbone.
上级 2e9d8534
......@@ -23,6 +23,7 @@ import paddle.nn as nn
import sys
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
from ....utils import logger
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
......@@ -77,7 +78,9 @@ _CLIP_diff = {
'fc_norm': [],
'return_all_tokens': [],
'return_patch_tokens': [],
}
'return_tokens_mean': ['vit_base_patch16_224'],
},
'remove_cls_token_in_forward': ['vit_base_patch16_224'],
}
_MOCOV3_diff = {
......@@ -92,7 +95,9 @@ _MOCOV3_diff = {
'fc_norm': [],
'return_all_tokens': [],
'return_patch_tokens': [],
}
'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
}
_CoCa_diff = {
......@@ -107,7 +112,9 @@ _CoCa_diff = {
'fc_norm': [],
'return_all_tokens': [],
'return_patch_tokens': [],
}
'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
}
_BEiTv2_diff = {
......@@ -124,7 +131,9 @@ _BEiTv2_diff = {
'fc_norm': [],
'return_all_tokens': [],
'return_patch_tokens': [],
}
'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
}
_CAE_diff = {
......@@ -139,7 +148,9 @@ _CAE_diff = {
'fc_norm': [], # 3 x 197 x 786
'return_all_tokens': [], # 3 x 197 x 1000
'return_patch_tokens': [], # 3 x 196 x 1000
}
'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
}
_EVA_diff = {
......@@ -154,7 +165,9 @@ _EVA_diff = {
'fc_norm': ['vit_huge_patch14'],
'return_all_tokens': [],
'return_patch_tokens': [],
}
'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
}
_MAE_diff = {
......@@ -169,7 +182,9 @@ _MAE_diff = {
'fc_norm': ['vit_huge_patch14'],
'return_all_tokens': [],
'return_patch_tokens': [],
}
'return_tokens_mean': [],
},
'remove_cls_token_in_forward': [],
}
trunc_normal_ = TruncatedNormal(std=.02)
......@@ -478,7 +493,11 @@ class PatchEmbed(nn.Layer):
self.num_patches = num_patches
self.proj = nn.Conv2D(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias_attr=False)
def forward(self, x):
B, C, H, W = x.shape
......@@ -500,6 +519,7 @@ class Head(nn.Layer):
epsilon=1e-5) if model_size in setting['fc_norm'] else None
self.return_all_tokens = model_size in setting['return_all_tokens']
self.return_patch_tokens = model_size in setting['return_patch_tokens']
self.return_tokens_mean = model_size in setting['return_tokens_mean']
self.fc_head = nn.Linear(embed_dim,
class_num) if class_num > 0 else Identity()
......@@ -519,6 +539,8 @@ class Head(nn.Layer):
x = x
elif self.return_patch_tokens:
x = x[:, 1:]
elif self.return_tokens_mean:
x = x.mean(1)
else:
x = x[:, 0]
return self.fc_head(x)
......@@ -545,6 +567,7 @@ class VisionTransformer(nn.Layer):
drop_path_rate=0.,
norm_layer='nn.LayerNorm',
epsilon=1e-5,
head_init_scale=1,
**kwargs):
super().__init__()
global _model_diff
......@@ -556,6 +579,7 @@ class VisionTransformer(nn.Layer):
_model_diff = eval(f'_{self.model_name}_diff')
self.class_num = class_num
self.head_init_scale = head_init_scale
self.return_embed = kwargs.get('return_embed', True)
self.num_features = self.embed_dim = embed_dim
_img_size = to_2tuple(img_size)
......@@ -634,6 +658,26 @@ class VisionTransformer(nn.Layer):
zeros_(m.bias)
ones_(m.weight)
if self.head_init_scale != 1:
if isinstance(self.head, Head) and isinstance(self.head.fc_head,
nn.Linear):
paddle.assign(self.head.fc_head.weight *
paddle.to_tensor(self.head_init_scale),
self.head.fc_head.weight)
paddle.assign(self.head.fc_head.bias *
paddle.to_tensor(self.head_init_scale),
self.head.fc_head.bias)
else:
logger.warning(
"Because the head or head.fc_head of ViT is Identity() class, the argument head_init_scale is invalid."
)
def get_num_layers(self):
return len(self.blocks)
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
# B = x.shape[0]
B = paddle.shape(x)[0]
......@@ -651,6 +695,9 @@ class VisionTransformer(nn.Layer):
'rel_pos_bias') else None
for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias)
if _model_size in _model_diff['remove_cls_token_in_forward']:
x = x[:, 1:, :]
x = self.norm(x)
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册