diff --git a/ppcls/arch/backbone/model_zoo/gvt.py b/ppcls/arch/backbone/model_zoo/gvt.py index d1afbecaac6f27d95a9780209cdb5ac33ce911b0..3e0592389b82fd989388ad17020b57d8d550d475 100644 --- a/ppcls/arch/backbone/model_zoo/gvt.py +++ b/ppcls/arch/backbone/model_zoo/gvt.py @@ -376,7 +376,7 @@ class PyramidVisionTransformer(nn.Layer): for i in range(len(self.depths)): x, (H, W) = self.patch_embeds[i](x) if i == len(self.depths) - 1: - cls_tokens = self.cls_token.expand([B, -1, -1]) + cls_tokens = self.cls_token.expand([B, -1, -1]).astype(x.dtype) x = paddle.concat([cls_tokens, x], dim=1) x = x + self.pos_embeds[i] x = self.pos_drops[i](x) diff --git a/ppcls/arch/backbone/model_zoo/tnt.py b/ppcls/arch/backbone/model_zoo/tnt.py index c313a14023ea97838fa39c25d146c9f84865e330..6025a1d2d9827af25be1101d6c113eb4ad759b93 100644 --- a/ppcls/arch/backbone/model_zoo/tnt.py +++ b/ppcls/arch/backbone/model_zoo/tnt.py @@ -350,7 +350,9 @@ class TNT(nn.Layer): pixel_embed.reshape((-1, self.num_patches, pixel_embed. shape[-1] * pixel_embed.shape[-2]))))) patch_embed = paddle.concat( - (self.cls_token.expand((B, -1, -1)), patch_embed), axis=1) + (self.cls_token.expand((B, -1, -1)).astype(patch_embed.dtype), + patch_embed), + axis=1) patch_embed = patch_embed + self.patch_pos patch_embed = self.pos_drop(patch_embed) for blk in self.blocks: diff --git a/ppcls/arch/backbone/model_zoo/vision_transformer.py b/ppcls/arch/backbone/model_zoo/vision_transformer.py index fbec1fcb48d75322f36e4c8abd277132d5825b92..e12b66e8263af212b45837df3e186d0a2fc6da52 100644 --- a/ppcls/arch/backbone/model_zoo/vision_transformer.py +++ b/ppcls/arch/backbone/model_zoo/vision_transformer.py @@ -302,7 +302,7 @@ class VisionTransformer(nn.Layer): # B = x.shape[0] B = paddle.shape(x)[0] x = self.patch_embed(x) - cls_tokens = self.cls_token.expand((B, -1, -1)) + cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype) x = paddle.concat((cls_tokens, x), axis=1) x = x + self.pos_embed x = self.pos_drop(x)