未验证 提交 6cdf2c96 编写于 作者: B Bai Yifan 提交者: GitHub

mig deformable_conv to deform_conv2d (#27841)

* mig deformable_conv to deform_conv2d
上级 e388e603
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test import OpTest
......@@ -260,6 +261,7 @@ class TestWithGroup(TestModulatedDeformableConvOp):
class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_input():
paddle.enable_static()
input = [1, 3, 32, 32]
offset = fluid.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32')
......@@ -271,6 +273,7 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
self.assertRaises(TypeError, test_invalid_input)
def test_invalid_offset():
paddle.enable_static()
input = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='int32')
offset = fluid.data(
......@@ -283,5 +286,36 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
self.assertRaises(TypeError, test_invalid_offset)
class TestDeformConv2dAPI(unittest.TestCase):
def test_api(self):
def test_deform_conv2d_v1():
paddle.enable_static()
input = paddle.static.data(
name='input_v1', shape=[None, 3, 32, 32], dtype='float32')
offset = paddle.static.data(
name='offset_v1', shape=[None, 4, 32, 32], dtype='float32')
out = paddle.static.nn.deform_conv2d(
input, offset, None, num_filters=4, filter_size=1)
assert (out.shape == (-1, 4, 32, 32))
test_deform_conv2d_v1()
def test_deform_conv2d_v2():
paddle.enable_static()
input = paddle.static.data(
name='input_v2', shape=[None, 3, 32, 32], dtype='float32')
offset = paddle.static.data(
name='offset_v2', shape=[None, 4, 32, 32], dtype='float32')
mask = paddle.static.data(
name='mask_v2', shape=[None, 2, 32, 32], dtype='float32')
out = paddle.static.nn.deform_conv2d(
input, offset, mask, num_filters=4, filter_size=1)
assert (out.shape == (-1, 4, 32, 32))
test_deform_conv2d_v2()
if __name__ == '__main__':
unittest.main()
......@@ -63,7 +63,7 @@ class TestDirectory(unittest.TestCase):
'paddle.static.nn.conv3d', 'paddle.static.nn.conv3d_transpose',
'paddle.static.nn.create_parameter',
'paddle.static.nn.crf_decoding', 'paddle.static.nn.data_norm',
'paddle.static.nn.deformable_conv', 'paddle.static.nn.group_norm',
'paddle.static.nn.deform_conv2d', 'paddle.static.nn.group_norm',
'paddle.static.nn.instance_norm', 'paddle.static.nn.layer_norm',
'paddle.static.nn.multi_box_head', 'paddle.static.nn.nce',
'paddle.static.nn.prelu', 'paddle.static.nn.row_conv',
......
......@@ -25,7 +25,7 @@ __all__ = [
'create_parameter',
'crf_decoding',
'data_norm',
'deformable_conv',
'deform_conv2d',
'group_norm',
'instance_norm',
'layer_norm',
......@@ -39,6 +39,7 @@ __all__ = [
]
from .common import fc #DEFINE_ALIAS
from .common import deform_conv2d #DEFINE_ALIAS
from ...fluid.layers import batch_norm #DEFINE_ALIAS
from ...fluid.layers import bilinear_tensor_product #DEFINE_ALIAS
......@@ -50,7 +51,6 @@ from ...fluid.layers import conv3d_transpose #DEFINE_ALIAS
from ...fluid.layers import create_parameter #DEFINE_ALIAS
from ...fluid.layers import crf_decoding #DEFINE_ALIAS
from ...fluid.layers import data_norm #DEFINE_ALIAS
from ...fluid.layers import deformable_conv #DEFINE_ALIAS
from ...fluid.layers import group_norm #DEFINE_ALIAS
from ...fluid.layers import instance_norm #DEFINE_ALIAS
from ...fluid.layers import layer_norm #DEFINE_ALIAS
......
......@@ -15,7 +15,7 @@
import paddle
from paddle.fluid.framework import static_only
__all__ = ['fc']
__all__ = ['fc', 'deform_conv2d']
@static_only
......@@ -163,3 +163,180 @@ def fc(x,
bias_attr=bias_attr,
act=activation,
name=name)
@static_only
def deform_conv2d(x,
offset,
mask,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=1,
weight_attr=None,
bias_attr=None,
name=None):
"""
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
Deformable Convolution v2:
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k}
Deformable Convolution v1:
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k)}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location,
Which :math:`\Delta m_k` is one in deformable convolution v1. Please refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ and `Deformable Convolutional Networks <https://arxiv.org/abs/1703.06211>`_.
Example:
- Input:
X shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
Offset shape: :math:`(N, 2 * deformable\_groups * H_f * H_w, H_{in}, W_{in})`
Mask shape: :math:`(N, deformable\_groups * H_f * H_w, H_{in}, W_{in})`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
x (Tensor): The input image with [N, C, H, W] format. A Tensor with type
float32, float64.
offset (Tensor): The input coordinate offset of deformable convolution layer.
A Tensor with type float32, float64.
Mask (Tensor, Optional): The input mask of deformable convolution layer.
A Tensor with type float32, float64. It should be None when you use
deformable convolution v1.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the deformable conv layer. According to
grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
deformable_groups (int): The number of deformable group partitions.
Default: deformable_groups = 1.
im2col_step (int): Maximum number of images per im2col computation;
The total batch size should be devisable by this value or smaller
than this value; if you face out of memory problem, you can try
to use a smaller value here.
Default: im2col_step = 1.
weight_attr (ParamAttr, Optional): The parameter attribute for learnable parameters/weights
of deformable conv. If it is set to None or one attribute of ParamAttr,
deformable conv will create ParamAttr as weight_attr.
If the Initializer of the weight_attr is not set, the parameter is
initialized with :math:`Normal(0.0, std)`, and the
:math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool, Optional): The parameter attribute for the bias of
deformable conv layer. If it is set to False, no bias will be added
to the output units. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
name(str, Optional): For details, please refer to :ref:`api_guide_Name`.
Generally, no setting is required. Default: None.
Returns:
Tensor: The tensor storing the deformable convolution \
result. A Tensor with type float32, float64.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
#deformable conv v2:
import paddle
paddle.enable_static()
C_in, H_in, W_in = 3, 32, 32
filter_size, deformable_groups = 3, 1
data = paddle.static.data(name='data', shape=[None, C_in, H_in, W_in], dtype='float32')
offset = paddle.static.data(name='offset', shape=[None, 2*deformable_groups*filter_size**2, H_in, W_in], dtype='float32')
mask = paddle.static.data(name='mask', shape=[None, deformable_groups*filter_size**2, H_in, W_in], dtype='float32')
out = paddle.static.nn.deform_conv2d(x=data, offset=offset, mask=mask,
num_filters=2, filter_size=filter_size, padding=1)
#deformable conv v1:
import paddle
paddle.enable_static()
C_in, H_in, W_in = 3, 32, 32
filter_size, deformable_groups = 3, 1
data = paddle.static.data(name='data', shape=[None, C_in, H_in, W_in], dtype='float32')
offset = paddle.static.data(name='offset', shape=[None, 2*deformable_groups*filter_size**2, H_in, W_in], dtype='float32')
out = paddle.static.nn.deform_conv2d(x=data, offset=offset, mask=None,
num_filters=2, filter_size=filter_size, padding=1)
"""
if mask is None:
return paddle.fluid.layers.deformable_conv(
input=x,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
deformable_groups=deformable_groups,
im2col_step=im2col_step,
param_attr=weight_attr,
bias_attr=bias_attr,
modulated=False,
name=name)
else:
return paddle.fluid.layers.deformable_conv(
input=x,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
deformable_groups=deformable_groups,
im2col_step=im2col_step,
param_attr=weight_attr,
bias_attr=bias_attr,
modulated=True,
name=name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册