diff --git a/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py b/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py index 2ccd3c7bee9d55fb0337dd67b187108e4451cb8e..e3ca9016cd962be97a6f89874bfde3047e4c088b 100644 --- a/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py +++ b/ppcls/arch/backbone/model_zoo/distilled_vision_transformer.py @@ -89,8 +89,8 @@ class DistilledVisionTransformer(VisionTransformer): B = paddle.shape(x)[0] x = self.patch_embed(x) - cls_tokens = self.cls_token.expand((B, -1, -1)) - dist_token = self.dist_token.expand((B, -1, -1)) + cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype) + dist_token = self.dist_token.expand((B, -1, -1)).astype(x.dtype) x = paddle.concat((cls_tokens, dist_token, x), axis=1) x = x + self.pos_embed