diff --git a/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py b/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..54f5e64fda4b6b5b1dc5bc04885fe21992d25f65 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle import fluid, nn +import paddle.fluid.dygraph as dg +import paddle.nn.functional as F +import paddle.fluid.initializer as I +import unittest + + +class LabelSmoothTestCase(unittest.TestCase): + def __init__(self, + methodName='runTest', + label_shape=(20, 1), + prior_dist=None, + epsilon=0.1, + dtype="float32"): + super(LabelSmoothTestCase, self).__init__(methodName) + + self.label_shape = label_shape + self.prior_dist = prior_dist + self.dtype = dtype + self.epsilon = epsilon + + def setUp(self): + self.label = np.random.randn(*(self.label_shape)).astype(self.dtype) + + def fluid_layer(self, place): + paddle.enable_static() + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + label_var = fluid.data( + "input", self.label_shape, dtype=self.dtype) + y_var = fluid.layers.label_smooth( + label_var, + prior_dist=self.prior_dist, + epsilon=self.epsilon, + dtype=self.dtype) + feed_dict = {"input": self.label} + exe = fluid.Executor(place) + exe.run(start) + y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def functional(self, place): + paddle.enable_static() + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + label_var = fluid.data( + "input", self.label_shape, dtype=self.dtype) + y_var = F.label_smooth( + label_var, prior_dist=self.prior_dist, epsilon=self.epsilon) + feed_dict = {"input": self.label} + exe = fluid.Executor(place) + exe.run(start) + y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def paddle_dygraph_layer(self): + paddle.disable_static() + label_var = dg.to_variable(self.label) + y_var = F.label_smooth( + label_var, prior_dist=self.prior_dist, epsilon=self.epsilon) + y_np = y_var.numpy() + return y_np + + def _test_equivalence(self, place): + place = fluid.CPUPlace() + result1 = self.fluid_layer(place) + result2 = self.functional(place) + result3 = self.paddle_dygraph_layer() + np.testing.assert_array_almost_equal(result1, result2) + np.testing.assert_array_almost_equal(result2, result3) + + def runTest(self): + place = fluid.CPUPlace() + self._test_equivalence(place) + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self._test_equivalence(place) + + +class LabelSmoothErrorTestCase(LabelSmoothTestCase): + def runTest(self): + place = fluid.CPUPlace() + with dg.guard(place): + with self.assertRaises(ValueError): + self.paddle_dygraph_layer() + + +def add_cases(suite): + suite.addTest(LabelSmoothTestCase(methodName='runTest')) + suite.addTest( + LabelSmoothTestCase( + methodName='runTest', label_shape=[2, 3, 1])) + + +def add_error_cases(suite): + suite.addTest(LabelSmoothErrorTestCase(methodName='runTest', epsilon=2)) + + +def load_tests(loader, standard_tests, pattern): + suite = unittest.TestSuite() + add_cases(suite) + add_error_cases(suite) + return suite + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7ceb2af45ccdfdf68b335771b60b790e5803ebac..3c99823fa897dc0a2f320cfce5e53997d557b951 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -62,7 +62,7 @@ from .common import dropout3d #DEFINE_ALIAS from .common import alpha_dropout #DEFINE_ALIAS # from .common import embedding #DEFINE_ALIAS # from .common import fc #DEFINE_ALIAS -from .common import label_smooth #DEFINE_ALIAS +from .common import label_smooth from .common import one_hot #DEFINE_ALIAS from .common import pad #DEFINE_ALIAS from .common import pad_constant_like #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 9f7fb0185133f580deba64634b62d82955670641..7d2ed0cdcf83a4fc5fb78f65c09c5b26d6b1ac23 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -20,7 +20,6 @@ from paddle.fluid.layers.tensor import Variable, fill_constant, zeros, concat from ...fluid.layers import core from ...fluid import dygraph_utils # TODO: define the common functions to build a neural network -from ...fluid.layers import label_smooth #DEFINE_ALIAS from ...fluid import one_hot #DEFINE_ALIAS from ...fluid.layers import pad2d #DEFINE_ALIAS from ...fluid.layers import unfold #DEFINE_ALIAS @@ -1482,3 +1481,83 @@ def linear(x, weight, bias=None, name=None): else: res = tmp return res + + +def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): + """ + Label smoothing is a mechanism to regularize the classifier layer and is called + label-smoothing regularization (LSR). + + Label smoothing is proposed to encourage the model to be less confident, + since optimizing the log-likelihood of the correct label directly may + cause overfitting and reduce the ability of the model to adapt. Label + smoothing replaces the ground-truth label :math:`y` with the weighted sum + of itself and some fixed distribution :math:`\mu`. For class :math:`k`, + i.e. + + .. math:: + + \\tilde{y_k} = (1 - \epsilon) * y_k + \epsilon * \mu_k, + + where :math:`1 - \epsilon` and :math:`\epsilon` are the weights + respectively, and :math:`\\tilde{y}_k` is the smoothed label. Usually + uniform distribution is used for :math:`\mu`. + + See more details about label smoothing in https://arxiv.org/abs/1512.00567. + + Parameters: + label(Tensor): The input variable containing the label data. The + label data should use one-hot representation. It's + a multidimensional tensor with a shape of + :math:`[N_1, ..., Depth]`, where Depth is class number. The dtype can be "float32" and "float64". + prior_dist(Tensor, optional): The prior distribution to be used to smooth + labels. If not provided, an uniform distribution + is used. It's a multidimensional tensor with a shape of + :math:`[1, class\_num]` . The default value is None. + epsilon(float, optional): The weight used to mix up the original ground-truth + distribution and the fixed distribution. The default value is + 0.1. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + Tensor: The tensor containing the smoothed labels. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x_data = np.array([[[0, 1, 0], + [ 1, 0, 1]]]).astype("float32") + print(x_data.shape) + paddle.disable_static() + x = paddle.to_tensor(x_data, stop_gradient=False) + output = paddle.nn.functional.label_smooth(x) + print(output.numpy()) + + #[[[0.03333334 0.93333334 0.03333334] + # [0.93333334 0.03333334 0.93333334]]] + """ + if epsilon > 1. or epsilon < 0.: + raise ValueError("The value of epsilon must be between 0 and 1.") + + if in_dygraph_mode(): + return core.ops.label_smooth(label, prior_dist, 'epsilon', + float(epsilon)) + + check_variable_and_dtype(label, 'label', ['float32', 'float64'], + 'label_smooth') + + helper = LayerHelper("label_smooth", **locals()) + label.stop_gradient = True + smooth_label = helper.create_variable_for_type_inference(label.dtype) + helper.append_op( + type="label_smooth", + inputs={"X": label, + "PriorDist": prior_dist} if prior_dist else {"X": label}, + outputs={"Out": smooth_label}, + attrs={"epsilon": float(epsilon)}) + return smooth_label