diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 8ed93065c81f845a2fc376c478f9ed6b6e558f4e..3da448522740d36d2fe4b90e2d29d7b735a0c1b4 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -33,9 +33,9 @@ import logging __all__ = [ 'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding', - 'GRUUnit', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', - 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', 'SpectralNorm', - 'TreeConv' + 'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu', + 'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', + 'SpectralNorm', 'TreeConv' ] @@ -971,6 +971,132 @@ class Linear(layers.Layer): return self._helper.append_activation(pre_activation, act=self._act) +class InstanceNorm(layers.Layer): + """ + This interface is used to construct a callable object of the ``InstanceNorm`` class. + For more details, refer to code examples. + + Can be used as a normalizer function for convolution or fully_connected operations. + The required data format for this layer is one of the following: + + DataLayout: NCHW `[batch, in_channels, in_height, in_width]` + + Refer to `Instance Normalization: The Missing Ingredient for Fast Stylization `_ + for more details. + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\ + \\ mean\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ variance\ of\ one\ feature\ map\ in\ mini-batch \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Note: + `H` means height of feature map, `W` means width of feature map. + + Parameters: + num_channels(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): A value added to the denominator for + numerical stability. Default is 1e-5. + param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale` + of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. + If the Initializer of the param_attr is not set, the parameter is initialized + one. Default: None. + bias_attr(ParamAttr, optional): The parameter attribute for the bias of instance_norm. + If it is set to None or one attribute of ParamAttr, instance_norm + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + Default: None. + dtype(str, optional): Indicate the data type of the input ``Tensor``, + which can be float32 or float64. Default: float32. + + Returns: + None. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + from paddle.fluid.dygraph.base import to_variable + import numpy as np + import paddle + + # x's shape is [1, 3, 1, 2] + x = np.array([[[[1.0, 8.0]], [[10.0, 5.0]], [[4.0, 6.0]]]]).astype('float32') + with fluid.dygraph.guard(): + x = to_variable(x) + instanceNorm = paddle.nn.InstanceNorm(3) + ret = instanceNorm(x) + # ret's shape is [1, 3, 1, 2]; value is [-1 1 0.999999 -0.999999 -0.999995 0.999995] + print(ret) + + """ + + def __init__(self, + num_channels, + epsilon=1e-5, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(InstanceNorm, self).__init__() + assert bias_attr is not False, "bias_attr should not be False in InstanceNorm." + + self._epsilon = epsilon + self._param_attr = param_attr + self._bias_attr = bias_attr + self._dtype = dtype + + self.scale = self.create_parameter( + attr=self._param_attr, + shape=[num_channels], + dtype=self._dtype, + default_initializer=Constant(1.0), + is_bias=False) + self.bias = self.create_parameter( + attr=self._bias_attr, + shape=[num_channels], + dtype=self._dtype, + default_initializer=Constant(0.0), + is_bias=True) + + def forward(self, input): + if in_dygraph_mode(): + out, _, _ = core.ops.instance_norm(input, self.scale, self.bias, + 'epsilon', self._epsilon) + return out + + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + "InstanceNorm") + + attrs = {"epsilon": self._epsilon} + + inputs = {"X": [input], "Scale": [self.scale], "Bias": [self.bias]} + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + instance_norm_out = self._helper.create_variable_for_type_inference( + self._dtype) + + outputs = { + "Y": [instance_norm_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + self._helper.append_op( + type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs) + return instance_norm_out + + class BatchNorm(layers.Layer): """ This interface is used to construct a callable object of the ``BatchNorm`` class. diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 7cb67b8bd55c2782a639833c9beb0728af061a08..6982202be1b71032a2353df7e45a6bcbab72be6e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1258,6 +1258,61 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(static_ret, dy_rlt_value)) self.assertTrue(np.allclose(static_ret, static_ret2)) + def test_instance_norm(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + shape = (2, 4, 3, 3) + + input = np.random.random(shape).astype('float32') + + with self.static_graph(): + X = fluid.layers.data( + name='X', shape=shape, dtype='float32', append_batch_size=False) + ret = layers.instance_norm(input=X) + static_ret = self.get_static_graph_result( + feed={'X': input}, fetch_list=[ret])[0] + + with self.static_graph(): + X = fluid.layers.data( + name='X', shape=shape, dtype='float32', append_batch_size=False) + instanceNorm = nn.InstanceNorm(num_channels=shape[1]) + ret = instanceNorm(X) + static_ret2 = self.get_static_graph_result( + feed={'X': input}, fetch_list=[ret])[0] + + with self.dynamic_graph(): + instanceNorm = nn.InstanceNorm(num_channels=shape[1]) + dy_ret = instanceNorm(base.to_variable(input)) + dy_rlt_value = dy_ret.numpy() + + with self.dynamic_graph(): + instanceNorm = paddle.nn.InstanceNorm(num_channels=shape[1]) + dy_ret = instanceNorm(base.to_variable(input)) + dy_rlt_value2 = dy_ret.numpy() + + self.assertTrue(np.allclose(static_ret, dy_rlt_value)) + self.assertTrue(np.allclose(static_ret, dy_rlt_value2)) + self.assertTrue(np.allclose(static_ret, static_ret2)) + + with self.static_graph(): + # the input of InstanceNorm must be Variable. + def test_Variable(): + instanceNorm = paddle.nn.InstanceNorm(num_channels=shape[1]) + ret1 = instanceNorm(input) + + self.assertRaises(TypeError, test_Variable) + + # the input dtype of InstanceNorm must be float32 or float64 + def test_type(): + input = np.random.random(shape).astype('int32') + instanceNorm = paddle.nn.InstanceNorm(num_channels=shape[1]) + ret2 = instanceNorm(input) + + self.assertRaises(TypeError, test_type) + def test_spectral_norm(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 1155484fdca4f78ad4870f7f16ded65150632927..ee43e8633a3f6877620358bd782d6437d96582ba 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -16,6 +16,11 @@ # including layers, linear, conv, rnn etc. # __all__ = [] +from .layer import norm + +__all__ = [] +__all__ += norm.__all__ + # TODO: define alias in nn directory # from .clip import ErrorClipByValue #DEFINE_ALIAS # from .clip import GradientClipByGlobalNorm #DEFINE_ALIAS @@ -73,6 +78,7 @@ from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFIN # from .layer.norm import BatchNorm #DEFINE_ALIAS # from .layer.norm import GroupNorm #DEFINE_ALIAS # from .layer.norm import LayerNorm #DEFINE_ALIAS +from .layer.norm import InstanceNorm #DEFINE_ALIAS # from .layer.norm import SpectralNorm #DEFINE_ALIAS # from .layer.activation import PReLU #DEFINE_ALIAS # from .layer.activation import ReLU #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 8e6eeec089e41d6b609d0c241699ae31cb1aec07..c61fd56a931f9c5ac550b04e78c63d7344b84190 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -16,6 +16,8 @@ from . import loss from . import conv +from . import norm from .loss import * from .conv import * +from .norm import * diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index d209e93054a778865e270cda673b8b055d373d4f..d02807773b96aa40193f4505a382a6d622bea7eb 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -17,3 +17,6 @@ # 'GroupNorm', # 'LayerNorm', # 'SpectralNorm'] +__all__ = ['InstanceNorm'] + +from ...fluid.dygraph.nn import InstanceNorm