From 2823e48be50f8d7844611d767a5fb2dd1ab03c79 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 24 May 2023 14:58:33 +0000 Subject: [PATCH] fix head_init_scale --- .../arch/backbone/model_zoo/foundation_vit.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/foundation_vit.py b/ppcls/arch/backbone/model_zoo/foundation_vit.py index 4f360e2d..12cc699e 100644 --- a/ppcls/arch/backbone/model_zoo/foundation_vit.py +++ b/ppcls/arch/backbone/model_zoo/foundation_vit.py @@ -579,7 +579,6 @@ class VisionTransformer(nn.Layer): _model_diff = eval(f'_{self.model_name}_diff') self.class_num = class_num - self.head_init_scale = head_init_scale self.return_embed = kwargs.get('return_embed', True) self.num_features = self.embed_dim = embed_dim _img_size = to_2tuple(img_size) @@ -647,8 +646,21 @@ class VisionTransformer(nn.Layer): trunc_normal_(self.pos_embed) if not _model_size in _model_diff['remove_cls_token']: trunc_normal_(self.cls_token) + self.apply(self._init_weights) + if head_init_scale != 1: + if not self.return_embed and class_num > 0: + self.head.fc_head.weight.set_value( + self.head.fc_head.weight * + paddle.to_tensor(head_init_scale)) + self.head.fc_head.bias.set_value( + self.head.fc_head.bias * paddle.to_tensor(head_init_scale)) + else: + logger.warning( + "Because the head or head.fc_head of ViT is Identity() class, the argument head_init_scale is invalid." + ) + def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight) @@ -658,20 +670,6 @@ class VisionTransformer(nn.Layer): zeros_(m.bias) ones_(m.weight) - if self.head_init_scale != 1: - if isinstance(self.head, Head) and isinstance(self.head.fc_head, - nn.Linear): - paddle.assign(self.head.fc_head.weight * - paddle.to_tensor(self.head_init_scale), - self.head.fc_head.weight) - paddle.assign(self.head.fc_head.bias * - paddle.to_tensor(self.head_init_scale), - self.head.fc_head.bias) - else: - logger.warning( - "Because the head or head.fc_head of ViT is Identity() class, the argument head_init_scale is invalid." - ) - def get_num_layers(self): return len(self.blocks) -- GitLab