From ed21c2c3d16802f96dbd04b2c9ebff661780319c Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 27 Oct 2020 10:20:15 +0800 Subject: [PATCH] add iou_similarity (#1605) --- ppdet/modeling/ops.py | 65 +++++++++++++++++++++++++++++- ppdet/modeling/tests/test_base.py | 4 +- ppdet/modeling/tests/test_ops.py | 66 ++++++++++++++++++++----------- 3 files changed, 108 insertions(+), 27 deletions(-) diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index fcfd9ef8a..653bb519a 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -30,7 +30,7 @@ __all__ = [ #'prior_box', #'anchor_generator', #'generate_proposals', - #'iou_similarity', + 'iou_similarity', #'box_coder', #'yolo_box', #'multiclass_nms', @@ -240,6 +240,69 @@ def roi_align(input, return align_out +def iou_similarity(x, y, box_normalized=True, name=None): + """ + Computes intersection-over-union (IOU) between two box lists. + Box list 'X' should be a LoDTensor and 'Y' is a common Tensor, + boxes in 'Y' are shared by all instance of the batched inputs of X. + Given two boxes A and B, the calculation of IOU is as follows: + + $$ + IOU(A, B) = + \\frac{area(A\\cap B)}{area(A)+area(B)-area(A\\cap B)} + $$ + + Args: + x (Tensor): Box list X is a 2-D Tensor with shape [N, 4] holds N + boxes, each box is represented as [xmin, ymin, xmax, ymax], + the shape of X is [N, 4]. [xmin, ymin] is the left top + coordinate of the box if the input is image feature map, they + are close to the origin of the coordinate system. + [xmax, ymax] is the right bottom coordinate of the box. + The data type is float32 or float64. + y (Tensor): Box list Y holds M boxes, each box is represented as + [xmin, ymin, xmax, ymax], the shape of X is [N, 4]. + [xmin, ymin] is the left top coordinate of the box if the + input is image feature map, and [xmax, ymax] is the right + bottom coordinate of the box. The data type is float32 or float64. + box_normalized(bool): Whether treat the priorbox as a normalized box. + Set true by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output of iou_similarity op, a tensor with shape [N, M] + representing pairwise iou scores. The data type is same with x. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + paddle.enable_static() + + x = paddle.data(name='x', shape=[None, 4], dtype='float32') + y = paddle.data(name='y', shape=[None, 4], dtype='float32') + iou = ops.iou_similarity(x=x, y=y) + """ + + if in_dygraph_mode(): + out = core.ops.iou_similarity(x, y, 'box_normalized', box_normalized) + return out + + helper = LayerHelper("iou_similarity", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="iou_similarity", + inputs={"X": x, + "Y": y}, + attrs={"box_normalized": box_normalized}, + outputs={"Out": out}) + return out + + def collect_fpn_proposals(multi_rois, multi_scores, min_level, diff --git a/ppdet/modeling/tests/test_base.py b/ppdet/modeling/tests/test_base.py index 28fce2b41..22191502d 100644 --- a/ppdet/modeling/tests/test_base.py +++ b/ppdet/modeling/tests/test_base.py @@ -48,7 +48,7 @@ class LayerTest(unittest.TestCase): program = Program() with fluid.scope_guard(scope): with fluid.program_guard(program): - paddle.manual_seed(self.seed) + paddle.seed(self.seed) paddle.framework.random._manual_program_seed(self.seed) yield @@ -68,6 +68,6 @@ class LayerTest(unittest.TestCase): def dynamic_graph(self, force_to_use_cpu=False): with fluid.dygraph.guard( self._get_place(force_to_use_cpu=force_to_use_cpu)): - paddle.manual_seed(self.seed) + paddle.seed(self.seed) paddle.framework.random._manual_program_seed(self.seed) yield diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index b7426617e..7c290b88d 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -31,6 +31,18 @@ import ppdet.modeling.ops as ops from ppdet.modeling.tests.test_base import LayerTest +def make_rois(h, w, rois_num, output_size): + rois = np.zeros((0, 4)).astype('float32') + for roi_num in rois_num: + roi = np.zeros((roi_num, 4)).astype('float32') + roi[:, 0] = np.random.randint(0, h - output_size[0], size=roi_num) + roi[:, 1] = np.random.randint(0, w - output_size[1], size=roi_num) + roi[:, 2] = np.random.randint(roi[:, 0] + output_size[0], h) + roi[:, 3] = np.random.randint(roi[:, 1] + output_size[1], w) + rois = np.vstack((rois, roi)) + return rois + + class TestCollectFpnProposals(LayerTest): def test_collect_fpn_proposals(self): multi_bboxes_np = [] @@ -223,7 +235,7 @@ class TestROIAlign(LayerTest): inputs_np = np.random.rand(b, c, h, w).astype('float32') rois_num = [4, 6] output_size = (7, 7) - rois_np = self.make_rois(h, w, rois_num, output_size) + rois_np = make_rois(h, w, rois_num, output_size) rois_num_np = np.array(rois_num).astype('int32') with self.static_graph(): inputs = paddle.static.data( @@ -261,17 +273,6 @@ class TestROIAlign(LayerTest): self.assertTrue(np.array_equal(output_np, output_dy_np)) - def make_rois(self, h, w, rois_num, output_size): - rois = np.zeros((0, 4)).astype('float32') - for roi_num in rois_num: - roi = np.zeros((roi_num, 4)).astype('float32') - roi[:, 0] = np.random.randint(0, h - output_size[0], size=roi_num) - roi[:, 1] = np.random.randint(0, w - output_size[1], size=roi_num) - roi[:, 2] = np.random.randint(roi[:, 0] + output_size[0], h) - roi[:, 3] = np.random.randint(roi[:, 1] + output_size[1], w) - rois = np.vstack((rois, roi)) - return rois - def test_roi_align_error(self): program = Program() with program_guard(program): @@ -293,7 +294,7 @@ class TestROIPool(LayerTest): inputs_np = np.random.rand(b, c, h, w).astype('float32') rois_num = [4, 6] output_size = (7, 7) - rois_np = self.make_rois(h, w, rois_num, output_size) + rois_np = make_rois(h, w, rois_num, output_size) rois_num_np = np.array(rois_num).astype('int32') with self.static_graph(): inputs = paddle.static.data( @@ -331,17 +332,6 @@ class TestROIPool(LayerTest): self.assertTrue(np.array_equal(output_np, output_dy_np)) - def make_rois(self, h, w, rois_num, output_size): - rois = np.zeros((0, 4)).astype('float32') - for roi_num in rois_num: - roi = np.zeros((roi_num, 4)).astype('float32') - roi[:, 0] = np.random.randint(0, h - output_size[0], size=roi_num) - roi[:, 1] = np.random.randint(0, w - output_size[1], size=roi_num) - roi[:, 2] = np.random.randint(roi[:, 0] + output_size[0], h) - roi[:, 3] = np.random.randint(roi[:, 1] + output_size[1], w) - rois = np.vstack((rois, roi)) - return rois - def test_roi_pool_error(self): program = Program() with program_guard(program): @@ -357,5 +347,33 @@ class TestROIPool(LayerTest): output_size=(7, 7)) +class TestIoUSimilarity(LayerTest): + def test_iou_similarity(self): + b, c, h, w = 2, 12, 20, 20 + inputs_np = np.random.rand(b, c, h, w).astype('float32') + output_size = (7, 7) + x_np = make_rois(h, w, [20], output_size) + y_np = make_rois(h, w, [10], output_size) + with self.static_graph(): + x = paddle.static.data(name='x', shape=[20, 4], dtype='float32') + y = paddle.static.data(name='y', shape=[10, 4], dtype='float32') + + iou = ops.iou_similarity(x=x, y=y) + iou_np, = self.get_static_graph_result( + feed={ + 'x': x_np, + 'y': y_np, + }, fetch_list=[iou], with_lod=False) + + with self.dynamic_graph(): + x_dy = base.to_variable(x_np) + y_dy = base.to_variable(y_np) + + iou_dy = ops.iou_similarity(x=x_dy, y=y_dy) + iou_dy_np = iou_dy.numpy() + + self.assertTrue(np.array_equal(iou_np, iou_dy_np)) + + if __name__ == '__main__': unittest.main() -- GitLab