未验证 提交 f4932680 编写于 作者: F FDInSky 提交者: GitHub

fix generate_proposals (#23797) (#24225)

* test=develop fix generate_proposals
上级 28fa467b
......@@ -2742,7 +2742,8 @@ def generate_proposals(scores,
nms_thresh=0.5,
min_size=0.1,
eta=1.0,
name=None):
name=None,
return_rois_num=False):
"""
**Generate proposal Faster-RCNN**
......@@ -2789,7 +2790,10 @@ def generate_proposals(scores,
width < min_size. The data type must be float32. `0.1` by default.
eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`,
`adaptive_threshold = adaptive_threshold * eta` in each iteration.
return_rois_num(bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's
num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents
the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model.
'False' by default.
Returns:
tuple:
A tuple with format ``(rpn_rois, rpn_roi_probs)``.
......@@ -2843,7 +2847,10 @@ def generate_proposals(scores,
rpn_roi_probs.stop_gradient = True
rpn_rois_lod.stop_gradient = True
if return_rois_num:
return rpn_rois, rpn_roi_probs, rpn_rois_lod
else:
return rpn_rois, rpn_roi_probs
def box_clip(input, im_info, name=None):
......
......@@ -480,7 +480,7 @@ class TestGenerateProposals(unittest.TestCase):
name='bbox_deltas',
shape=[num_anchors * 4, 8, 8],
dtype='float32')
rpn_rois, rpn_roi_probs, _ = fluid.layers.generate_proposals(
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
name='generate_proposals',
scores=scores,
bbox_deltas=bbox_deltas,
......
......@@ -282,8 +282,6 @@ class TestGenerateProposalsOp(OpTest):
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]),
'RpnRoisLod': (np.asarray(
self.lod, dtype=np.int32))
}
def test_check_output(self):
......@@ -328,5 +326,35 @@ class TestGenerateProposalsOp(OpTest):
self.nms_thresh, self.min_size, self.eta)
class TestGenerateProposalsOutLodOp(TestGenerateProposalsOp):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImInfo': self.im_info.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta,
'return_rois_num': True
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]),
'RpnRoisLod': (np.asarray(
self.lod, dtype=np.int32))
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册