提交 3a871524 编写于 作者: C cuicheng01

Update gvt.py

上级 e0767460
......@@ -56,10 +56,10 @@ class GroupAttention(nn.Layer):
ws=1):
super().__init__()
if ws == 1:
raise Exception(f"ws {ws} should not be 1")
raise Exception("ws {ws} should not be 1")
if dim % num_heads != 0:
raise Exception(
f"dim {dim} should be divided by num_heads {num_heads}.")
"dim {dim} should be divided by num_heads {num_heads}.")
self.dim = dim
self.num_heads = num_heads
......@@ -82,11 +82,11 @@ class GroupAttention(nn.Layer):
B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads
]).transpose([3, 0, 1, 4, 2, 5])
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale
attn = (q @ k.transpose([0, 1, 2, 4, 3])) * self.scale
attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn)
attn = (attn @v).transpose([0, 1, 3, 2, 4]).reshape(
attn = (attn @ v).transpose([0, 1, 3, 2, 4]).reshape(
[B, h_group, w_group, self.ws, self.ws, C])
x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C])
......@@ -147,11 +147,11 @@ class Attention(nn.Layer):
[2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
attn = (q @k.transpose([0, 1, 3, 2])) * self.scale
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn)
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
......@@ -281,7 +281,7 @@ class PyramidVisionTransformer(nn.Layer):
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
class_num=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
......@@ -295,7 +295,7 @@ class PyramidVisionTransformer(nn.Layer):
sr_ratios=[8, 4, 2, 1],
block_cls=Block):
super().__init__()
self.num_classes = num_classes
self.class_num = class_num
self.depths = depths
# patch_embed
......@@ -354,7 +354,7 @@ class PyramidVisionTransformer(nn.Layer):
# classification head
self.head = nn.Linear(embed_dims[-1],
num_classes) if num_classes > 0 else Identity()
class_num) if class_num > 0 else Identity()
# init weights
for pos_emb in self.pos_embeds:
......@@ -433,7 +433,7 @@ class CPVTV2(PyramidVisionTransformer):
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
class_num=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
......@@ -446,10 +446,10 @@ class CPVTV2(PyramidVisionTransformer):
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
block_cls=Block):
super().__init__(img_size, patch_size, in_chans, num_classes,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
depths, sr_ratios, block_cls)
super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
attn_drop_rate, drop_path_rate, norm_layer, depths,
sr_ratios, block_cls)
del self.pos_embeds
del self.cls_token
self.pos_block = nn.LayerList(
......@@ -499,7 +499,7 @@ class PCPVT(CPVTV2):
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
class_num=1000,
embed_dims=[64, 128, 256],
num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4],
......@@ -512,10 +512,10 @@ class PCPVT(CPVTV2):
depths=[4, 4, 4],
sr_ratios=[4, 2, 1],
block_cls=SBlock):
super().__init__(img_size, patch_size, in_chans, num_classes,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
depths, sr_ratios, block_cls)
super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
attn_drop_rate, drop_path_rate, norm_layer, depths,
sr_ratios, block_cls)
class ALTGVT(PCPVT):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册