From 38bdf6504fa7e098d0bb350c1b41bcb84c506ca8 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Wed, 19 Oct 2022 20:58:33 +0800 Subject: [PATCH] [dev] fix deformable_detr src_mask bug (#7148) --- ppdet/modeling/transformers/deformable_transformer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ppdet/modeling/transformers/deformable_transformer.py b/ppdet/modeling/transformers/deformable_transformer.py index 0c2089a8b..db07e0327 100644 --- a/ppdet/modeling/transformers/deformable_transformer.py +++ b/ppdet/modeling/transformers/deformable_transformer.py @@ -448,7 +448,6 @@ class DeformableTransformer(nn.Layer): return {'backbone_num_channels': [i.channels for i in input_shape], } def _get_valid_ratio(self, mask): - mask = mask.astype(paddle.float32) _, H, W = mask.shape valid_ratio_h = paddle.sum(mask[:, :, 0], 1) / H valid_ratio_w = paddle.sum(mask[:, 0, :], 1) / W @@ -477,18 +476,16 @@ class DeformableTransformer(nn.Layer): src = src.flatten(2).transpose([0, 2, 1]) src_flatten.append(src) if src_mask is not None: - mask = F.interpolate( - src_mask.unsqueeze(0).astype(src.dtype), - size=(h, w))[0].astype('bool') + mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0] else: - mask = paddle.ones([bs, h, w], dtype='bool') + mask = paddle.ones([bs, h, w]) valid_ratios.append(self._get_valid_ratio(mask)) pos_embed = self.position_embedding(mask).flatten(2).transpose( [0, 2, 1]) lvl_pos_embed = pos_embed + self.level_embed.weight[level].reshape( [1, 1, -1]) lvl_pos_embed_flatten.append(lvl_pos_embed) - mask = mask.astype(src.dtype).flatten(1) + mask = mask.flatten(1) mask_flatten.append(mask) src_flatten = paddle.concat(src_flatten, 1) mask_flatten = paddle.concat(mask_flatten, 1) -- GitLab