提交 6fdaf94a 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix concat error when fp16

上级 1433161e
......@@ -89,8 +89,8 @@ class DistilledVisionTransformer(VisionTransformer):
B = paddle.shape(x)[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand((B, -1, -1))
dist_token = self.dist_token.expand((B, -1, -1))
cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype)
dist_token = self.dist_token.expand((B, -1, -1)).astype(x.dtype)
x = paddle.concat((cls_tokens, dist_token, x), axis=1)
x = x + self.pos_embed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册