提交 b62a17bb 编写于 作者: J jerrywgz

add nms api

上级 f660553d
...@@ -458,7 +458,8 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -458,7 +458,8 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
"predicted locations of M bounding bboxes, N is the batch size. " "predicted locations of M bounding bboxes, N is the batch size. "
"Each bounding box has four coordinate values and the layout is " "Each bounding box has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax], when box size equals to 4." "[xmin, ymin, xmax, ymax], when box size equals to 4."
"2. (LoDTensor) A 3-D Tensor with shape [N, M, 4]"); "2. (LoDTensor) A 3-D Tensor with shape [N, M, 4]"
"N is the number of boxes, M is the class number");
AddInput("Scores", AddInput("Scores",
"Two types of scores are supported:" "Two types of scores are supported:"
"1. (Tensor) A 3-D Tensor with shape [N, C, M] represents the " "1. (Tensor) A 3-D Tensor with shape [N, C, M] represents the "
...@@ -467,8 +468,7 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -467,8 +468,7 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
"there are total M scores which corresponding M bounding boxes. " "there are total M scores which corresponding M bounding boxes. "
" Please note, M is equal to the 1st dimension of BBoxes. " " Please note, M is equal to the 1st dimension of BBoxes. "
"2. (LoDTensor) A 2-D LoDTensor with shape" "2. (LoDTensor) A 2-D LoDTensor with shape"
"[N, num_class]. N is the number of bbox and" "[N, num_class]. N is the number of bbox");
"M represents the scores of bboxes in each class.");
AddAttr<int>( AddAttr<int>(
"background_label", "background_label",
"(int, defalut: 0) " "(int, defalut: 0) "
...@@ -497,7 +497,7 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -497,7 +497,7 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
"Number of total bboxes to be kept per image after NMS " "Number of total bboxes to be kept per image after NMS "
"step. -1 means keeping all bboxes after NMS step."); "step. -1 means keeping all bboxes after NMS step.");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) " "(bool, default true) "
"Whether detections are normalized.") "Whether detections are normalized.")
.SetDefault(true); .SetDefault(true);
AddOutput("Out", AddOutput("Out",
......
...@@ -48,6 +48,7 @@ __all__ = [ ...@@ -48,6 +48,7 @@ __all__ = [
'box_coder', 'box_coder',
'polygon_box_transform', 'polygon_box_transform',
'yolov3_loss', 'yolov3_loss',
'multiclass_nms',
] ]
...@@ -1810,3 +1811,37 @@ def generate_proposals(scores, ...@@ -1810,3 +1811,37 @@ def generate_proposals(scores,
rpn_roi_probs.stop_gradient = True rpn_roi_probs.stop_gradient = True
return rpn_rois, rpn_roi_probs return rpn_rois, rpn_roi_probs
def multiclass_nms(bboxes,
scores,
score_threshold,
nms_top_k,
nms_threshold,
keep_top_k,
normalized=True,
nms_eta=1.,
background_label=0):
"""
"""
helper = LayerHelper('multiclass_nms', **locals())
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
helper.append_op(
type="multiclass_nms",
inputs={'BBoxes': bboxes,
'Scores': scores},
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'nms_eta': nms_eta,
'keep_top_k': keep_top_k,
'nms_eta': nms_eta,
'normalized': normalized
},
outputs={'Out': output})
output.stop_gradient = True
return output
...@@ -401,5 +401,16 @@ class TestYoloDetection(unittest.TestCase): ...@@ -401,5 +401,16 @@ class TestYoloDetection(unittest.TestCase):
self.assertIsNotNone(loss) self.assertIsNotNone(loss)
class TestMulticlassNMS(unittest.TestCase):
def test_multiclass_nms(self):
program = Program()
with program_guard(program):
bboxes = layers.data(
name='bboxes', shape=[-1, 10, 4], dtype='float32')
scores = layers.data(name='scores', shape=[-1, 10], dtype='float32')
output = layers.multiclass_nms(bboxes, scores, 0.3, 400, 0.7, 200)
self.assertIsNotNone(output)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册