提交 3f0b406f 编写于 作者: C cuicheng01

fix tnt inference bug when bs > 1

上级 ce39aea9
......@@ -193,9 +193,11 @@ class Block(nn.Layer):
self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))))
# outer
B, N, C = patch_embed.shape
patch_embed[:, 1:] = paddle.add(
patch_embed[:, 1:],
self.proj(self.norm1_proj(pixel_embed).reshape((B, N - 1, -1))))
norm1_proj = self.norm1_proj(pixel_embed)
norm1_proj = norm1_proj.reshape(
(B, N - 1, norm1_proj.shape[1] * norm1_proj.shape[2]))
patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:],
self.proj(norm1_proj))
patch_embed = paddle.add(
patch_embed,
self.drop_path(self.attn_out(self.norm_out(patch_embed))))
......@@ -328,7 +330,7 @@ class TNT(nn.Layer):
ones_(m.weight)
def forward_features(self, x):
B = x.shape[0]
B = paddle.shape(x)[0]
pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册