From f068e08d725faf61ccf3128efd70fdcd89cd8a1c Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Tue, 28 Sep 2021 20:18:26 +0800 Subject: [PATCH] add roi_align (#35102) * add roi_align in vision/ops.py --- python/paddle/tests/test_ops_roi_align.py | 108 +++++++++++++++ python/paddle/vision/ops.py | 159 ++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 python/paddle/tests/test_ops_roi_align.py diff --git a/python/paddle/tests/test_ops_roi_align.py b/python/paddle/tests/test_ops_roi_align.py new file mode 100644 index 00000000000..4a37831a0cc --- /dev/null +++ b/python/paddle/tests/test_ops_roi_align.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle +from paddle.vision.ops import roi_align, RoIAlign + + +class TestRoIAlign(unittest.TestCase): + def setUp(self): + self.data = np.random.rand(1, 256, 32, 32).astype('float32') + boxes = np.random.rand(3, 4) + boxes[:, 2] += boxes[:, 0] + 3 + boxes[:, 3] += boxes[:, 1] + 4 + self.boxes = boxes.astype('float32') + self.boxes_num = np.array([3], dtype=np.int32) + + def roi_align_functional(self, output_size): + if isinstance(output_size, int): + output_shape = (3, 256, output_size, output_size) + else: + output_shape = (3, 256, output_size[0], output_size[1]) + + if paddle.in_dynamic_mode(): + data = paddle.to_tensor(self.data) + boxes = paddle.to_tensor(self.boxes) + boxes_num = paddle.to_tensor(self.boxes_num) + + align_out = roi_align( + data, boxes, boxes_num=boxes_num, output_size=output_size) + np.testing.assert_equal(align_out.shape, output_shape) + + else: + data = paddle.static.data( + shape=self.data.shape, dtype=self.data.dtype, name='data') + boxes = paddle.static.data( + shape=self.boxes.shape, dtype=self.boxes.dtype, name='boxes') + boxes_num = paddle.static.data( + shape=self.boxes_num.shape, + dtype=self.boxes_num.dtype, + name='boxes_num') + + align_out = roi_align( + data, boxes, boxes_num=boxes_num, output_size=output_size) + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + align_out = exe.run(paddle.static.default_main_program(), + feed={ + 'data': self.data, + 'boxes': self.boxes, + 'boxes_num': self.boxes_num + }, + fetch_list=[align_out]) + + np.testing.assert_equal(align_out[0].shape, output_shape) + + def test_roi_align_functional_dynamic(self): + self.roi_align_functional(3) + self.roi_align_functional(output_size=(3, 4)) + + def test_roi_align_functional_static(self): + paddle.enable_static() + self.roi_align_functional(3) + paddle.disable_static() + + def test_RoIAlign(self): + roi_align_c = RoIAlign(output_size=(4, 3)) + data = paddle.to_tensor(self.data) + boxes = paddle.to_tensor(self.boxes) + boxes_num = paddle.to_tensor(self.boxes_num) + + align_out = roi_align_c(data, boxes, boxes_num) + np.testing.assert_equal(align_out.shape, (3, 256, 4, 3)) + + def test_value(self, ): + data = np.array([i for i in range(1, 17)]).reshape(1, 1, 4, + 4).astype(np.float32) + boxes = np.array( + [[1., 1., 2., 2.], [1.5, 1.5, 3., 3.]]).astype(np.float32) + boxes_num = np.array([2]).astype(np.int32) + output = np.array([[[[6.]]], [[[9.75]]]], dtype=np.float32) + + data = paddle.to_tensor(data) + boxes = paddle.to_tensor(boxes) + boxes_num = paddle.to_tensor(boxes_num) + + roi_align_c = RoIAlign(output_size=1) + align_out = roi_align_c(data, boxes, boxes_num) + np.testing.assert_almost_equal(align_out.numpy(), output) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 84dcdfa4cfc..965cf8b55e7 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -34,6 +34,8 @@ __all__ = [ #noqa 'RoIPool', 'psroi_pool', 'PSRoIPool', + 'roi_align', + 'RoIAlign', ] @@ -1138,3 +1140,160 @@ class RoIPool(Layer): def extra_repr(self): main_str = 'output_size={_output_size}, spatial_scale={_spatial_scale}' return main_str.format(**self.__dict__) + + +def roi_align(x, + boxes, + boxes_num, + output_size, + spatial_scale=1.0, + sampling_ratio=-1, + aligned=True, + name=None): + """ + This operator implements the roi_align layer. + Region of Interest (RoI) Align operator (also known as RoI Align) is to + perform bilinear interpolation on inputs of nonuniform sizes to obtain + fixed-size feature maps (e.g. 7*7), as described in Mask R-CNN. + + Dividing each region proposal into equal-sized sections with the pooled_width + and pooled_height. Location remains the origin result. + + In each ROI bin, the value of the four regularly sampled locations are + computed directly through bilinear interpolation. The output is the mean of + four locations. Thus avoid the misaligned problem. + + Args: + x (Tensor): Input feature, 4D-Tensor with the shape of [N,C,H,W], + where N is the batch size, C is the input channel, H is Height, + W is weight. The data type is float32 or float64. + boxes (Tensor): Boxes (RoIs, Regions of Interest) to pool over. It + should be a 2-D Tensor of shape (num_boxes, 4). The data type is + float32 or float64. Given as [[x1, y1, x2, y2], ...], (x1, y1) is + the top left coordinates, and (x2, y2) is the bottom right coordinates. + boxes_num (Tensor): The number of boxes contained in each picture in + the batch, the data type is int32. + output_size (int or Tuple[int, int]): The pooled output size(h, w), data + type is int32. If int, h and w are both equal to output_size. + spatial_scale (float32): Multiplicative spatial scale factor to translate + ROI coords from their input scale to the scale used when pooling. + Default: 1.0 + sampling_ratio (int32): number of sampling points in the interpolation + grid used to compute the output value of each pooled output bin. + If > 0, then exactly ``sampling_ratio x sampling_ratio`` sampling + points per bin are used. + If <= 0, then an adaptive number of grid points are used (computed + as ``ceil(roi_width / output_width)``, and likewise for height). + Default: -1 + aligned (bool): If False, use the legacy implementation. If True, pixel + shift the box coordinates it by -0.5 for a better alignment with the + two neighboring pixel indices. This version is used in Detectron2. + Default: True + name(str, optional): For detailed information, please refer to : + ref:`api_guide_Name`. Usually name is no need to set and None by + default. + + Returns: + Tensor: The output of ROIAlignOp is a 4-D tensor with shape (num_boxes, + channels, pooled_h, pooled_w). The data type is float32 or float64. + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.ops import roi_align + + data = paddle.rand([1, 256, 32, 32]) + boxes = paddle.rand([3, 4]) + boxes[:, 2] += boxes[:, 0] + 3 + boxes[:, 3] += boxes[:, 1] + 4 + boxes_num = paddle.to_tensor([3]).astype('int32') + align_out = roi_align(data, boxes, boxes_num, output_size=3) + assert align_out.shape == [3, 256, 3, 3] + """ + + check_type(output_size, 'output_size', (int, tuple), 'roi_align') + if isinstance(output_size, int): + output_size = (output_size, output_size) + + pooled_height, pooled_width = output_size + if in_dygraph_mode(): + assert boxes_num is not None, "boxes_num should not be None in dygraph mode." + align_out = core.ops.roi_align( + x, boxes, boxes_num, "pooled_height", pooled_height, "pooled_width", + pooled_width, "spatial_scale", spatial_scale, "sampling_ratio", + sampling_ratio, "aligned", aligned) + return align_out + + else: + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'roi_align') + check_variable_and_dtype(boxes, 'boxes', ['float32', 'float64'], + 'roi_align') + helper = LayerHelper('roi_align', **locals()) + dtype = helper.input_dtype() + align_out = helper.create_variable_for_type_inference(dtype) + inputs = { + "X": x, + "ROIs": boxes, + } + if boxes_num is not None: + inputs['RoisNum'] = boxes_num + helper.append_op( + type="roi_align", + inputs=inputs, + outputs={"Out": align_out}, + attrs={ + "pooled_height": pooled_height, + "pooled_width": pooled_width, + "spatial_scale": spatial_scale, + "sampling_ratio": sampling_ratio, + "aligned": aligned, + }) + return align_out + + +class RoIAlign(Layer): + """ + This interface is used to construct a callable object of the `RoIAlign` class. + Please refer to :ref:`api_paddle_vision_ops_roi_align`. + + Args: + output_size (int or tuple[int, int]): The pooled output size(h, w), + data type is int32. If int, h and w are both equal to output_size. + spatial_scale (float32, optional): Multiplicative spatial scale factor + to translate ROI coords from their input scale to the scale used + when pooling. Default: 1.0 + + Returns: + align_out (Tensor): The output of ROIAlign operator is a 4-D tensor with + shape (num_boxes, channels, pooled_h, pooled_w). + + Examples: + .. code-block:: python + + import paddle + from paddle.vision.ops import RoIAlign + + data = paddle.rand([1, 256, 32, 32]) + boxes = paddle.rand([3, 4]) + boxes[:, 2] += boxes[:, 0] + 3 + boxes[:, 3] += boxes[:, 1] + 4 + boxes_num = paddle.to_tensor([3]).astype('int32') + roi_align = RoIAlign(output_size=(4, 3)) + align_out = roi_align(data, boxes, boxes_num) + assert align_out.shape == [3, 256, 4, 3] + """ + + def __init__(self, output_size, spatial_scale=1.0): + super(RoIAlign, self).__init__() + self._output_size = output_size + self._spatial_scale = spatial_scale + + def forward(self, x, boxes, boxes_num, aligned=True): + return roi_align( + x=x, + boxes=boxes, + boxes_num=boxes_num, + output_size=self._output_size, + spatial_scale=self._spatial_scale, + aligned=aligned) -- GitLab