From d49ee3ee31219c242ba23be78a4693489a70fdec Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 9 Dec 2020 18:54:44 +0800 Subject: [PATCH] fix mask_rcnn (#1855) --- ppdet/modeling/head/mask_head.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ppdet/modeling/head/mask_head.py b/ppdet/modeling/head/mask_head.py index 7c483ab84..656f52ceb 100644 --- a/ppdet/modeling/head/mask_head.py +++ b/ppdet/modeling/head/mask_head.py @@ -85,13 +85,14 @@ class MaskFeat(Layer): stage=0, bbox_head_feat_func=None, mode='train'): - if self.share_bbox_feat and mask_index: + if self.share_bbox_feat and mask_index is not None: rois_feat = paddle.gather(bbox_feat, mask_index) - if bbox_head_feat_func is not None and mode == 'infer': - rois_feat = bbox_head_feat_func(rois_feat) else: rois_feat = self.mask_roi_extractor(body_feats, bboxes, spatial_scale) + if self.share_bbox_feat and bbox_head_feat_func is not None and mode == 'infer': + rois_feat = bbox_head_feat_func(rois_feat) + # upsample mask_feat = self.upsample_module[stage](rois_feat) return mask_feat -- GitLab