提交 2823e48b 编写于 作者: G gaotingquan 提交者: cuicheng01

fix head_init_scale

上级 042d1e7e
......@@ -579,7 +579,6 @@ class VisionTransformer(nn.Layer):
_model_diff = eval(f'_{self.model_name}_diff')
self.class_num = class_num
self.head_init_scale = head_init_scale
self.return_embed = kwargs.get('return_embed', True)
self.num_features = self.embed_dim = embed_dim
_img_size = to_2tuple(img_size)
......@@ -647,8 +646,21 @@ class VisionTransformer(nn.Layer):
trunc_normal_(self.pos_embed)
if not _model_size in _model_diff['remove_cls_token']:
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
if head_init_scale != 1:
if not self.return_embed and class_num > 0:
self.head.fc_head.weight.set_value(
self.head.fc_head.weight *
paddle.to_tensor(head_init_scale))
self.head.fc_head.bias.set_value(
self.head.fc_head.bias * paddle.to_tensor(head_init_scale))
else:
logger.warning(
"Because the head or head.fc_head of ViT is Identity() class, the argument head_init_scale is invalid."
)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
......@@ -658,20 +670,6 @@ class VisionTransformer(nn.Layer):
zeros_(m.bias)
ones_(m.weight)
if self.head_init_scale != 1:
if isinstance(self.head, Head) and isinstance(self.head.fc_head,
nn.Linear):
paddle.assign(self.head.fc_head.weight *
paddle.to_tensor(self.head_init_scale),
self.head.fc_head.weight)
paddle.assign(self.head.fc_head.bias *
paddle.to_tensor(self.head_init_scale),
self.head.fc_head.bias)
else:
logger.warning(
"Because the head or head.fc_head of ViT is Identity() class, the argument head_init_scale is invalid."
)
def get_num_layers(self):
return len(self.blocks)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册