From e07674603ebe3bee505afa8737d6df2df0459345 Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Wed, 21 Jul 2021 20:50:04 +0800 Subject: [PATCH] Update gvt.py --- ppcls/arch/backbone/model_zoo/gvt.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/gvt.py b/ppcls/arch/backbone/model_zoo/gvt.py index 1818540b..c3171228 100644 --- a/ppcls/arch/backbone/model_zoo/gvt.py +++ b/ppcls/arch/backbone/model_zoo/gvt.py @@ -78,9 +78,9 @@ class GroupAttention(nn.Layer): total_groups = h_group * w_group x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose( [0, 1, 3, 2, 4, 5]) - qkv = self.qkv(x).reshape( - [B, total_groups, -1, 3, self.num_heads, - C // self.num_heads]).transpose([3, 0, 1, 4, 2, 5]) + qkv = self.qkv(x).reshape([ + B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads + ]).transpose([3, 0, 1, 4, 2, 5]) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale @@ -135,14 +135,15 @@ class Attention(nn.Layer): if self.sr_ratio > 1: x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W]) - x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1]) + tmp_n = H * W // self.sr_ratio**2 + x_ = self.sr(x_).reshape([B, C, tmp_n]).transpose([0, 2, 1]) x_ = self.norm(x_) kv = self.kv(x_).reshape( - [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( + [B, tmp_n, 2, self.num_heads, C // self.num_heads]).transpose( [2, 0, 3, 1, 4]) else: kv = self.kv(x).reshape( - [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( + [B, N, 2, self.num_heads, C // self.num_heads]).transpose( [2, 0, 3, 1, 4]) k, v = kv[0], kv[1] @@ -317,7 +318,6 @@ class PyramidVisionTransformer(nn.Layer): self.create_parameter( shape=[1, patch_num, embed_dims[i]], default_initializer=zeros_)) - self.add_parameter(f"pos_embeds_{i}", self.pos_embeds[i]) self.pos_drops.append(nn.Dropout(p=drop_rate)) dpr = [ @@ -433,7 +433,7 @@ class CPVTV2(PyramidVisionTransformer): img_size=224, patch_size=4, in_chans=3, - class_num=1000, + num_classes=1000, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], @@ -446,7 +446,7 @@ class CPVTV2(PyramidVisionTransformer): depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): - super().__init__(img_size, patch_size, in_chans, class_num, + super().__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, depths, sr_ratios, block_cls) @@ -488,7 +488,7 @@ class CPVTV2(PyramidVisionTransformer): x = self.pos_block[i](x, H, W) # PEG here if i < len(self.depths) - 1: - x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2]) + x = x.reshape([B, H, W, x.shape[-1]]).transpose([0, 3, 1, 2]) x = self.norm(x) return x.mean(axis=1) # GAP here @@ -499,7 +499,7 @@ class PCPVT(CPVTV2): img_size=224, patch_size=4, in_chans=3, - class_num=1000, + num_classes=1000, embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], @@ -512,7 +512,7 @@ class PCPVT(CPVTV2): depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=SBlock): - super().__init__(img_size, patch_size, in_chans, class_num, + super().__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, depths, sr_ratios, block_cls) -- GitLab