diff --git a/python/paddle/fluid/tests/unittests/test_normalize.py b/python/paddle/fluid/tests/unittests/test_normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..6595a29b24ae23c9b38538035c9593ba77eecdb7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_normalize.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np + + +def p_normalize(x, axis=1, p=2, epsilon=1e-12, keepdims=True): + if len(x.shape) == 1: + axis = 0 + xp = np.power(np.abs(x), p) + s = np.sum(xp, axis=axis, keepdims=keepdims) + r = np.maximum(np.power(s, 1.0 / p), epsilon) + return x / r + + +class TestNNFunctionalNormalize(unittest.TestCase): + def setUp(self): + self.input_np = np.random.random(size=(10, 10)).astype(np.float32) + self.input_np2 = np.array([0.0, 0.0]).astype(np.float32) + self.expected0 = p_normalize(self.input_np) + self.expected1 = p_normalize(self.input_np, p=1.5) + self.expected2 = p_normalize(self.input_np, axis=0) + self.expected3 = p_normalize(self.input_np2) + + def run_imperative(self): + x = paddle.to_variable(self.input_np) + y = F.normalize(x) + self.assertTrue(np.allclose(y.numpy(), self.expected0)) + + y = F.normalize(x, p=1.5) + self.assertTrue(np.allclose(y.numpy(), self.expected1)) + + y = F.normalize(x, axis=0) + self.assertTrue(np.allclose(y.numpy(), self.expected2)) + + x = paddle.to_variable(self.input_np2) + y = F.normalize(x) + self.assertTrue(np.allclose(y.numpy(), self.expected3)) + + def run_static(self, use_gpu=False): + x = paddle.data(name='input', shape=[10, 10], dtype='float32') + x2 = paddle.data(name='input2', shape=[2], dtype='float32') + result0 = F.normalize(x) + result1 = F.normalize(x, p=1.5) + result2 = F.normalize(x, axis=0) + result3 = F.normalize(x, name='aaa') + result4 = F.normalize(x2) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + static_result = exe.run( + feed={"input": self.input_np, + "input2": self.input_np2}, + fetch_list=[result0, result1, result2, result4]) + + self.assertTrue(np.allclose(static_result[0], self.expected0)) + self.assertTrue(np.allclose(static_result[1], self.expected1)) + self.assertTrue(np.allclose(static_result[2], self.expected2)) + self.assertTrue('aaa' in result3.name) + self.assertTrue(np.allclose(static_result[3], self.expected3)) + + def test_cpu(self): + paddle.disable_static(place=paddle.fluid.CPUPlace()) + self.run_imperative() + paddle.enable_static() + + with fluid.program_guard(fluid.Program()): + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + + paddle.disable_static(place=paddle.fluid.CUDAPlace(0)) + self.run_imperative() + paddle.enable_static() + + with fluid.program_guard(fluid.Program()): + self.run_static(use_gpu=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index fa85b19426cd2e4b5f02d8540a5ddc545ada2aa5..bc71b8bdf06d2885327bc722b9594208f49478e8 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -150,6 +150,7 @@ from .loss import teacher_student_sigmoid_loss #DEFINE_ALIAS from .norm import l2_normalize #DEFINE_ALIAS # from .norm import layer_norm #DEFINE_ALIAS from .norm import lrn #DEFINE_ALIAS +from .norm import normalize #DEFINE_ALIAS # from .norm import spectral_norm #DEFINE_ALIAS from .pooling import pool2d #DEFINE_ALIAS from .pooling import pool3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e08c707b8daa6bae8bc30b2753852d41319cebb4..f8bc0b1b54e9639a41d3975f894b3a9873014b38 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -176,61 +176,61 @@ def margin_ranking_loss(input, return result_out -def l1_loss(x, label, reduction='mean', name=None): +def l1_loss(input, label, reduction='mean', name=None): """ - This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows. + This operator computes the L1 Loss of Tensor ``input`` and ``label`` as follows. - If :attr:`reduction` set to ``'none'``, the loss is: + If `reduction` set to ``'none'``, the loss is: .. math:: - Out = \lvert x - label\rvert + Out = \lvert input - label\rvert - If :attr:`reduction` set to ``'mean'``, the loss is: + If `reduction` set to ``'mean'``, the loss is: .. math:: - Out = MEAN(\lvert x - label\rvert) + Out = MEAN(\lvert input - label\rvert) - If :attr:`reduction` set to ``'sum'``, the loss is: + If `reduction` set to ``'sum'``, the loss is: .. math:: - Out = SUM(\lvert x - label\rvert) + Out = SUM(\lvert input - label\rvert) Parameters: - x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. - label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64. + input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. + label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64. reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. - If :attr:`reduction` is ``'none'``, the unreduced loss is returned; - If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. - If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. + If `reduction` is ``'none'``, the unreduced loss is returned; + If `reduction` is ``'mean'``, the reduced mean loss is returned. + If `reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, the L1 Loss of Tensor ``x`` and ``label``. - If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` . - If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar. + Tensor, the L1 Loss of Tensor ``input`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. Examples: .. code-block:: python import paddle import numpy as np paddle.disable_static() - x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") + input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32") - x = paddle.to_variable(x_data) + input = paddle.to_variable(input_data) label = paddle.to_variable(label_data) - l1_loss = paddle.nn.functional.l1_loss(x, label) + l1_loss = paddle.nn.functional.l1_loss(input, label) print(l1_loss.numpy()) # [0.35] - l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='none') + l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='none') print(l1_loss.numpy()) # [[0.20000005 0.19999999] # [0.2 0.79999995]] - l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='sum') + l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='sum') print(l1_loss.numpy()) # [1.4] """ @@ -241,7 +241,7 @@ def l1_loss(x, label, reduction='mean', name=None): if in_dygraph_mode(): unreduced = _elementwise_op_in_dygraph( - x, label, axis=-1, act='abs', op_name='elementwise_sub') + input, label, axis=-1, act='abs', op_name='elementwise_sub') if reduction == 'mean': return core.ops.mean(unreduced) elif reduction == 'sum': @@ -251,18 +251,18 @@ def l1_loss(x, label, reduction='mean', name=None): return unreduced fluid.data_feeder.check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') fluid.data_feeder.check_variable_and_dtype( label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') if reduction == 'sum': - unreduced = paddle.elementwise_sub(x, label, act='abs') + unreduced = paddle.elementwise_sub(input, label, act='abs') return paddle.sum(unreduced, name=name) elif reduction == 'mean': - unreduced = paddle.elementwise_sub(x, label, act='abs') + unreduced = paddle.elementwise_sub(input, label, act='abs') return paddle.mean(unreduced, name=name) else: - return paddle.elementwise_sub(x, label, act='abs', name=name) + return paddle.elementwise_sub(input, label, act='abs', name=name) def nll_loss(input, diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 04b031b91ce387c1d8266d53725090d23b592f8c..0b007041b4ab336ae355f5d338a0d7dca9b5380e 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -13,6 +13,11 @@ # limitations under the License. # TODO: define normalization api +import paddle +import paddle.fluid as fluid +from ...fluid.data_feeder import check_variable_and_dtype, check_type +from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode, core from ...fluid.layers import l2_normalize #DEFINE_ALIAS from ...fluid.layers import lrn #DEFINE_ALIAS @@ -24,5 +29,84 @@ __all__ = [ 'l2_normalize', # 'layer_norm', 'lrn', + 'normalize', # 'spectral_norm' ] + + +def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): + """ + This op normalizes ``x`` along dimension ``axis`` using :math:`L_p` norm. This layer computes + + .. math:: + + y = \frac{x}{ \max\left( \lvert \lvert x \rvert \rvert_p, epsilon\right) } + + .. math:: + \lvert \lvert x \rvert \rvert_p = \left(\sum_i {\lvert x_i\rvert^p} \right)^{1/p} + + where, :math:`\sum_i{\lvert x_i\rvert^p}` is calculated along the ``axis`` dimension. + + + Args: + x (Tensor): The input tensor could be N-D tensor, and the input data type could be float32 or float64. + p (float|int, optional): The exponent value in the norm formulation. Default: 2 + axis (int, optional): The axis on which to apply normalization. If ``x`` is 1-D tensor, ``axis`` is fixed to 0. If `axis < 0`, \ + the dimension to normalization is `x.ndim + axis`. -1 is the last dimension. + epsilon (float, optional): Small float added to denominator to avoid dividing by zero. Default is 1e-12. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the output has the same shape and data type with ``x``. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + + paddle.disable_static() + x = np.arange(6, dtype=np.float32).reshape(2,3) + x = paddle.to_variable(x) + y = F.normalize(x) + print(y.numpy()) + # [[0. 0.4472136 0.8944272 ] + # [0.42426404 0.5656854 0.7071067 ]] + + y = F.normalize(x, p=1.5) + print(y.numpy()) + # [[0. 0.40862012 0.81724024] + # [0.35684016 0.4757869 0.5947336 ]] + + y = F.normalize(x, axis=0) + print(y.numpy()) + # [[0. 0.24253564 0.37139067] + # [1. 0.97014254 0.9284767 ]] + """ + if len(x.shape) == 1: + axis = 0 + if in_dygraph_mode(): + eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype) + out = core.ops.p_norm(x, 'axis', axis, 'porder', + float(p), 'keepdim', True, 'epsilon', epsilon) + return x / core.ops.elementwise_max(out, eps) + + check_type(p, 'p', (float, int), 'normalize') + check_type(axis, 'axis', (int), 'normalize') + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'normalize') + + attrs = { + 'axis': axis, + 'porder': float(p), + 'keepdim': True, + 'epsilon': epsilon, + } + helper = LayerHelper('p_norm', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='p_norm', inputs={'X': x}, outputs={'Out': out}, attrs=attrs) + eps = out.block.create_var(dtype=out.dtype) + paddle.fill_constant([1], out.dtype, epsilon, out=eps) + return paddle.elementwise_div(x, paddle.maximum(out, eps), name=name) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index bc4f32f9c31860b97dada0bd228428193e51138b..5067264ee792dd642e22a3ecde6eb7c1264d3875 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -256,39 +256,39 @@ class MSELoss(fluid.dygraph.layers.Layer): class L1Loss(fluid.dygraph.Layer): """ This interface is used to construct a callable object of the ``L1Loss`` class. - The L1Loss layer calculates the L1 Loss of ``x`` and ``label`` as follows. + The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows. - If :attr:`reduction` set to ``'none'``, the loss is: + If `reduction` set to ``'none'``, the loss is: .. math:: - Out = \lvert x - label\rvert + Out = \lvert input - label\rvert - If :attr:`reduction` set to ``'mean'``, the loss is: + If `reduction` set to ``'mean'``, the loss is: .. math:: - Out = MEAN(\lvert x - label\rvert) + Out = MEAN(\lvert input - label\rvert) - If :attr:`reduction` set to ``'sum'``, the loss is: + If `reduction` set to ``'sum'``, the loss is: .. math:: - Out = SUM(\lvert x - label\rvert) + Out = SUM(\lvert input - label\rvert) Parameters: reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. - If :attr:`reduction` is ``'none'``, the unreduced loss is returned; - If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. - If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. + If `reduction` is ``'none'``, the unreduced loss is returned; + If `reduction` is ``'mean'``, the reduced mean loss is returned. + If `reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Shape: - x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. - label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64. - output (Tensor): The L1 Loss of ``x`` and ``label``. - If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` . - If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar. + input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. + label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64. + output (Tensor): The L1 Loss of ``input`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. Examples: .. code-block:: python @@ -296,23 +296,23 @@ class L1Loss(fluid.dygraph.Layer): import numpy as np paddle.disable_static() - x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") + input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32") - x = paddle.to_variable(x_data) + input = paddle.to_variable(input_data) label = paddle.to_variable(label_data) l1_loss = paddle.nn.loss.L1Loss() - output = l1_loss(x, label) + output = l1_loss(input, label) print(output.numpy()) # [0.35] l1_loss = paddle.nn.loss.L1Loss(reduction='sum') - output = l1_loss(x, label) + output = l1_loss(input, label) print(output.numpy()) # [1.4] l1_loss = paddle.nn.loss.L1Loss(reduction='none') - output = l1_loss(x, label) + output = l1_loss(input, label) print(output.numpy()) # [[0.20000005 0.19999999] # [0.2 0.79999995]] @@ -327,9 +327,9 @@ class L1Loss(fluid.dygraph.Layer): self.reduction = reduction self.name = name - def forward(self, x, label): + def forward(self, input, label): return paddle.nn.functional.l1_loss( - x, label, self.reduction, name=self.name) + input, label, self.reduction, name=self.name) class BCELoss(fluid.dygraph.Layer):