未验证 提交 74092635 编写于 作者: H huangjun12 提交者: GitHub

Add local_response_norm in nn.functional and nn.layer (#27725)

* add local_response_norm in nn.functional and nn.layer, test=develop

* update layers to functional api, test=develop

* fix ci coverage, test=develop

* fix unittests bug, test=develop
上级 1b121773
......@@ -14,6 +14,7 @@
from __future__ import print_function
import paddle
import unittest
import numpy as np
import paddle.fluid as fluid
......@@ -154,5 +155,197 @@ class TestLRNOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.lrn, in_w)
class TestLocalResponseNormFAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def check_static_3d_input(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
in_np1 = np.random.random([3, 40, 40]).astype("float32")
in_np2 = np.transpose(in_np1, (0, 2, 1))
input1 = fluid.data(
name="input1", shape=[3, 40, 40], dtype="float32")
input2 = fluid.data(
name="input2", shape=[3, 40, 40], dtype="float32")
res1 = paddle.nn.functional.local_response_norm(
x=input1, size=5, data_format='NCL')
res2 = paddle.nn.functional.local_response_norm(
x=input2, size=5, data_format='NLC')
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input1": in_np1,
"input2": in_np2},
fetch_list=[res1, res2])
fetches1_tran = np.transpose(fetches[1], (0, 2, 1))
self.assertTrue(np.allclose(fetches[0], fetches1_tran))
def check_static_4d_input(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input1 = fluid.data(
name="input1", shape=[3, 3, 40, 40], dtype="float32")
input2 = fluid.data(
name="input2", shape=[3, 40, 40, 3], dtype="float32")
res1 = paddle.nn.functional.local_response_norm(
x=input1, size=5, data_format='NCHW')
res2 = paddle.nn.functional.local_response_norm(
x=input2, size=5, data_format='NHWC')
in_np1 = np.random.random([3, 3, 40, 40]).astype("float32")
in_np2 = np.transpose(in_np1, (0, 2, 3, 1))
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input1": in_np1,
"input2": in_np2},
fetch_list=[res1, res2])
fetches1_tran = np.transpose(fetches[1], (0, 3, 1, 2))
self.assertTrue(np.allclose(fetches[0], fetches1_tran))
def check_static_5d_input(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input1 = fluid.data(
name="input1", shape=[3, 3, 3, 40, 40], dtype="float32")
input2 = fluid.data(
name="input2", shape=[3, 3, 40, 40, 3], dtype="float32")
res1 = paddle.nn.functional.local_response_norm(
x=input1, size=5, data_format='NCDHW')
res2 = paddle.nn.functional.local_response_norm(
x=input2, size=5, data_format='NDHWC')
in_np1 = np.random.random([3, 3, 3, 40, 40]).astype("float32")
in_np2 = np.transpose(in_np1, (0, 2, 3, 4, 1))
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input1": in_np1,
"input2": in_np2},
fetch_list=[res1, res2])
fetches1_tran = np.transpose(fetches[1], (0, 4, 1, 2, 3))
self.assertTrue(np.allclose(fetches[0], fetches1_tran))
def test_static(self):
for place in self.places:
self.check_static_3d_input(place=place)
self.check_static_4d_input(place=place)
self.check_static_5d_input(place=place)
def check_dygraph_3d_input(self, place):
with fluid.dygraph.guard(place):
in_np1 = np.random.random([3, 40, 40]).astype("float32")
in_np2 = np.transpose(in_np1, (0, 2, 1))
in1 = paddle.to_tensor(in_np1)
in2 = paddle.to_tensor(in_np2)
res1 = paddle.nn.functional.local_response_norm(
x=in1, size=5, data_format='NCL')
res2 = paddle.nn.functional.local_response_norm(
x=in2, size=5, data_format='NLC')
res2_tran = np.transpose(res2.numpy(), (0, 2, 1))
self.assertTrue(np.allclose(res1.numpy(), res2_tran))
def check_dygraph_4d_input(self, place):
with fluid.dygraph.guard(place):
in_np1 = np.random.random([3, 3, 40, 40]).astype("float32")
in_np2 = np.transpose(in_np1, (0, 2, 3, 1))
in1 = paddle.to_tensor(in_np1)
in2 = paddle.to_tensor(in_np2)
res1 = paddle.nn.functional.local_response_norm(
x=in1, size=5, data_format='NCHW')
res2 = paddle.nn.functional.local_response_norm(
x=in2, size=5, data_format='NHWC')
res2_tran = np.transpose(res2.numpy(), (0, 3, 1, 2))
self.assertTrue(np.allclose(res1.numpy(), res2_tran))
def check_dygraph_5d_input(self, place):
with fluid.dygraph.guard(place):
in_np1 = np.random.random([3, 3, 3, 40, 40]).astype("float32")
in_np2 = np.transpose(in_np1, (0, 2, 3, 4, 1))
in1 = paddle.to_tensor(in_np1)
in2 = paddle.to_tensor(in_np2)
res1 = paddle.nn.functional.local_response_norm(
x=in1, size=5, data_format='NCDHW')
res2 = paddle.nn.functional.local_response_norm(
x=in2, size=5, data_format='NDHWC')
res2_tran = np.transpose(res2.numpy(), (0, 4, 1, 2, 3))
self.assertTrue(np.allclose(res1.numpy(), res2_tran))
def test_dygraph(self):
for place in self.places:
self.check_dygraph_3d_input(place)
self.check_dygraph_4d_input(place)
self.check_dygraph_5d_input(place)
class TestLocalResponseNormFAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
# the input of lrn must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace())
paddle.nn.functional.local_response_norm(x1, size=5)
self.assertRaises(TypeError, test_Variable)
def test_datatype():
x = fluid.data(name='x', shape=[3, 4, 5, 6], dtype="int32")
paddle.nn.functional.local_response_norm(x, size=5)
self.assertRaises(TypeError, test_datatype)
def test_dataformat():
x = fluid.data(name='x', shape=[3, 4, 5, 6], dtype="float32")
paddle.nn.functional.local_response_norm(
x, size=5, data_format="NCTHW")
self.assertRaises(ValueError, test_dataformat)
def test_dim():
x = fluid.data(name='x', shape=[3, 4], dtype="float32")
paddle.nn.functional.local_response_norm(x, size=5)
self.assertRaises(ValueError, test_dim)
class TestLocalResponseNormCAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
in1 = paddle.rand(shape=(3, 3, 40, 40), dtype="float32")
in2 = paddle.transpose(in1, [0, 2, 3, 1])
m1 = paddle.nn.LocalResponseNorm(size=5, data_format='NCHW')
m2 = paddle.nn.LocalResponseNorm(size=5, data_format='NHWC')
res1 = m1(in1)
res2 = m2(in2)
res2_tran = np.transpose(res2.numpy(), (0, 3, 1, 2))
self.assertTrue(np.allclose(res1.numpy(), res2_tran))
if __name__ == "__main__":
unittest.main()
......@@ -146,6 +146,7 @@ from .layer.norm import InstanceNorm3d #DEFINE_ALIAS
from .layer.norm import BatchNorm1d #DEFINE_ALIAS
from .layer.norm import BatchNorm2d #DEFINE_ALIAS
from .layer.norm import BatchNorm3d #DEFINE_ALIAS
from .layer.norm import LocalResponseNorm #DEFINE_ALIAS
from .layer.rnn import RNNCellBase #DEFINE_ALIAS
from .layer.rnn import SimpleRNNCell #DEFINE_ALIAS
......
......@@ -157,7 +157,7 @@ from .loss import ctc_loss #DEFINE_ALIAS
from .norm import batch_norm #DEFINE_ALIAS
from .norm import instance_norm #DEFINE_ALIAS
from .norm import layer_norm #DEFINE_ALIAS
from .norm import lrn #DEFINE_ALIAS
from .norm import local_response_norm #DEFINE_ALIAS
from .norm import normalize #DEFINE_ALIAS
# from .norm import spectral_norm #DEFINE_ALIAS
from .pooling import pool2d #DEFINE_ALIAS
......
......@@ -19,7 +19,6 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_type
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode, core
from ...framework import create_parameter
from ...fluid.layers import lrn #DEFINE_ALIAS
from ...fluid.initializer import Constant
from ...fluid.param_attr import ParamAttr
from ...fluid import core, dygraph_utils
......@@ -29,7 +28,7 @@ __all__ = [
# 'data_norm',
'instance_norm',
'layer_norm',
'lrn',
'local_response_norm',
'normalize',
# 'spectral_norm'
]
......@@ -403,3 +402,109 @@ def instance_norm(x,
helper.append_op(
type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return instance_norm_out
def local_response_norm(x,
size,
alpha=1e-4,
beta=0.75,
k=1.,
data_format="NCHW",
name=None):
"""
Local Response Normalization performs a type of "lateral inhibition" by normalizing over local input regions.
For more information, please refer to `ImageNet Classification with Deep Convolutional Neural Networks <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_
The formula is as follows:
.. math::
Output(i, x, y) = Input(i, x, y) / \\left(k + \\alpha \\sum\\limits^{\\min(C-1, i + size/2)}_{j = \\max(0, i - size/2)}(Input(j, x, y))^2\\right)^{\\beta}
In the above equation:
- :math:`size` : The number of channels to sum over.
- :math:`k` : The offset (avoid being divided by 0).
- :math:`\\alpha` : The scaling parameter.
- :math:`\\beta` : The exponent parameter.
Args:
x (Tensor): The input 3-D/4-D/5-D tensor. The data type is float32.
size (int): The number of channels to sum over.
alpha (float, optional): The scaling parameter, positive. Default:1e-4
beta (float, optional): The exponent, positive. Default:0.75
k (float, optional): An offset, positive. Default: 1.0
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:
If x is 3-D Tensor, the string could be `"NCL"` or `"NLC"` . When it is `"NCL"`,
the data is stored in the order of: `[batch_size, input_channels, feature_length]`.
If x is 4-D Tensor, the string could be `"NCHW"`, `"NHWC"`. When it is `"NCHW"`,
the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`.
If x is 5-D Tensor, the string could be `"NCDHW"`, `"NDHWC"` . When it is `"NCDHW"`,
the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
name (str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.
Returns:
A tensor storing the transformation result with the same shape and data type as input.
Examples:
.. code-block:: python
import paddle
x = paddle.rand(shape=(3, 3, 112, 112), dtype="float32")
y = paddle.nn.functional.local_response_norm(x, size=5)
print(y.shape) # [3, 3, 112, 112]
"""
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32'], 'local_response_norm')
if data_format not in ['NCL', 'NLC', 'NCHW', 'NHWC', 'NCDHW', 'NDHWC']:
raise ValueError(
"data_format should be in one of [NCL, NCHW, NCDHW, NLC, NHWC, NDHWC], " \
"but got {}".format(data_format))
sizes = x.shape
dim = len(sizes)
if dim < 3:
raise ValueError(
'Expected 3D or higher dimensionality input, but got {} dimensions'.
format(dim))
channel_last = True if data_format[-1] == "C" else False
div = paddle.unsqueeze(paddle.multiply(x, x), axis=1)
if not channel_last:
pad4d_shape = [0, 0, size // 2, (size - 1) // 2]
pool2d_shape = (size, 1)
reshape_shape = [sizes[0], 1, sizes[1], sizes[2], -1]
pad5d_shape = [0, 0, 0, 0, size // 2, (size - 1) // 2]
pool3d_shape = (size, 1, 1)
else:
pad4d_shape = [size // 2, (size - 1) // 2, 0, 0]
pool2d_shape = (1, size)
reshape_shape = [sizes[0], 1, sizes[1], -1, sizes[-1]]
pad5d_shape = [size // 2, (size - 1) // 2, 0, 0, 0, 0]
pool3d_shape = (1, 1, size)
if dim == 3:
div = paddle.nn.functional.pad(div, pad=pad4d_shape)
div = paddle.nn.functional.avg_pool2d(
div, kernel_size=pool2d_shape, stride=1)
div = paddle.squeeze(div, axis=1)
else:
div = paddle.reshape(div, shape=reshape_shape)
div = paddle.nn.functional.pad(div,
pad=pad5d_shape,
data_format='NCDHW')
div = paddle.nn.functional.avg_pool3d(
div, kernel_size=pool3d_shape, stride=1)
div = paddle.reshape(paddle.squeeze(div, axis=1), sizes)
div = paddle.scale(div, scale=alpha, bias=k)
div = paddle.pow(div, beta)
res = paddle.divide(x, div, name=name)
return res
......@@ -103,6 +103,7 @@ from .norm import GroupNorm #DEFINE_ALIAS
from .norm import LayerNorm #DEFINE_ALIAS
from .norm import SpectralNorm #DEFINE_ALIAS
from .norm import InstanceNorm #DEFINE_ALIAS
from .norm import LocalResponseNorm #DEFINE_ALIAS
# from .rnn import RNNCell #DEFINE_ALIAS
# from .rnn import GRUCell #DEFINE_ALIAS
# from .rnn import LSTMCell #DEFINE_ALIAS
......
......@@ -51,11 +51,22 @@ import numpy as np
import numbers
import warnings
from ...fluid.dygraph.base import no_grad
from .. import functional as F
__all__ = [
'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm',
'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'SyncBatchNorm'
'BatchNorm',
'GroupNorm',
'LayerNorm',
'SpectralNorm',
'InstanceNorm',
'BatchNorm1d',
'BatchNorm2d',
'BatchNorm3d',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'SyncBatchNorm',
'LocalResponseNorm',
]
......@@ -1147,3 +1158,63 @@ class SyncBatchNorm(_BatchNormBase):
cls.convert_sync_batchnorm(sublayer))
del layer
return layer_output
class LocalResponseNorm(layers.Layer):
"""
Local Response Normalization performs a type of "lateral inhibition" by normalizing over local input regions.
For more information, please refer to `ImageNet Classification with Deep Convolutional Neural Networks <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_
See more details in :ref:`api_paddle_nn_functional_local_response_norm` .
Parameters:
size (int): The number of channels to sum over.
alpha (float, optional): The scaling parameter, positive. Default:1e-4
beta (float, optional): The exponent, positive. Default:0.75
k (float, optional): An offset, positive. Default: 1.0
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from:
If input is 3-D Tensor, the string could be `"NCL"` or `"NLC"` . When it is `"NCL"`,
the data is stored in the order of: `[batch_size, input_channels, feature_length]`.
If input is 4-D Tensor, the string could be `"NCHW"`, `"NHWC"`. When it is `"NCHW"`,
the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`.
If input is 5-D Tensor, the string could be `"NCDHW"`, `"NDHWC"` . When it is `"NCDHW"`,
the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
name (str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.
Shape:
- input: 3-D/4-D/5-D tensor.
- output: 3-D/4-D/5-D tensor, the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.rand(shape=(3, 3, 112, 112), dtype="float32")
m = paddle.nn.LocalResponseNorm(size=5)
y = m(x)
print(y.shape) # [3, 3, 112, 112]
"""
def __init__(self,
size,
alpha=0.0001,
beta=0.75,
k=1.0,
data_format="NCHW",
name=None):
super(LocalResponseNorm, self).__init__()
self.size = size
self.alpha = alpha
self.beta = beta
self.k = k
self.data_format = data_format
self.name = name
def forward(self, input):
out = F.local_response_norm(input, self.size, self.alpha, self.beta,
self.k, self.data_format, self.name)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册