未验证 提交 d1ce58ca 编写于 作者: W wangguanzhong 提交者: GitHub

fix mask infer (#1846)

上级 08c993d7
...@@ -87,11 +87,11 @@ class MaskFeat(Layer): ...@@ -87,11 +87,11 @@ class MaskFeat(Layer):
mode='train'): mode='train'):
if self.share_bbox_feat and mask_index: if self.share_bbox_feat and mask_index:
rois_feat = paddle.gather(bbox_feat, mask_index) rois_feat = paddle.gather(bbox_feat, mask_index)
if bbox_head_feat_func is not None and mode == 'infer':
rois_feat = bbox_head_feat_func(rois_feat)
else: else:
rois_feat = self.mask_roi_extractor(body_feats, bboxes, rois_feat = self.mask_roi_extractor(body_feats, bboxes,
spatial_scale) spatial_scale)
if bbox_head_feat_func is not None and mode == 'infer':
rois_feat = bbox_head_feat_func(rois_feat)
# upsample # upsample
mask_feat = self.upsample_module[stage](rois_feat) mask_feat = self.upsample_module[stage](rois_feat)
return mask_feat return mask_feat
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册