提交 03a3da16 编写于 作者: C cuicheng01

Update tnt.py

上级 f731ac54
...@@ -44,7 +44,7 @@ def drop_path(x, drop_prob=0., training=False): ...@@ -44,7 +44,7 @@ def drop_path(x, drop_prob=0., training=False):
return x return x
keep_prob = paddle.to_tensor(1 - drop_prob) keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) random_tensor = paddle.add(keep_prob, paddle.rand(shape, dtype=x.dtype))
random_tensor = paddle.floor(random_tensor) # binarize random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor output = x.divide(keep_prob) * random_tensor
return output return output
...@@ -114,14 +114,15 @@ class Attention(nn.Layer): ...@@ -114,14 +114,15 @@ class Attention(nn.Layer):
(2, 0, 3, 1, 4)) (2, 0, 3, 1, 4))
q, k = qk[0], qk[1] q, k = qk[0], qk[1]
v = self.v(x).reshape((B, N, self.num_heads, -1)).transpose( v = self.v(x).reshape((B, N, self.num_heads, x.shape[-1] // self.num_heads)).transpose(
(0, 2, 1, 3)) (0, 2, 1, 3))
attn = (q @k.transpose((0, 1, 3, 2))) * self.scale attn = paddle.matmul(q, k.transpose((0, 1, 3, 2))) * self.scale
attn = nn.functional.softmax(attn, axis=-1) attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @v).transpose((0, 2, 1, 3)).reshape((B, N, -1)) x = paddle.matmul(attn, v)
x = x.transpose((0, 2, 1, 3)).reshape((B, N, x.shape[-1] * x.shape[-3]))
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
...@@ -182,18 +183,18 @@ class Block(nn.Layer): ...@@ -182,18 +183,18 @@ class Block(nn.Layer):
def forward(self, pixel_embed, patch_embed): def forward(self, pixel_embed, patch_embed):
# inner # inner
pixel_embed = pixel_embed + self.drop_path( pixel_embed = paddle.add(pixel_embed, self.drop_path(
self.attn_in(self.norm_in(pixel_embed))) self.attn_in(self.norm_in(pixel_embed))))
pixel_embed = pixel_embed + self.drop_path( pixel_embed = paddle.add(pixel_embed, self.drop_path(
self.mlp_in(self.norm_mlp_in(pixel_embed))) self.mlp_in(self.norm_mlp_in(pixel_embed))))
# outer # outer
B, N, C = patch_embed.shape B, N, C = patch_embed.shape
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj( patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:], self.proj(
self.norm1_proj(pixel_embed).reshape((B, N - 1, -1))) self.norm1_proj(pixel_embed).reshape((B, N - 1, -1))))
patch_embed = patch_embed + self.drop_path( patch_embed = paddle.add(patch_embed, self.drop_path(
self.attn_out(self.norm_out(patch_embed))) self.attn_out(self.norm_out(patch_embed))))
patch_embed = patch_embed + self.drop_path( patch_embed = paddle.add(patch_embed, self.drop_path(
self.mlp(self.norm_mlp(patch_embed))) self.mlp(self.norm_mlp(patch_embed))))
return pixel_embed, patch_embed return pixel_embed, patch_embed
...@@ -222,10 +223,9 @@ class PixelEmbed(nn.Layer): ...@@ -222,10 +223,9 @@ class PixelEmbed(nn.Layer):
x = self.proj(x) x = self.proj(x)
x = nn.functional.unfold(x, self.new_patch_size, self.new_patch_size) x = nn.functional.unfold(x, self.new_patch_size, self.new_patch_size)
x = x.transpose((0, 2, 1)).reshape( x = x.transpose((0, 2, 1)).reshape(
(B * self.num_patches, self.in_dim, self.new_patch_size, (-1, self.in_dim, self.new_patch_size, self.new_patch_size))
self.new_patch_size))
x = x + pixel_pos x = x + pixel_pos
x = x.reshape((B * self.num_patches, self.in_dim, -1)).transpose( x = x.reshape((-1, self.in_dim, self.new_patch_size * self.new_patch_size)).transpose(
(0, 2, 1)) (0, 2, 1))
return x return x
...@@ -328,7 +328,7 @@ class TNT(nn.Layer): ...@@ -328,7 +328,7 @@ class TNT(nn.Layer):
patch_embed = self.norm2_proj( patch_embed = self.norm2_proj(
self.proj( self.proj(
self.norm1_proj( self.norm1_proj(
pixel_embed.reshape((B, self.num_patches, -1))))) pixel_embed.reshape((-1, self.num_patches, pixel_embed.shape[-1] * pixel_embed.shape[-2])))))
patch_embed = paddle.concat( patch_embed = paddle.concat(
(self.cls_token.expand((B, -1, -1)), patch_embed), axis=1) (self.cls_token.expand((B, -1, -1)), patch_embed), axis=1)
patch_embed = patch_embed + self.patch_pos patch_embed = patch_embed + self.patch_pos
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册