diff --git a/ppdet/modeling/head/mask_head.py b/ppdet/modeling/head/mask_head.py index 7c483ab84f3dd23ac0447e220ca4831414b5ec48..656f52ceba39637814335ebd797ab5c97dea737f 100644 --- a/ppdet/modeling/head/mask_head.py +++ b/ppdet/modeling/head/mask_head.py @@ -85,13 +85,14 @@ class MaskFeat(Layer): stage=0, bbox_head_feat_func=None, mode='train'): - if self.share_bbox_feat and mask_index: + if self.share_bbox_feat and mask_index is not None: rois_feat = paddle.gather(bbox_feat, mask_index) - if bbox_head_feat_func is not None and mode == 'infer': - rois_feat = bbox_head_feat_func(rois_feat) else: rois_feat = self.mask_roi_extractor(body_feats, bboxes, spatial_scale) + if self.share_bbox_feat and bbox_head_feat_func is not None and mode == 'infer': + rois_feat = bbox_head_feat_func(rois_feat) + # upsample mask_feat = self.upsample_module[stage](rois_feat) return mask_feat