diff --git a/ppdet/modeling/heads/mask_head.py b/ppdet/modeling/heads/mask_head.py index 08b7d3d26f046572c042f37ca800f91245d1ba3a..939debbaae129293551394b5571f7da158a0cccb 100644 --- a/ppdet/modeling/heads/mask_head.py +++ b/ppdet/modeling/heads/mask_head.py @@ -222,7 +222,7 @@ class MaskHead(nn.Layer): mask_feat = self.head(rois_feat) mask_logit = self.mask_fcn_logits(mask_feat) if self.num_classes == 1: - mask_out = F.sigmoid(mask_logit) + mask_out = F.sigmoid(mask_logit)[:, 0, :, :] else: num_masks = paddle.shape(mask_logit)[0] index = paddle.arange(num_masks).cast('int32')