diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 424e9017dccef7317c94fa3584cad841237e7109..e655059087262732cd20e30665a29f0f8093ef39 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -388,6 +388,32 @@ class MultiClassNMS(object): self.background_label = background_label +@register +@serializable +class MatrixNMS(object): + __op__ = ops.matrix_nms + __append_doc__ = True + + def __init__(self, + score_threshold=.05, + post_threshold=.05, + nms_top_k=-1, + keep_top_k=100, + use_gaussian=False, + gaussian_sigma=2., + normalized=False, + background_label=0): + super(MatrixNMS, self).__init__() + self.score_threshold = score_threshold + self.post_threshold = post_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.normalized = normalized + self.use_gaussian = use_gaussian + self.gaussian_sigma = gaussian_sigma + self.background_label = background_label + + @register @serializable class YOLOBox(object): diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index caae2597a5401f7d703429ac68afdfea3d1df362..a9ae84643a9f6e377e3b1a9453c895c406972564 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -36,7 +36,7 @@ __all__ = [ #'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals', - #'matrix_nms', + 'matrix_nms', ] @@ -669,3 +669,155 @@ def yolo_box( }, attrs=attrs) return boxes, scores + + +def matrix_nms(bboxes, + scores, + score_threshold, + post_threshold, + nms_top_k, + keep_top_k, + use_gaussian=False, + gaussian_sigma=2., + background_label=0, + normalized=True, + return_index=False, + return_rois_num=True, + name=None): + """ + **Matrix NMS** + + This operator does matrix non maximum suppression (NMS). + + First selects a subset of candidate bounding boxes that have higher scores + than score_threshold (if provided), then the top k candidate is selected if + nms_top_k is larger than -1. Score of the remaining candidate are then + decayed according to the Matrix NMS scheme. + Aftern NMS step, at most keep_top_k number of total bboxes are to be kept + per image if keep_top_k is larger than -1. + + Args: + bboxes (Tensor): A 3-D Tensor with shape [N, M, 4] represents the + predicted locations of M bounding bboxes, + N is the batch size. Each bounding box has four + coordinate values and the layout is + [xmin, ymin, xmax, ymax], when box size equals to 4. + The data type is float32 or float64. + scores (Tensor): A 3-D Tensor with shape [N, C, M] + represents the predicted confidence predictions. + N is the batch size, C is the class number, M is + number of bounding boxes. For each category there + are total M scores which corresponding M bounding + boxes. Please note, M is equal to the 2nd dimension + of BBoxes. The data type is float32 or float64. + score_threshold (float): Threshold to filter out bounding boxes with + low confidence score. + post_threshold (float): Threshold to filter out bounding boxes with + low confidence score AFTER decaying. + nms_top_k (int): Maximum number of detections to be kept according to + the confidences after the filtering detections based + on score_threshold. + keep_top_k (int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + use_gaussian (bool): Use Gaussian as the decay function. Default: False + gaussian_sigma (float): Sigma for Gaussian decay function. Default: 2.0 + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all + categories will be considered. Default: 0 + normalized (bool): Whether detections are normalized. Default: True + return_index(bool): Whether return selected index. Default: False + return_rois_num(bool): whether return rois_num. Default: True + name(str): Name of the matrix nms op. Default: None. + + Returns: + A tuple with three Tensor: (Out, Index, RoisNum) if return_index is True, + otherwise, a tuple with two Tensor (Out, RoisNum) is returned. + + Out (Tensor): A 2-D Tensor with shape [No, 6] containing the + detection results. + Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax] + (After version 1.3, when no boxes detected, the lod is changed + from {0} to {1}) + + Index (Tensor): A 2-D Tensor with shape [No, 1] containing the + selected indices, which are absolute values cross batches. + + rois_num (Tensor): A 1-D Tensor with shape [N] containing + the number of detected boxes in each image. + + Examples: + .. code-block:: python + + + import paddle + from ppdet.modeling import ops + + boxes = paddle.static.data(name='bboxes', shape=[None,81, 4], + dtype='float32', lod_level=1) + scores = paddle.static.data(name='scores', shape=[None,81], + dtype='float32', lod_level=1) + out = ops.matrix_nms(bboxes=boxes, scores=scores, background_label=0, + score_threshold=0.5, post_threshold=0.1, + nms_top_k=400, keep_top_k=200, normalized=False) + + """ + check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'], + 'matrix_nms') + check_variable_and_dtype(scores, 'Scores', ['float32', 'float64'], + 'matrix_nms') + check_type(score_threshold, 'score_threshold', float, 'matrix_nms') + check_type(post_threshold, 'post_threshold', float, 'matrix_nms') + check_type(nms_top_k, 'nums_top_k', int, 'matrix_nms') + check_type(keep_top_k, 'keep_top_k', int, 'matrix_nms') + check_type(normalized, 'normalized', bool, 'matrix_nms') + check_type(use_gaussian, 'use_gaussian', bool, 'matrix_nms') + check_type(gaussian_sigma, 'gaussian_sigma', float, 'matrix_nms') + check_type(background_label, 'background_label', int, 'matrix_nms') + + if in_dygraph_mode(): + attrs = ('background_label', background_label, 'score_threshold', + score_threshold, 'post_threshold', post_threshold, 'nms_top_k', + nms_top_k, 'gaussian_sigma', gaussian_sigma, 'use_gaussian', + use_gaussian, 'keep_top_k', keep_top_k, 'normalized', + normalized) + out, index, rois_num = core.ops.matrix_nms(bboxes, scores, *attrs) + if return_index: + if return_rois_num: + return out, index, rois_num + return out, index + if return_rois_num: + return out, rois_num + return out + + helper = LayerHelper('matrix_nms', **locals()) + output = helper.create_variable_for_type_inference(dtype=bboxes.dtype) + index = helper.create_variable_for_type_inference(dtype='int') + outputs = {'Out': output, 'Index': index} + if return_rois_num: + rois_num = helper.create_variable_for_type_inference(dtype='int') + outputs['RoisNum'] = rois_num + + helper.append_op( + type="matrix_nms", + inputs={'BBoxes': bboxes, + 'Scores': scores}, + attrs={ + 'background_label': background_label, + 'score_threshold': score_threshold, + 'post_threshold': post_threshold, + 'nms_top_k': nms_top_k, + 'gaussian_sigma': gaussian_sigma, + 'use_gaussian': use_gaussian, + 'keep_top_k': keep_top_k, + 'normalized': normalized + }, + outputs=outputs) + output.stop_gradient = True + + if return_index: + if return_rois_num: + return output, index, rois_num + return output, index + if return_rois_num: + return output, rois_num + return output diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index 3fc3f0c885ff25f0cba4f74d107d284c8e68fb59..87cbf8dc6e93be393715531a4eed3063f3aad951 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -43,6 +43,14 @@ def make_rois(h, w, rois_num, output_size): return rois +def softmax(x): + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + class TestCollectFpnProposals(LayerTest): def test_collect_fpn_proposals(self): multi_bboxes_np = [] @@ -355,19 +363,15 @@ class TestIoUSimilarity(LayerTest): x_np = make_rois(h, w, [20], output_size) y_np = make_rois(h, w, [10], output_size) with self.static_graph(): - program = Program() - with program_guard(program): - x = paddle.static.data(name='x', shape=[20, 4], dtype='float32') - y = paddle.static.data(name='y', shape=[10, 4], dtype='float32') - - iou = ops.iou_similarity(x=x, y=y) - iou_np, = self.get_static_graph_result( - feed={ - 'x': x_np, - 'y': y_np, - }, - fetch_list=[iou], - with_lod=False) + x = paddle.static.data(name='x', shape=[20, 4], dtype='float32') + y = paddle.static.data(name='y', shape=[10, 4], dtype='float32') + + iou = ops.iou_similarity(x=x, y=y) + iou_np, = self.get_static_graph_result( + feed={ + 'x': x_np, + 'y': y_np, + }, fetch_list=[iou], with_lod=False) with self.dynamic_graph(): x_dy = base.to_variable(x_np) @@ -459,5 +463,80 @@ class TestYOLO_Box(LayerTest): scale_x_y=1.2) +class TestMatrixNMS(LayerTest): + def test_matrix_nms(self): + N, M, C = 7, 1200, 21 + BOX_SIZE = 4 + nms_top_k = 400 + keep_top_k = 200 + score_threshold = 0.01 + post_threshold = 0. + + scores_np = np.random.random((N * M, C)).astype('float32') + scores_np = np.apply_along_axis(softmax, 1, scores_np) + scores_np = np.reshape(scores_np, (N, M, C)) + scores_np = np.transpose(scores_np, (0, 2, 1)) + + boxes_np = np.random.random((N, M, BOX_SIZE)).astype('float32') + boxes_np[:, :, 0:2] = boxes_np[:, :, 0:2] * 0.5 + boxes_np[:, :, 2:4] = boxes_np[:, :, 2:4] * 0.5 + 0.5 + + with self.static_graph(): + boxes = paddle.static.data( + name='boxes', shape=[N, M, BOX_SIZE], dtype='float32') + scores = paddle.static.data( + name='scores', shape=[N, C, M], dtype='float32') + out, index, _ = ops.matrix_nms( + bboxes=boxes, + scores=scores, + score_threshold=score_threshold, + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k, + return_index=True) + out_np, index_np = self.get_static_graph_result( + feed={'boxes': boxes_np, + 'scores': scores_np}, + fetch_list=[out, index], + with_lod=True) + + with self.dynamic_graph(): + boxes_dy = base.to_variable(boxes_np) + scores_dy = base.to_variable(scores_np) + + out_dy, index_dy, _ = ops.matrix_nms( + bboxes=boxes_dy, + scores=scores_dy, + score_threshold=score_threshold, + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k, + return_index=True) + out_dy_np = out_dy.numpy() + index_dy_np = index_dy.numpy() + + self.assertTrue(np.array_equal(out_np, out_dy_np)) + self.assertTrue(np.array_equal(index_np, index_dy_np)) + + def test_matrix_nms_error(self): + paddle.enable_static() + program = Program() + with program_guard(program): + bboxes = paddle.static.data( + name='bboxes', shape=[7, 1200, 4], dtype='float32') + scores = paddle.static.data( + name='data_error', shape=[7, 21, 1200], dtype='int32') + self.assertRaises( + TypeError, + ops.matrix_nms, + bboxes=bboxes, + scores=scores, + score_threshold=0.01, + post_threshold=0., + nms_top_k=400, + keep_top_k=200, + return_index=True) + + if __name__ == '__main__': unittest.main()