diff --git a/ppcls/arch/backbone/model_zoo/tnt.py b/ppcls/arch/backbone/model_zoo/tnt.py index 352e1d2d6ceded283af0520c16d5aa7bd8d77f6a..13e9b5c4a0831b8c87a01cb50b43405fd8481be9 100644 --- a/ppcls/arch/backbone/model_zoo/tnt.py +++ b/ppcls/arch/backbone/model_zoo/tnt.py @@ -193,9 +193,11 @@ class Block(nn.Layer): self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))) # outer B, N, C = patch_embed.shape - patch_embed[:, 1:] = paddle.add( - patch_embed[:, 1:], - self.proj(self.norm1_proj(pixel_embed).reshape((B, N - 1, -1)))) + norm1_proj = self.norm1_proj(pixel_embed) + norm1_proj = norm1_proj.reshape( + (B, N - 1, norm1_proj.shape[1] * norm1_proj.shape[2])) + patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:], + self.proj(norm1_proj)) patch_embed = paddle.add( patch_embed, self.drop_path(self.attn_out(self.norm_out(patch_embed)))) @@ -328,7 +330,7 @@ class TNT(nn.Layer): ones_(m.weight) def forward_features(self, x): - B = x.shape[0] + B = paddle.shape(x)[0] pixel_embed = self.pixel_embed(x, self.pixel_pos) patch_embed = self.norm2_proj(