未验证 提交 ed331ba2 编写于 作者: F Feng Ni 提交者: GitHub

remove ppdet ops roi pool align, add vision roi pool align (#6154)

上级 170fa7d2
......@@ -29,7 +29,7 @@ class RoIAlign(object):
RoI Align module
For more details, please refer to the document of roi_align in
in ppdet/modeing/ops.py
in https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/vision/ops.py
Args:
resolution (int): The output size, default 14
......@@ -76,12 +76,12 @@ class RoIAlign(object):
def __call__(self, feats, roi, rois_num):
roi = paddle.concat(roi) if len(roi) > 1 else roi[0]
if len(feats) == 1:
rois_feat = ops.roi_align(
feats[self.start_level],
roi,
self.resolution,
self.spatial_scale[0],
rois_num=rois_num,
rois_feat = paddle.vision.ops.roi_align(
x=feats[self.start_level],
boxes=roi,
boxes_num=rois_num,
output_size=self.resolution,
spatial_scale=self.spatial_scale[0],
aligned=self.aligned)
else:
offset = 2
......@@ -96,13 +96,13 @@ class RoIAlign(object):
rois_num=rois_num)
rois_feat_list = []
for lvl in range(self.start_level, self.end_level + 1):
roi_feat = ops.roi_align(
feats[lvl],
rois_dist[lvl],
self.resolution,
self.spatial_scale[lvl],
roi_feat = paddle.vision.ops.roi_align(
x=feats[lvl],
boxes=rois_dist[lvl],
boxes_num=rois_num_dist[lvl],
output_size=self.resolution,
spatial_scale=self.spatial_scale[lvl],
sampling_ratio=self.sampling_ratio,
rois_num=rois_num_dist[lvl],
aligned=self.aligned)
rois_feat_list.append(roi_feat)
rois_feat_shuffle = paddle.concat(rois_feat_list)
......
......@@ -23,8 +23,6 @@ from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype
__all__ = [
'roi_pool',
'roi_align',
'prior_box',
'generate_proposals',
'box_coder',
......@@ -117,215 +115,6 @@ def batch_norm(ch,
return norm_layer
@paddle.jit.not_to_static
def roi_pool(input,
rois,
output_size,
spatial_scale=1.0,
rois_num=None,
name=None):
"""
This operator implements the roi_pooling layer.
Region of interest pooling (also known as RoI pooling) is to perform max pooling on inputs of nonuniform sizes to obtain fixed-size feature maps (e.g. 7*7).
The operator has three steps:
1. Dividing each region proposal into equal-sized sections with output_size(h, w);
2. Finding the largest value in each section;
3. Copying these max values to the output buffer.
For more information, please refer to https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
Args:
input (Tensor): Input feature, 4D-Tensor with the shape of [N,C,H,W],
where N is the batch size, C is the input channel, H is Height, W is weight.
The data type is float32 or float64.
rois (Tensor): ROIs (Regions of Interest) to pool over.
2D-Tensor or 2D-LoDTensor with the shape of [num_rois,4], the lod level is 1.
Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates,
and (x2, y2) is the bottom right coordinates.
output_size (int or tuple[int, int]): The pooled output size(h, w), data type is int32. If int, h and w are both equal to output_size.
spatial_scale (float, optional): Multiplicative spatial scale factor to translate ROI coords from their input scale to the scale used when pooling. Default: 1.0
rois_num (Tensor): The number of RoIs in each image. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The pooled feature, 4D-Tensor with the shape of [num_rois, C, output_size[0], output_size[1]].
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
x = paddle.static.data(
name='data', shape=[None, 256, 32, 32], dtype='float32')
rois = paddle.static.data(
name='rois', shape=[None, 4], dtype='float32')
rois_num = paddle.static.data(name='rois_num', shape=[None], dtype='int32')
pool_out = ops.roi_pool(
input=x,
rois=rois,
output_size=(1, 1),
spatial_scale=1.0,
rois_num=rois_num)
"""
check_type(output_size, 'output_size', (int, tuple), 'roi_pool')
if isinstance(output_size, int):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
pool_out, argmaxes = _C_ops.roi_pool(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale)
return pool_out, argmaxes
else:
check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool')
check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool')
helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_pool",
inputs=inputs,
outputs={"Out": pool_out,
"Argmax": argmaxes},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale
})
return pool_out, argmaxes
@paddle.jit.not_to_static
def roi_align(input,
rois,
output_size,
spatial_scale=1.0,
sampling_ratio=-1,
rois_num=None,
aligned=True,
name=None):
"""
Region of interest align (also known as RoI align) is to perform
bilinear interpolation on inputs of nonuniform sizes to obtain
fixed-size feature maps (e.g. 7*7)
Dividing each region proposal into equal-sized sections with
the pooled_width and pooled_height. Location remains the origin
result.
In each ROI bin, the value of the four regularly sampled locations
are computed directly through bilinear interpolation. The output is
the mean of four locations.
Thus avoid the misaligned problem.
Args:
input (Tensor): Input feature, 4D-Tensor with the shape of [N,C,H,W],
where N is the batch size, C is the input channel, H is Height, W is weight.
The data type is float32 or float64.
rois (Tensor): ROIs (Regions of Interest) to pool over.It should be
a 2-D Tensor or 2-D LoDTensor of shape (num_rois, 4), the lod level is 1.
The data type is float32 or float64. Given as [[x1, y1, x2, y2], ...],
(x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates.
output_size (int or tuple[int, int]): The pooled output size(h, w), data type is int32. If int, h and w are both equal to output_size.
spatial_scale (float32, optional): Multiplicative spatial scale factor to translate ROI coords
from their input scale to the scale used when pooling. Default: 1.0
sampling_ratio(int32, optional): number of sampling points in the interpolation grid.
If <=0, then grid points are adaptive to roi_width and pooled_w, likewise for height. Default: -1
rois_num (Tensor): The number of RoIs in each image. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor:
Output: The output of ROIAlignOp is a 4-D tensor with shape (num_rois, channels, pooled_h, pooled_w). The data type is float32 or float64.
Examples:
.. code-block:: python
import paddle
from ppdet.modeling import ops
paddle.enable_static()
x = paddle.static.data(
name='data', shape=[None, 256, 32, 32], dtype='float32')
rois = paddle.static.data(
name='rois', shape=[None, 4], dtype='float32')
rois_num = paddle.static.data(name='rois_num', shape=[None], dtype='int32')
align_out = ops.roi_align(input=x,
rois=rois,
output_size=(7, 7),
spatial_scale=0.5,
sampling_ratio=-1,
rois_num=rois_num)
"""
check_type(output_size, 'output_size', (int, tuple), 'roi_align')
if isinstance(output_size, int):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
if in_dynamic_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
align_out = _C_ops.roi_align(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale,
"sampling_ratio", sampling_ratio, "aligned", aligned)
return align_out
else:
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'roi_align')
check_variable_and_dtype(rois, 'rois', ['float32', 'float64'],
'roi_align')
helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype()
align_out = helper.create_variable_for_type_inference(dtype)
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_align",
inputs=inputs,
outputs={"Out": align_out},
attrs={
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"spatial_scale": spatial_scale,
"sampling_ratio": sampling_ratio,
"aligned": aligned,
})
return align_out
@paddle.jit.not_to_static
def distribute_fpn_proposals(fpn_rois,
min_level,
......
......@@ -128,11 +128,11 @@ class TestROIAlign(LayerTest):
rois_num = paddle.static.data(
name='rois_num', shape=[None], dtype='int32')
output = ops.roi_align(
input=inputs,
rois=rois,
output_size=output_size,
rois_num=rois_num)
output = paddle.vision.ops.roi_align(
x=inputs,
boxes=rois,
boxes_num=rois_num,
output_size=output_size)
output_np, = self.get_static_graph_result(
feed={
'inputs': inputs_np,
......@@ -147,11 +147,11 @@ class TestROIAlign(LayerTest):
rois_dy = paddle.to_tensor(rois_np)
rois_num_dy = paddle.to_tensor(rois_num_np)
output_dy = ops.roi_align(
input=inputs_dy,
rois=rois_dy,
output_size=output_size,
rois_num=rois_num_dy)
output_dy = paddle.vision.ops.roi_align(
x=inputs_dy,
boxes=rois_dy,
boxes_num=rois_num_dy,
output_size=output_size)
output_dy_np = output_dy.numpy()
self.assertTrue(np.array_equal(output_np, output_dy_np))
......@@ -164,7 +164,7 @@ class TestROIAlign(LayerTest):
name='data_error', shape=[10, 4], dtype='int32', lod_level=1)
self.assertRaises(
TypeError,
ops.roi_align,
paddle.vision.ops.roi_align,
input=inputs,
rois=rois,
output_size=(7, 7))
......@@ -188,11 +188,11 @@ class TestROIPool(LayerTest):
rois_num = paddle.static.data(
name='rois_num', shape=[None], dtype='int32')
output, _ = ops.roi_pool(
input=inputs,
rois=rois,
output_size=output_size,
rois_num=rois_num)
output = paddle.vision.ops.roi_pool(
x=inputs,
boxes=rois,
boxes_num=rois_num,
output_size=output_size)
output_np, = self.get_static_graph_result(
feed={
'inputs': inputs_np,
......@@ -207,11 +207,11 @@ class TestROIPool(LayerTest):
rois_dy = paddle.to_tensor(rois_np)
rois_num_dy = paddle.to_tensor(rois_num_np)
output_dy, _ = ops.roi_pool(
input=inputs_dy,
rois=rois_dy,
output_size=output_size,
rois_num=rois_num_dy)
output_dy = paddle.vision.ops.roi_pool(
x=inputs_dy,
boxes=rois_dy,
boxes_num=rois_num_dy,
output_size=output_size)
output_dy_np = output_dy.numpy()
self.assertTrue(np.array_equal(output_np, output_dy_np))
......@@ -224,7 +224,7 @@ class TestROIPool(LayerTest):
name='data_error', shape=[10, 4], dtype='int32', lod_level=1)
self.assertRaises(
TypeError,
ops.roi_pool,
paddle.vision.ops.roi_pool,
input=inputs,
rois=rois,
output_size=(7, 7))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册