From 6fdaf94a0d5516f293c37e391ace8cbeec9f7271 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Tue, 4 Apr 2023 09:21:58 +0000 Subject: [PATCH] fix concat error when fp16 --- ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py b/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py index 2ccd3c7b..e3ca9016 100644 --- a/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py +++ b/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py @@ -89,8 +89,8 @@ class DistilledVisionTransformer(VisionTransformer): B = paddle.shape(x)[0] x = self.patch_embed(x) - cls_tokens = self.cls_token.expand((B, -1, -1)) - dist_token = self.dist_token.expand((B, -1, -1)) + cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype) + dist_token = self.dist_token.expand((B, -1, -1)).astype(x.dtype) x = paddle.concat((cls_tokens, dist_token, x), axis=1) x = x + self.pos_embed -- GitLab