diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index ed33eb966fb43f16eca40ff3e0ba8937939ee1b3..0dbfd3bc5fa5f4c78121a5cb8b3896120361388a 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -298,7 +298,7 @@ class RoIExtractor(object): roi, rois_num = rois cur_l = 0 if self.start_level == self.end_level: - rois_feat = fluid.layers.roi_align( + rois_feat = ops.roi_align( feats[self.start_level], roi, self.resolution, @@ -319,7 +319,7 @@ class RoIExtractor(object): rois_feat_list = [] for lvl in range(self.start_level, self.end_level + 1): - roi_feat = fluid.layers.roi_align( + roi_feat = ops.roi_align( feats[lvl], rois_dist[lvl], self.resolution, diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 0c26a148b9c6b8bc78ea6e11a5315cf362880cac..5066c21643ded97557ae93bf88ddc5cf69a6bc8a 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -25,8 +25,8 @@ import numpy as np from functools import reduce __all__ = [ - #'roi_pool', - #'roi_align', + 'roi_pool', + 'roi_align', #'prior_box', #'anchor_generator', #'generate_proposals', @@ -40,6 +40,203 @@ __all__ = [ ] +def roi_pool(input, + rois, + output_size, + spatial_scale=1.0, + rois_num=None, + name=None): + """ + + This operator implements the roi_pooling layer. + Region of interest pooling (also known as RoI pooling) is to perform max pooling on inputs of nonuniform sizes to obtain fixed-size feature maps (e.g. 7*7). + + The operator has three steps: + + 1. Dividing each region proposal into equal-sized sections with output_size(h, w); + 2. Finding the largest value in each section; + 3. Copying these max values to the output buffer. + + For more information, please refer to https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn + + Args: + input (Tensor): Input feature, 4D-Tensor with the shape of [N,C,H,W], + where N is the batch size, C is the input channel, H is Height, W is weight. + The data type is float32 or float64. + rois (Tensor): ROIs (Regions of Interest) to pool over. + 2D-Tensor or 2D-LoDTensor with the shape of [num_rois,4], the lod level is 1. + Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates, + and (x2, y2) is the bottom right coordinates. + output_size (int or tuple[int, int]): The pooled output size(h, w), data type is int32. If int, h and w are both equal to output_size. + spatial_scale (float, optional): Multiplicative spatial scale factor to translate ROI coords from their input scale to the scale used when pooling. Default: 1.0 + rois_num (Tensor): The number of RoIs in each image. Default: None + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + Returns: + Tensor: The pooled feature, 4D-Tensor with the shape of [num_rois, C, output_size[0], output_size[1]]. + + + Examples: + + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.static.data( + name='data', shape=[None, 256, 32, 32], dtype='float32') + rois = paddle.static.data( + name='rois', shape=[None, 4], dtype='float32') + rois_num = paddle.static.data(name='rois_num', shape=[None], dtype='int32') + + pool_out = ops.roi_pool( + input=x, + rois=rois, + output_size=(1, 1), + spatial_scale=1.0, + rois_num=rois_num) + """ + check_type(output_size, 'output_size', (int, tuple), 'roi_pool') + if isinstance(output_size, int): + output_size = (output_size, output_size) + + pooled_height, pooled_width = output_size + if in_dygraph_mode(): + assert rois_num is not None, "rois_num should not be None in dygraph mode." + pool_out, argmaxes = core.ops.roi_pool( + input, rois, rois_num, "pooled_height", pooled_height, + "pooled_width", pooled_width, "spatial_scale", spatial_scale) + return pool_out, argmaxes + + check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool') + check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool') + helper = LayerHelper('roi_pool', **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + argmaxes = helper.create_variable_for_type_inference(dtype='int32') + + inputs = { + "X": input, + "ROIs": rois, + } + if rois_num is not None: + inputs['RoisNum'] = rois_num + helper.append_op( + type="roi_pool", + inputs=inputs, + outputs={"Out": pool_out, + "Argmax": argmaxes}, + attrs={ + "pooled_height": pooled_height, + "pooled_width": pooled_width, + "spatial_scale": spatial_scale + }) + return pool_out, argmaxes + + +def roi_align(input, + rois, + output_size, + spatial_scale=1.0, + sampling_ratio=-1, + rois_num=None, + name=None): + """ + + Region of interest align (also known as RoI align) is to perform + bilinear interpolation on inputs of nonuniform sizes to obtain + fixed-size feature maps (e.g. 7*7) + + Dividing each region proposal into equal-sized sections with + the pooled_width and pooled_height. Location remains the origin + result. + + In each ROI bin, the value of the four regularly sampled locations + are computed directly through bilinear interpolation. The output is + the mean of four locations. + Thus avoid the misaligned problem. + + Args: + input (Tensor): Input feature, 4D-Tensor with the shape of [N,C,H,W], + where N is the batch size, C is the input channel, H is Height, W is weight. + The data type is float32 or float64. + rois (Tensor): ROIs (Regions of Interest) to pool over.It should be + a 2-D Tensor or 2-D LoDTensor of shape (num_rois, 4), the lod level is 1. + The data type is float32 or float64. Given as [[x1, y1, x2, y2], ...], + (x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates. + output_size (int or tuple[int, int]): The pooled output size(h, w), data type is int32. If int, h and w are both equal to output_size. + spatial_scale (float32, optional): ${spatial_scale_comment} Default: 1.0 + sampling_ratio(int32, optional): ${sampling_ratio_comment} Default: -1 + rois_num (Tensor): The number of RoIs in each image. Default: None + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: + + Output: The output of ROIAlignOp is a 4-D tensor with shape (num_rois, channels, pooled_h, pooled_w). The data type is float32 or float64. + + + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.static.data( + name='data', shape=[None, 256, 32, 32], dtype='float32') + rois = paddle.static.data( + name='rois', shape=[None, 4], dtype='float32') + rois_num = paddle.static.data(name='rois_num', shape=[None], dtype='int32') + align_out = ops.roi_align(input=x, + rois=rois, + ouput_size=(7, 7), + spatial_scale=0.5, + sampling_ratio=-1, + rois_num=rois_num) + """ + if isinstance(output_size, int): + output_size = (output_size, output_size) + + pooled_height, pooled_width = output_size + + if in_dygraph_mode(): + assert rois_num is not None, "rois_num should not be None in dygraph mode." + align_out = core.ops.roi_align( + input, rois, rois_num, "pooled_height", pooled_height, + "pooled_width", pooled_width, "spatial_scale", spatial_scale, + "sampling_ratio", sampling_ratio) + return align_out + + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'roi_align') + check_variable_and_dtype(rois, 'rois', ['float32', 'float64'], 'roi_align') + helper = LayerHelper('roi_align', **locals()) + dtype = helper.input_dtype() + align_out = helper.create_variable_for_type_inference(dtype) + inputs = { + "X": input, + "ROIs": rois, + } + if rois_num is not None: + inputs['RoisNum'] = rois_num + helper.append_op( + type="roi_align", + inputs=inputs, + outputs={"Out": align_out}, + attrs={ + "pooled_height": pooled_height, + "pooled_width": pooled_width, + "spatial_scale": spatial_scale, + "sampling_ratio": sampling_ratio + }) + return align_out + + def collect_fpn_proposals(multi_rois, multi_scores, min_level, diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index 0c6b9dbe00aa288cc5a71bc3ba199f5bea5ed8d1..b7426617e4bb962a4302cb7d6237cf9133c839db 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -217,5 +217,145 @@ class TestDistributeFpnProposals(LayerTest): refer_scale=224) +class TestROIAlign(LayerTest): + def test_roi_align(self): + b, c, h, w = 2, 12, 20, 20 + inputs_np = np.random.rand(b, c, h, w).astype('float32') + rois_num = [4, 6] + output_size = (7, 7) + rois_np = self.make_rois(h, w, rois_num, output_size) + rois_num_np = np.array(rois_num).astype('int32') + with self.static_graph(): + inputs = paddle.static.data( + name='inputs', shape=[b, c, h, w], dtype='float32') + rois = paddle.static.data( + name='rois', shape=[10, 4], dtype='float32') + rois_num = paddle.static.data( + name='rois_num', shape=[None], dtype='int32') + + output = ops.roi_align( + input=inputs, + rois=rois, + output_size=output_size, + rois_num=rois_num) + output_np, = self.get_static_graph_result( + feed={ + 'inputs': inputs_np, + 'rois': rois_np, + 'rois_num': rois_num_np + }, + fetch_list=output, + with_lod=False) + + with self.dynamic_graph(): + inputs_dy = base.to_variable(inputs_np) + rois_dy = base.to_variable(rois_np) + rois_num_dy = base.to_variable(rois_num_np) + + output_dy = ops.roi_align( + input=inputs_dy, + rois=rois_dy, + output_size=output_size, + rois_num=rois_num_dy) + output_dy_np = output_dy.numpy() + + self.assertTrue(np.array_equal(output_np, output_dy_np)) + + def make_rois(self, h, w, rois_num, output_size): + rois = np.zeros((0, 4)).astype('float32') + for roi_num in rois_num: + roi = np.zeros((roi_num, 4)).astype('float32') + roi[:, 0] = np.random.randint(0, h - output_size[0], size=roi_num) + roi[:, 1] = np.random.randint(0, w - output_size[1], size=roi_num) + roi[:, 2] = np.random.randint(roi[:, 0] + output_size[0], h) + roi[:, 3] = np.random.randint(roi[:, 1] + output_size[1], w) + rois = np.vstack((rois, roi)) + return rois + + def test_roi_align_error(self): + program = Program() + with program_guard(program): + inputs = paddle.static.data( + name='inputs', shape=[2, 12, 20, 20], dtype='float32') + rois = paddle.static.data( + name='data_error', shape=[10, 4], dtype='int32', lod_level=1) + self.assertRaises( + TypeError, + ops.roi_align, + input=inputs, + rois=rois, + output_size=(7, 7)) + + +class TestROIPool(LayerTest): + def test_roi_pool(self): + b, c, h, w = 2, 12, 20, 20 + inputs_np = np.random.rand(b, c, h, w).astype('float32') + rois_num = [4, 6] + output_size = (7, 7) + rois_np = self.make_rois(h, w, rois_num, output_size) + rois_num_np = np.array(rois_num).astype('int32') + with self.static_graph(): + inputs = paddle.static.data( + name='inputs', shape=[b, c, h, w], dtype='float32') + rois = paddle.static.data( + name='rois', shape=[10, 4], dtype='float32') + rois_num = paddle.static.data( + name='rois_num', shape=[None], dtype='int32') + + output, _ = ops.roi_pool( + input=inputs, + rois=rois, + output_size=output_size, + rois_num=rois_num) + output_np, = self.get_static_graph_result( + feed={ + 'inputs': inputs_np, + 'rois': rois_np, + 'rois_num': rois_num_np + }, + fetch_list=[output], + with_lod=False) + + with self.dynamic_graph(): + inputs_dy = base.to_variable(inputs_np) + rois_dy = base.to_variable(rois_np) + rois_num_dy = base.to_variable(rois_num_np) + + output_dy, _ = ops.roi_pool( + input=inputs_dy, + rois=rois_dy, + output_size=output_size, + rois_num=rois_num_dy) + output_dy_np = output_dy.numpy() + + self.assertTrue(np.array_equal(output_np, output_dy_np)) + + def make_rois(self, h, w, rois_num, output_size): + rois = np.zeros((0, 4)).astype('float32') + for roi_num in rois_num: + roi = np.zeros((roi_num, 4)).astype('float32') + roi[:, 0] = np.random.randint(0, h - output_size[0], size=roi_num) + roi[:, 1] = np.random.randint(0, w - output_size[1], size=roi_num) + roi[:, 2] = np.random.randint(roi[:, 0] + output_size[0], h) + roi[:, 3] = np.random.randint(roi[:, 1] + output_size[1], w) + rois = np.vstack((rois, roi)) + return rois + + def test_roi_pool_error(self): + program = Program() + with program_guard(program): + inputs = paddle.static.data( + name='inputs', shape=[2, 12, 20, 20], dtype='float32') + rois = paddle.static.data( + name='data_error', shape=[10, 4], dtype='int32', lod_level=1) + self.assertRaises( + TypeError, + ops.roi_pool, + input=inputs, + rois=rois, + output_size=(7, 7)) + + if __name__ == '__main__': unittest.main()