未验证 提交 e0767460 编写于 作者: T Tingquan Gao 提交者: GitHub

Update gvt.py

上级 88d0d4ca
......@@ -78,9 +78,9 @@ class GroupAttention(nn.Layer):
total_groups = h_group * w_group
x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose(
[0, 1, 3, 2, 4, 5])
qkv = self.qkv(x).reshape(
[B, total_groups, -1, 3, self.num_heads,
C // self.num_heads]).transpose([3, 0, 1, 4, 2, 5])
qkv = self.qkv(x).reshape([
B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads
]).transpose([3, 0, 1, 4, 2, 5])
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale
......@@ -135,14 +135,15 @@ class Attention(nn.Layer):
if self.sr_ratio > 1:
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
tmp_n = H * W // self.sr_ratio**2
x_ = self.sr(x_).reshape([B, C, tmp_n]).transpose([0, 2, 1])
x_ = self.norm(x_)
kv = self.kv(x_).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[B, tmp_n, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
else:
kv = self.kv(x).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[B, N, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
......@@ -317,7 +318,6 @@ class PyramidVisionTransformer(nn.Layer):
self.create_parameter(
shape=[1, patch_num, embed_dims[i]],
default_initializer=zeros_))
self.add_parameter(f"pos_embeds_{i}", self.pos_embeds[i])
self.pos_drops.append(nn.Dropout(p=drop_rate))
dpr = [
......@@ -433,7 +433,7 @@ class CPVTV2(PyramidVisionTransformer):
img_size=224,
patch_size=4,
in_chans=3,
class_num=1000,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
......@@ -446,7 +446,7 @@ class CPVTV2(PyramidVisionTransformer):
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
block_cls=Block):
super().__init__(img_size, patch_size, in_chans, class_num,
super().__init__(img_size, patch_size, in_chans, num_classes,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
depths, sr_ratios, block_cls)
......@@ -488,7 +488,7 @@ class CPVTV2(PyramidVisionTransformer):
x = self.pos_block[i](x, H, W) # PEG here
if i < len(self.depths) - 1:
x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
x = x.reshape([B, H, W, x.shape[-1]]).transpose([0, 3, 1, 2])
x = self.norm(x)
return x.mean(axis=1) # GAP here
......@@ -499,7 +499,7 @@ class PCPVT(CPVTV2):
img_size=224,
patch_size=4,
in_chans=3,
class_num=1000,
num_classes=1000,
embed_dims=[64, 128, 256],
num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4],
......@@ -512,7 +512,7 @@ class PCPVT(CPVTV2):
depths=[4, 4, 4],
sr_ratios=[4, 2, 1],
block_cls=SBlock):
super().__init__(img_size, patch_size, in_chans, class_num,
super().__init__(img_size, patch_size, in_chans, num_classes,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
depths, sr_ratios, block_cls)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册