提交 3a871524 编写于 作者: C cuicheng01

Update gvt.py

上级 e0767460
...@@ -56,10 +56,10 @@ class GroupAttention(nn.Layer): ...@@ -56,10 +56,10 @@ class GroupAttention(nn.Layer):
ws=1): ws=1):
super().__init__() super().__init__()
if ws == 1: if ws == 1:
raise Exception(f"ws {ws} should not be 1") raise Exception("ws {ws} should not be 1")
if dim % num_heads != 0: if dim % num_heads != 0:
raise Exception( raise Exception(
f"dim {dim} should be divided by num_heads {num_heads}.") "dim {dim} should be divided by num_heads {num_heads}.")
self.dim = dim self.dim = dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -82,11 +82,11 @@ class GroupAttention(nn.Layer): ...@@ -82,11 +82,11 @@ class GroupAttention(nn.Layer):
B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads
]).transpose([3, 0, 1, 4, 2, 5]) ]).transpose([3, 0, 1, 4, 2, 5])
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale attn = (q @ k.transpose([0, 1, 2, 4, 3])) * self.scale
attn = nn.Softmax(axis=-1)(attn) attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
attn = (attn @v).transpose([0, 1, 3, 2, 4]).reshape( attn = (attn @ v).transpose([0, 1, 3, 2, 4]).reshape(
[B, h_group, w_group, self.ws, self.ws, C]) [B, h_group, w_group, self.ws, self.ws, C])
x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C]) x = attn.transpose([0, 1, 3, 2, 4, 5]).reshape([B, N, C])
...@@ -147,11 +147,11 @@ class Attention(nn.Layer): ...@@ -147,11 +147,11 @@ class Attention(nn.Layer):
[2, 0, 3, 1, 4]) [2, 0, 3, 1, 4])
k, v = kv[0], kv[1] k, v = kv[0], kv[1]
attn = (q @k.transpose([0, 1, 3, 2])) * self.scale attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
attn = nn.Softmax(axis=-1)(attn) attn = nn.Softmax(axis=-1)(attn)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C]) x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C])
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
...@@ -281,7 +281,7 @@ class PyramidVisionTransformer(nn.Layer): ...@@ -281,7 +281,7 @@ class PyramidVisionTransformer(nn.Layer):
img_size=224, img_size=224,
patch_size=16, patch_size=16,
in_chans=3, in_chans=3,
num_classes=1000, class_num=1000,
embed_dims=[64, 128, 256, 512], embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4], mlp_ratios=[4, 4, 4, 4],
...@@ -295,7 +295,7 @@ class PyramidVisionTransformer(nn.Layer): ...@@ -295,7 +295,7 @@ class PyramidVisionTransformer(nn.Layer):
sr_ratios=[8, 4, 2, 1], sr_ratios=[8, 4, 2, 1],
block_cls=Block): block_cls=Block):
super().__init__() super().__init__()
self.num_classes = num_classes self.class_num = class_num
self.depths = depths self.depths = depths
# patch_embed # patch_embed
...@@ -354,7 +354,7 @@ class PyramidVisionTransformer(nn.Layer): ...@@ -354,7 +354,7 @@ class PyramidVisionTransformer(nn.Layer):
# classification head # classification head
self.head = nn.Linear(embed_dims[-1], self.head = nn.Linear(embed_dims[-1],
num_classes) if num_classes > 0 else Identity() class_num) if class_num > 0 else Identity()
# init weights # init weights
for pos_emb in self.pos_embeds: for pos_emb in self.pos_embeds:
...@@ -433,7 +433,7 @@ class CPVTV2(PyramidVisionTransformer): ...@@ -433,7 +433,7 @@ class CPVTV2(PyramidVisionTransformer):
img_size=224, img_size=224,
patch_size=4, patch_size=4,
in_chans=3, in_chans=3,
num_classes=1000, class_num=1000,
embed_dims=[64, 128, 256, 512], embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4], mlp_ratios=[4, 4, 4, 4],
...@@ -446,10 +446,10 @@ class CPVTV2(PyramidVisionTransformer): ...@@ -446,10 +446,10 @@ class CPVTV2(PyramidVisionTransformer):
depths=[3, 4, 6, 3], depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1], sr_ratios=[8, 4, 2, 1],
block_cls=Block): block_cls=Block):
super().__init__(img_size, patch_size, in_chans, num_classes, super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale, num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer, attn_drop_rate, drop_path_rate, norm_layer, depths,
depths, sr_ratios, block_cls) sr_ratios, block_cls)
del self.pos_embeds del self.pos_embeds
del self.cls_token del self.cls_token
self.pos_block = nn.LayerList( self.pos_block = nn.LayerList(
...@@ -499,7 +499,7 @@ class PCPVT(CPVTV2): ...@@ -499,7 +499,7 @@ class PCPVT(CPVTV2):
img_size=224, img_size=224,
patch_size=4, patch_size=4,
in_chans=3, in_chans=3,
num_classes=1000, class_num=1000,
embed_dims=[64, 128, 256], embed_dims=[64, 128, 256],
num_heads=[1, 2, 4], num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4], mlp_ratios=[4, 4, 4],
...@@ -512,10 +512,10 @@ class PCPVT(CPVTV2): ...@@ -512,10 +512,10 @@ class PCPVT(CPVTV2):
depths=[4, 4, 4], depths=[4, 4, 4],
sr_ratios=[4, 2, 1], sr_ratios=[4, 2, 1],
block_cls=SBlock): block_cls=SBlock):
super().__init__(img_size, patch_size, in_chans, num_classes, super().__init__(img_size, patch_size, in_chans, class_num, embed_dims,
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale, num_heads, mlp_ratios, qkv_bias, qk_scale, drop_rate,
drop_rate, attn_drop_rate, drop_path_rate, norm_layer, attn_drop_rate, drop_path_rate, norm_layer, depths,
depths, sr_ratios, block_cls) sr_ratios, block_cls)
class ALTGVT(PCPVT): class ALTGVT(PCPVT):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册