diff --git a/ppdet/modeling/heads/mask_head.py b/ppdet/modeling/heads/mask_head.py index 604847a2d07224314b2eba700eefa00729b4f95f..57dd1a024899c99a46c31eebe739b3ec83434dac 100644 --- a/ppdet/modeling/heads/mask_head.py +++ b/ppdet/modeling/heads/mask_head.py @@ -221,7 +221,7 @@ class MaskHead(nn.Layer): mask_feat = self.head(rois_feat) mask_logit = self.mask_fcn_logits(mask_feat) if self.num_classes == 1: - mask_out = F.sigmoid(mask_logit) + mask_out = F.sigmoid(mask_logit)[:, 0, :, :] else: num_masks = paddle.shape(mask_logit)[0] index = paddle.arange(num_masks).cast('int32')