diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py index e685d7b5f53b0b373eae9efee040db54479e6661..eed637b1d5da1bc1c68fa2bdaa0ae9314a737121 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest @@ -260,6 +261,7 @@ class TestWithGroup(TestModulatedDeformableConvOp): class TestModulatedDeformableConvInvalidInput(unittest.TestCase): def test_error(self): def test_invalid_input(): + paddle.enable_static() input = [1, 3, 32, 32] offset = fluid.data( name='offset', shape=[None, 3, 32, 32], dtype='float32') @@ -271,6 +273,7 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase): self.assertRaises(TypeError, test_invalid_input) def test_invalid_offset(): + paddle.enable_static() input = fluid.data( name='input', shape=[None, 3, 32, 32], dtype='int32') offset = fluid.data( @@ -283,5 +286,36 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase): self.assertRaises(TypeError, test_invalid_offset) +class TestDeformConv2dAPI(unittest.TestCase): + def test_api(self): + def test_deform_conv2d_v1(): + paddle.enable_static() + input = paddle.static.data( + name='input_v1', shape=[None, 3, 32, 32], dtype='float32') + offset = paddle.static.data( + name='offset_v1', shape=[None, 4, 32, 32], dtype='float32') + out = paddle.static.nn.deform_conv2d( + input, offset, None, num_filters=4, filter_size=1) + + assert (out.shape == (-1, 4, 32, 32)) + + test_deform_conv2d_v1() + + def test_deform_conv2d_v2(): + paddle.enable_static() + input = paddle.static.data( + name='input_v2', shape=[None, 3, 32, 32], dtype='float32') + offset = paddle.static.data( + name='offset_v2', shape=[None, 4, 32, 32], dtype='float32') + mask = paddle.static.data( + name='mask_v2', shape=[None, 2, 32, 32], dtype='float32') + out = paddle.static.nn.deform_conv2d( + input, offset, mask, num_filters=4, filter_size=1) + + assert (out.shape == (-1, 4, 32, 32)) + + test_deform_conv2d_v2() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_directory_migration.py b/python/paddle/fluid/tests/unittests/test_directory_migration.py index fd014f3b4ecaf481c52738d3926718a67c329adc..28232e9ba4dc0b4afbdf7c6731e0bf68197daa7b 100644 --- a/python/paddle/fluid/tests/unittests/test_directory_migration.py +++ b/python/paddle/fluid/tests/unittests/test_directory_migration.py @@ -63,7 +63,7 @@ class TestDirectory(unittest.TestCase): 'paddle.static.nn.conv3d', 'paddle.static.nn.conv3d_transpose', 'paddle.static.nn.create_parameter', 'paddle.static.nn.crf_decoding', 'paddle.static.nn.data_norm', - 'paddle.static.nn.deformable_conv', 'paddle.static.nn.group_norm', + 'paddle.static.nn.deform_conv2d', 'paddle.static.nn.group_norm', 'paddle.static.nn.instance_norm', 'paddle.static.nn.layer_norm', 'paddle.static.nn.multi_box_head', 'paddle.static.nn.nce', 'paddle.static.nn.prelu', 'paddle.static.nn.row_conv', diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index cd089432b1ca37dcff3c70f6b48834487f285926..3ae65e879f7235456f7a75f3d9547a27bc530ee4 100644 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -25,7 +25,7 @@ __all__ = [ 'create_parameter', 'crf_decoding', 'data_norm', - 'deformable_conv', + 'deform_conv2d', 'group_norm', 'instance_norm', 'layer_norm', @@ -39,6 +39,7 @@ __all__ = [ ] from .common import fc #DEFINE_ALIAS +from .common import deform_conv2d #DEFINE_ALIAS from ...fluid.layers import batch_norm #DEFINE_ALIAS from ...fluid.layers import bilinear_tensor_product #DEFINE_ALIAS @@ -50,7 +51,6 @@ from ...fluid.layers import conv3d_transpose #DEFINE_ALIAS from ...fluid.layers import create_parameter #DEFINE_ALIAS from ...fluid.layers import crf_decoding #DEFINE_ALIAS from ...fluid.layers import data_norm #DEFINE_ALIAS -from ...fluid.layers import deformable_conv #DEFINE_ALIAS from ...fluid.layers import group_norm #DEFINE_ALIAS from ...fluid.layers import instance_norm #DEFINE_ALIAS from ...fluid.layers import layer_norm #DEFINE_ALIAS diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 59ffacdaebed51d7fa4855e850df331c54b60290..93a603f4770a79766b043c800eb642c80115a543 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -15,7 +15,7 @@ import paddle from paddle.fluid.framework import static_only -__all__ = ['fc'] +__all__ = ['fc', 'deform_conv2d'] @static_only @@ -163,3 +163,180 @@ def fc(x, bias_attr=bias_attr, act=activation, name=name) + + +@static_only +def deform_conv2d(x, + offset, + mask, + num_filters, + filter_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=1, + weight_attr=None, + bias_attr=None, + name=None): + """ + + Compute 2-D deformable convolution on 4-D input. + Given input image x, output feature map y, the deformable convolution operation can be expressed as follow: + + + Deformable Convolution v2: + + .. math:: + + y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k} + + Deformable Convolution v1: + + .. math:: + + y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k)} + + Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, + Which :math:`\Delta m_k` is one in deformable convolution v1. Please refer to `Deformable ConvNets v2: More Deformable, Better Results + `_ and `Deformable Convolutional Networks `_. + + Example: + - Input: + + X shape: :math:`(N, C_{in}, H_{in}, W_{in})` + + Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` + + Offset shape: :math:`(N, 2 * deformable\_groups * H_f * H_w, H_{in}, W_{in})` + + Mask shape: :math:`(N, deformable\_groups * H_f * H_w, H_{in}, W_{in})` + + - Output: + + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + + Where + + .. math:: + + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 + + Args: + x (Tensor): The input image with [N, C, H, W] format. A Tensor with type + float32, float64. + offset (Tensor): The input coordinate offset of deformable convolution layer. + A Tensor with type float32, float64. + Mask (Tensor, Optional): The input mask of deformable convolution layer. + A Tensor with type float32, float64. It should be None when you use + deformable convolution v1. + num_filters(int): The number of filter. It is as same as the output + image channel. + filter_size (int|tuple): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + padding (int|tuple): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: padding = 0. + dilation (int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. + groups (int): The groups number of the deformable conv layer. According to + grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1. + deformable_groups (int): The number of deformable group partitions. + Default: deformable_groups = 1. + im2col_step (int): Maximum number of images per im2col computation; + The total batch size should be devisable by this value or smaller + than this value; if you face out of memory problem, you can try + to use a smaller value here. + Default: im2col_step = 1. + weight_attr (ParamAttr, Optional): The parameter attribute for learnable parameters/weights + of deformable conv. If it is set to None or one attribute of ParamAttr, + deformable conv will create ParamAttr as weight_attr. + If the Initializer of the weight_attr is not set, the parameter is + initialized with :math:`Normal(0.0, std)`, and the + :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr|bool, Optional): The parameter attribute for the bias of + deformable conv layer. If it is set to False, no bias will be added + to the output units. If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + name(str, Optional): For details, please refer to :ref:`api_guide_Name`. + Generally, no setting is required. Default: None. + Returns: + Tensor: The tensor storing the deformable convolution \ + result. A Tensor with type float32, float64. + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + Examples: + .. code-block:: python + + #deformable conv v2: + + import paddle + paddle.enable_static() + + C_in, H_in, W_in = 3, 32, 32 + filter_size, deformable_groups = 3, 1 + data = paddle.static.data(name='data', shape=[None, C_in, H_in, W_in], dtype='float32') + offset = paddle.static.data(name='offset', shape=[None, 2*deformable_groups*filter_size**2, H_in, W_in], dtype='float32') + mask = paddle.static.data(name='mask', shape=[None, deformable_groups*filter_size**2, H_in, W_in], dtype='float32') + out = paddle.static.nn.deform_conv2d(x=data, offset=offset, mask=mask, + num_filters=2, filter_size=filter_size, padding=1) + + #deformable conv v1: + + import paddle + paddle.enable_static() + + C_in, H_in, W_in = 3, 32, 32 + filter_size, deformable_groups = 3, 1 + data = paddle.static.data(name='data', shape=[None, C_in, H_in, W_in], dtype='float32') + offset = paddle.static.data(name='offset', shape=[None, 2*deformable_groups*filter_size**2, H_in, W_in], dtype='float32') + out = paddle.static.nn.deform_conv2d(x=data, offset=offset, mask=None, + num_filters=2, filter_size=filter_size, padding=1) + """ + + if mask is None: + return paddle.fluid.layers.deformable_conv( + input=x, + offset=offset, + mask=mask, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + deformable_groups=deformable_groups, + im2col_step=im2col_step, + param_attr=weight_attr, + bias_attr=bias_attr, + modulated=False, + name=name) + else: + return paddle.fluid.layers.deformable_conv( + input=x, + offset=offset, + mask=mask, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + deformable_groups=deformable_groups, + im2col_step=im2col_step, + param_attr=weight_attr, + bias_attr=bias_attr, + modulated=True, + name=name)