提交 4e988692 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix concat error when fp16

上级 2c3ebe7b
......@@ -376,7 +376,7 @@ class PyramidVisionTransformer(nn.Layer):
for i in range(len(self.depths)):
x, (H, W) = self.patch_embeds[i](x)
if i == len(self.depths) - 1:
cls_tokens = self.cls_token.expand([B, -1, -1])
cls_tokens = self.cls_token.expand([B, -1, -1]).astype(x.dtype)
x = paddle.concat([cls_tokens, x], dim=1)
x = x + self.pos_embeds[i]
x = self.pos_drops[i](x)
......
......@@ -350,7 +350,9 @@ class TNT(nn.Layer):
pixel_embed.reshape((-1, self.num_patches, pixel_embed.
shape[-1] * pixel_embed.shape[-2])))))
patch_embed = paddle.concat(
(self.cls_token.expand((B, -1, -1)), patch_embed), axis=1)
(self.cls_token.expand((B, -1, -1)).astype(patch_embed.dtype),
patch_embed),
axis=1)
patch_embed = patch_embed + self.patch_pos
patch_embed = self.pos_drop(patch_embed)
for blk in self.blocks:
......
......@@ -302,7 +302,7 @@ class VisionTransformer(nn.Layer):
# B = x.shape[0]
B = paddle.shape(x)[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand((B, -1, -1))
cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype)
x = paddle.concat((cls_tokens, x), axis=1)
x = x + self.pos_embed
x = self.pos_drop(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册