diff --git a/ppdet/modeling/head/mask_head.py b/ppdet/modeling/head/mask_head.py index f57a98de0a202845ff2edec4821ef6e19b743d81..7c483ab84f3dd23ac0447e220ca4831414b5ec48 100644 --- a/ppdet/modeling/head/mask_head.py +++ b/ppdet/modeling/head/mask_head.py @@ -87,11 +87,11 @@ class MaskFeat(Layer): mode='train'): if self.share_bbox_feat and mask_index: rois_feat = paddle.gather(bbox_feat, mask_index) + if bbox_head_feat_func is not None and mode == 'infer': + rois_feat = bbox_head_feat_func(rois_feat) else: rois_feat = self.mask_roi_extractor(body_feats, bboxes, spatial_scale) - if bbox_head_feat_func is not None and mode == 'infer': - rois_feat = bbox_head_feat_func(rois_feat) # upsample mask_feat = self.upsample_module[stage](rois_feat) return mask_feat