diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index da3b2b7a2066aacf5c25fec160615d5a3706d75a..29e0a8d6f02db323fc6befa9fca588247741ba24 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import paddle import unittest import numpy as np import paddle.fluid as fluid @@ -154,5 +155,197 @@ class TestLRNOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.lrn, in_w) +class TestLocalResponseNormFAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_3d_input(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + in_np1 = np.random.random([3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 1)) + + input1 = fluid.data( + name="input1", shape=[3, 40, 40], dtype="float32") + input2 = fluid.data( + name="input2", shape=[3, 40, 40], dtype="float32") + res1 = paddle.nn.functional.local_response_norm( + x=input1, size=5, data_format='NCL') + res2 = paddle.nn.functional.local_response_norm( + x=input2, size=5, data_format='NLC') + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input1": in_np1, + "input2": in_np2}, + fetch_list=[res1, res2]) + + fetches1_tran = np.transpose(fetches[1], (0, 2, 1)) + self.assertTrue(np.allclose(fetches[0], fetches1_tran)) + + def check_static_4d_input(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input1 = fluid.data( + name="input1", shape=[3, 3, 40, 40], dtype="float32") + input2 = fluid.data( + name="input2", shape=[3, 40, 40, 3], dtype="float32") + + res1 = paddle.nn.functional.local_response_norm( + x=input1, size=5, data_format='NCHW') + res2 = paddle.nn.functional.local_response_norm( + x=input2, size=5, data_format='NHWC') + + in_np1 = np.random.random([3, 3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 3, 1)) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input1": in_np1, + "input2": in_np2}, + fetch_list=[res1, res2]) + + fetches1_tran = np.transpose(fetches[1], (0, 3, 1, 2)) + self.assertTrue(np.allclose(fetches[0], fetches1_tran)) + + def check_static_5d_input(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input1 = fluid.data( + name="input1", shape=[3, 3, 3, 40, 40], dtype="float32") + input2 = fluid.data( + name="input2", shape=[3, 3, 40, 40, 3], dtype="float32") + res1 = paddle.nn.functional.local_response_norm( + x=input1, size=5, data_format='NCDHW') + res2 = paddle.nn.functional.local_response_norm( + x=input2, size=5, data_format='NDHWC') + + in_np1 = np.random.random([3, 3, 3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 3, 4, 1)) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input1": in_np1, + "input2": in_np2}, + fetch_list=[res1, res2]) + + fetches1_tran = np.transpose(fetches[1], (0, 4, 1, 2, 3)) + self.assertTrue(np.allclose(fetches[0], fetches1_tran)) + + def test_static(self): + for place in self.places: + self.check_static_3d_input(place=place) + self.check_static_4d_input(place=place) + self.check_static_5d_input(place=place) + + def check_dygraph_3d_input(self, place): + with fluid.dygraph.guard(place): + in_np1 = np.random.random([3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 1)) + + in1 = paddle.to_tensor(in_np1) + in2 = paddle.to_tensor(in_np2) + + res1 = paddle.nn.functional.local_response_norm( + x=in1, size=5, data_format='NCL') + res2 = paddle.nn.functional.local_response_norm( + x=in2, size=5, data_format='NLC') + + res2_tran = np.transpose(res2.numpy(), (0, 2, 1)) + self.assertTrue(np.allclose(res1.numpy(), res2_tran)) + + def check_dygraph_4d_input(self, place): + with fluid.dygraph.guard(place): + in_np1 = np.random.random([3, 3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 3, 1)) + + in1 = paddle.to_tensor(in_np1) + in2 = paddle.to_tensor(in_np2) + + res1 = paddle.nn.functional.local_response_norm( + x=in1, size=5, data_format='NCHW') + res2 = paddle.nn.functional.local_response_norm( + x=in2, size=5, data_format='NHWC') + + res2_tran = np.transpose(res2.numpy(), (0, 3, 1, 2)) + self.assertTrue(np.allclose(res1.numpy(), res2_tran)) + + def check_dygraph_5d_input(self, place): + with fluid.dygraph.guard(place): + in_np1 = np.random.random([3, 3, 3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 3, 4, 1)) + + in1 = paddle.to_tensor(in_np1) + in2 = paddle.to_tensor(in_np2) + + res1 = paddle.nn.functional.local_response_norm( + x=in1, size=5, data_format='NCDHW') + res2 = paddle.nn.functional.local_response_norm( + x=in2, size=5, data_format='NDHWC') + + res2_tran = np.transpose(res2.numpy(), (0, 4, 1, 2, 3)) + self.assertTrue(np.allclose(res1.numpy(), res2_tran)) + + def test_dygraph(self): + for place in self.places: + self.check_dygraph_3d_input(place) + self.check_dygraph_4d_input(place) + self.check_dygraph_5d_input(place) + + +class TestLocalResponseNormFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of lrn must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.nn.functional.local_response_norm(x1, size=5) + + self.assertRaises(TypeError, test_Variable) + + def test_datatype(): + x = fluid.data(name='x', shape=[3, 4, 5, 6], dtype="int32") + paddle.nn.functional.local_response_norm(x, size=5) + + self.assertRaises(TypeError, test_datatype) + + def test_dataformat(): + x = fluid.data(name='x', shape=[3, 4, 5, 6], dtype="float32") + paddle.nn.functional.local_response_norm( + x, size=5, data_format="NCTHW") + + self.assertRaises(ValueError, test_dataformat) + + def test_dim(): + x = fluid.data(name='x', shape=[3, 4], dtype="float32") + paddle.nn.functional.local_response_norm(x, size=5) + + self.assertRaises(ValueError, test_dim) + + +class TestLocalResponseNormCAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + in1 = paddle.rand(shape=(3, 3, 40, 40), dtype="float32") + in2 = paddle.transpose(in1, [0, 2, 3, 1]) + + m1 = paddle.nn.LocalResponseNorm(size=5, data_format='NCHW') + m2 = paddle.nn.LocalResponseNorm(size=5, data_format='NHWC') + + res1 = m1(in1) + res2 = m2(in2) + + res2_tran = np.transpose(res2.numpy(), (0, 3, 1, 2)) + self.assertTrue(np.allclose(res1.numpy(), res2_tran)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 184dd327764412f246d3d316209ac49115155fdf..b506b52ec9a4897d93f9f8587614c0a9d4856aff 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -146,6 +146,7 @@ from .layer.norm import InstanceNorm3d #DEFINE_ALIAS from .layer.norm import BatchNorm1d #DEFINE_ALIAS from .layer.norm import BatchNorm2d #DEFINE_ALIAS from .layer.norm import BatchNorm3d #DEFINE_ALIAS +from .layer.norm import LocalResponseNorm #DEFINE_ALIAS from .layer.rnn import RNNCellBase #DEFINE_ALIAS from .layer.rnn import SimpleRNNCell #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index b12fa9a6c936f6e0d6c58a22c50324cd5fcf6318..23acb3c50f229966cccb412464dbee502369454e 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -157,7 +157,7 @@ from .loss import ctc_loss #DEFINE_ALIAS from .norm import batch_norm #DEFINE_ALIAS from .norm import instance_norm #DEFINE_ALIAS from .norm import layer_norm #DEFINE_ALIAS -from .norm import lrn #DEFINE_ALIAS +from .norm import local_response_norm #DEFINE_ALIAS from .norm import normalize #DEFINE_ALIAS # from .norm import spectral_norm #DEFINE_ALIAS from .pooling import pool2d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 75d1b549b08d506326db9f0577d67fdf257123c0..c2e01cb82fbcb695334b59293659e4fdd8081b37 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -19,7 +19,6 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_type from ...fluid.layer_helper import LayerHelper from ...fluid.framework import in_dygraph_mode, core from ...framework import create_parameter -from ...fluid.layers import lrn #DEFINE_ALIAS from ...fluid.initializer import Constant from ...fluid.param_attr import ParamAttr from ...fluid import core, dygraph_utils @@ -29,7 +28,7 @@ __all__ = [ # 'data_norm', 'instance_norm', 'layer_norm', - 'lrn', + 'local_response_norm', 'normalize', # 'spectral_norm' ] @@ -403,3 +402,109 @@ def instance_norm(x, helper.append_op( type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs) return instance_norm_out + + +def local_response_norm(x, + size, + alpha=1e-4, + beta=0.75, + k=1., + data_format="NCHW", + name=None): + """ + Local Response Normalization performs a type of "lateral inhibition" by normalizing over local input regions. + For more information, please refer to `ImageNet Classification with Deep Convolutional Neural Networks `_ + + The formula is as follows: + + .. math:: + + Output(i, x, y) = Input(i, x, y) / \\left(k + \\alpha \\sum\\limits^{\\min(C-1, i + size/2)}_{j = \\max(0, i - size/2)}(Input(j, x, y))^2\\right)^{\\beta} + + In the above equation: + + - :math:`size` : The number of channels to sum over. + - :math:`k` : The offset (avoid being divided by 0). + - :math:`\\alpha` : The scaling parameter. + - :math:`\\beta` : The exponent parameter. + + + Args: + x (Tensor): The input 3-D/4-D/5-D tensor. The data type is float32. + size (int): The number of channels to sum over. + alpha (float, optional): The scaling parameter, positive. Default:1e-4 + beta (float, optional): The exponent, positive. Default:0.75 + k (float, optional): An offset, positive. Default: 1.0 + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: + If x is 3-D Tensor, the string could be `"NCL"` or `"NLC"` . When it is `"NCL"`, + the data is stored in the order of: `[batch_size, input_channels, feature_length]`. + If x is 4-D Tensor, the string could be `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, + the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. + If x is 5-D Tensor, the string could be `"NCDHW"`, `"NDHWC"` . When it is `"NCDHW"`, + the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. + name (str, optional): Name for the operation (optional, default is None). For more information, + please refer to :ref:`api_guide_Name`. + + Returns: + A tensor storing the transformation result with the same shape and data type as input. + + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.rand(shape=(3, 3, 112, 112), dtype="float32") + y = paddle.nn.functional.local_response_norm(x, size=5) + print(y.shape) # [3, 3, 112, 112] + """ + if not in_dygraph_mode(): + check_variable_and_dtype(x, 'x', ['float32'], 'local_response_norm') + if data_format not in ['NCL', 'NLC', 'NCHW', 'NHWC', 'NCDHW', 'NDHWC']: + raise ValueError( + "data_format should be in one of [NCL, NCHW, NCDHW, NLC, NHWC, NDHWC], " \ + "but got {}".format(data_format)) + + sizes = x.shape + dim = len(sizes) + if dim < 3: + raise ValueError( + 'Expected 3D or higher dimensionality input, but got {} dimensions'. + format(dim)) + + channel_last = True if data_format[-1] == "C" else False + + div = paddle.unsqueeze(paddle.multiply(x, x), axis=1) + if not channel_last: + pad4d_shape = [0, 0, size // 2, (size - 1) // 2] + pool2d_shape = (size, 1) + reshape_shape = [sizes[0], 1, sizes[1], sizes[2], -1] + pad5d_shape = [0, 0, 0, 0, size // 2, (size - 1) // 2] + pool3d_shape = (size, 1, 1) + else: + pad4d_shape = [size // 2, (size - 1) // 2, 0, 0] + pool2d_shape = (1, size) + reshape_shape = [sizes[0], 1, sizes[1], -1, sizes[-1]] + pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0] + pool3d_shape = (1, 1, size) + + if dim == 3: + div = paddle.nn.functional.pad(div, pad=pad4d_shape) + div = paddle.nn.functional.avg_pool2d( + div, kernel_size=pool2d_shape, stride=1) + div = paddle.squeeze(div, axis=1) + else: + div = paddle.reshape(div, shape=reshape_shape) + div = paddle.nn.functional.pad(div, + pad=pad5d_shape, + data_format='NCDHW') + div = paddle.nn.functional.avg_pool3d( + div, kernel_size=pool3d_shape, stride=1) + div = paddle.reshape(paddle.squeeze(div, axis=1), sizes) + + div = paddle.scale(div, scale=alpha, bias=k) + div = paddle.pow(div, beta) + res = paddle.divide(x, div, name=name) + return res diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index d5abaa4de5ef2378d4b6cce268d717a5143b1a19..8a234e779e27985e36baeede2f8645835e9debe9 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -103,6 +103,7 @@ from .norm import GroupNorm #DEFINE_ALIAS from .norm import LayerNorm #DEFINE_ALIAS from .norm import SpectralNorm #DEFINE_ALIAS from .norm import InstanceNorm #DEFINE_ALIAS +from .norm import LocalResponseNorm #DEFINE_ALIAS # from .rnn import RNNCell #DEFINE_ALIAS # from .rnn import GRUCell #DEFINE_ALIAS # from .rnn import LSTMCell #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 2000fbf388f88d1da7119402104706a433cebf06..50f7904c417e912d1c470768251b7739a1cff07c 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -51,11 +51,22 @@ import numpy as np import numbers import warnings from ...fluid.dygraph.base import no_grad +from .. import functional as F __all__ = [ - 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm', - 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d', - 'InstanceNorm2d', 'InstanceNorm3d', 'SyncBatchNorm' + 'BatchNorm', + 'GroupNorm', + 'LayerNorm', + 'SpectralNorm', + 'InstanceNorm', + 'BatchNorm1d', + 'BatchNorm2d', + 'BatchNorm3d', + 'InstanceNorm1d', + 'InstanceNorm2d', + 'InstanceNorm3d', + 'SyncBatchNorm', + 'LocalResponseNorm', ] @@ -1147,3 +1158,63 @@ class SyncBatchNorm(_BatchNormBase): cls.convert_sync_batchnorm(sublayer)) del layer return layer_output + + +class LocalResponseNorm(layers.Layer): + """ + Local Response Normalization performs a type of "lateral inhibition" by normalizing over local input regions. + For more information, please refer to `ImageNet Classification with Deep Convolutional Neural Networks `_ + + See more details in :ref:`api_paddle_nn_functional_local_response_norm` . + + Parameters: + size (int): The number of channels to sum over. + alpha (float, optional): The scaling parameter, positive. Default:1e-4 + beta (float, optional): The exponent, positive. Default:0.75 + k (float, optional): An offset, positive. Default: 1.0 + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: + If input is 3-D Tensor, the string could be `"NCL"` or `"NLC"` . When it is `"NCL"`, + the data is stored in the order of: `[batch_size, input_channels, feature_length]`. + If input is 4-D Tensor, the string could be `"NCHW"`, `"NHWC"`. When it is `"NCHW"`, + the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. + If input is 5-D Tensor, the string could be `"NCDHW"`, `"NDHWC"` . When it is `"NCDHW"`, + the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. + name (str, optional): Name for the operation (optional, default is None). For more information, + please refer to :ref:`api_guide_Name`. + + Shape: + - input: 3-D/4-D/5-D tensor. + - output: 3-D/4-D/5-D tensor, the same shape as input. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.rand(shape=(3, 3, 112, 112), dtype="float32") + m = paddle.nn.LocalResponseNorm(size=5) + y = m(x) + print(y.shape) # [3, 3, 112, 112] + """ + + def __init__(self, + size, + alpha=0.0001, + beta=0.75, + k=1.0, + data_format="NCHW", + name=None): + super(LocalResponseNorm, self).__init__() + self.size = size + self.alpha = alpha + self.beta = beta + self.k = k + self.data_format = data_format + self.name = name + + def forward(self, input): + out = F.local_response_norm(input, self.size, self.alpha, self.beta, + self.k, self.data_format, self.name) + return out