未验证 提交 d49ee3ee 编写于 作者: W wangguanzhong 提交者: GitHub

fix mask_rcnn (#1855)

上级 73ac7c1d
...@@ -85,13 +85,14 @@ class MaskFeat(Layer): ...@@ -85,13 +85,14 @@ class MaskFeat(Layer):
stage=0, stage=0,
bbox_head_feat_func=None, bbox_head_feat_func=None,
mode='train'): mode='train'):
if self.share_bbox_feat and mask_index: if self.share_bbox_feat and mask_index is not None:
rois_feat = paddle.gather(bbox_feat, mask_index) rois_feat = paddle.gather(bbox_feat, mask_index)
if bbox_head_feat_func is not None and mode == 'infer':
rois_feat = bbox_head_feat_func(rois_feat)
else: else:
rois_feat = self.mask_roi_extractor(body_feats, bboxes, rois_feat = self.mask_roi_extractor(body_feats, bboxes,
spatial_scale) spatial_scale)
if self.share_bbox_feat and bbox_head_feat_func is not None and mode == 'infer':
rois_feat = bbox_head_feat_func(rois_feat)
# upsample # upsample
mask_feat = self.upsample_module[stage](rois_feat) mask_feat = self.upsample_module[stage](rois_feat)
return mask_feat return mask_feat
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册