未验证 提交 c363fa1d 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1210 from cuicheng01/release/2.2

[Cherry-pick]fix tnt inference bug when bs > 1
...@@ -193,9 +193,11 @@ class Block(nn.Layer): ...@@ -193,9 +193,11 @@ class Block(nn.Layer):
self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))) self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))))
# outer # outer
B, N, C = patch_embed.shape B, N, C = patch_embed.shape
patch_embed[:, 1:] = paddle.add( norm1_proj = self.norm1_proj(pixel_embed)
patch_embed[:, 1:], norm1_proj = norm1_proj.reshape(
self.proj(self.norm1_proj(pixel_embed).reshape((B, N - 1, -1)))) (B, N - 1, norm1_proj.shape[1] * norm1_proj.shape[2]))
patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:],
self.proj(norm1_proj))
patch_embed = paddle.add( patch_embed = paddle.add(
patch_embed, patch_embed,
self.drop_path(self.attn_out(self.norm_out(patch_embed)))) self.drop_path(self.attn_out(self.norm_out(patch_embed))))
...@@ -328,7 +330,7 @@ class TNT(nn.Layer): ...@@ -328,7 +330,7 @@ class TNT(nn.Layer):
ones_(m.weight) ones_(m.weight)
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0] B = paddle.shape(x)[0]
pixel_embed = self.pixel_embed(x, self.pixel_pos) pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj( patch_embed = self.norm2_proj(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册