diff --git a/python/paddle/fluid/tests/unittests/test_multi_label_soft_margin_loss.py b/python/paddle/fluid/tests/unittests/test_multi_label_soft_margin_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1eae5eb97dbede4a0d303bf364a13e6b1e545e7d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multi_label_soft_margin_loss.py @@ -0,0 +1,252 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + + +def call_MultiLabelSoftMarginLoss_layer( + input, + label, + weight=None, + reduction='mean', +): + multilabel_margin_loss = paddle.nn.MultiLabelSoftMarginLoss( + weight=weight, reduction=reduction) + res = multilabel_margin_loss( + input=input, + label=label, + ) + return res + + +def call_MultiLabelSoftMarginLoss_functional( + input, + label, + weight=None, + reduction='mean', +): + res = paddle.nn.functional.multi_label_soft_margin_loss( + input, + label, + reduction=reduction, + weight=weight, + ) + return res + + +def test_static(place, + input_np, + label_np, + weight_np=None, + reduction='mean', + functional=False): + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype='float64') + label = paddle.static.data(name='label', + shape=label_np.shape, + dtype='float64') + feed_dict = { + "input": input_np, + "label": label_np, + } + weight = None + if weight_np is not None: + weight = paddle.static.data(name='weight', + shape=weight_np.shape, + dtype='float64') + feed_dict['weight'] = weight_np + + if functional: + res = call_MultiLabelSoftMarginLoss_functional(input=input, + label=label, + weight=weight, + reduction=reduction) + else: + res = call_MultiLabelSoftMarginLoss_layer(input=input, + label=label, + weight=weight, + reduction=reduction) + + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +def test_dygraph(place, + input_np, + label_np, + weight=None, + reduction='mean', + functional=False): + with paddle.fluid.dygraph.base.guard(): + input = paddle.to_tensor(input_np) + label = paddle.to_tensor(label_np) + if weight is not None: + weight = paddle.to_tensor(weight) + + if functional: + dy_res = call_MultiLabelSoftMarginLoss_functional( + input=input, label=label, weight=weight, reduction=reduction) + else: + dy_res = call_MultiLabelSoftMarginLoss_layer(input=input, + label=label, + weight=weight, + reduction=reduction) + dy_result = dy_res.numpy() + return dy_result + + +def calc_multilabel_margin_loss( + input, + label, + weight=None, + reduction="mean", +): + + def LogSigmoid(x): + return np.log(1 / (1 + np.exp(-x))) + + loss = -(label * LogSigmoid(input) + (1 - label) * LogSigmoid(-input)) + + if weight is not None: + loss = loss * weight + + loss = loss.mean(axis=-1) # only return N loss values + + if reduction == "none": + return loss + elif reduction == "mean": + return np.mean(loss) + elif reduction == "sum": + return np.sum(loss) + + +class TestMultiLabelMarginLoss(unittest.TestCase): + + def test_MultiLabelSoftMarginLoss(self): + input = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) + label = np.random.randint(0, 2, size=(5, 5)).astype(np.float64) + + places = ['cpu'] + if paddle.device.is_compiled_with_cuda(): + places.append('gpu') + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + expected = calc_multilabel_margin_loss(input=input, + label=label, + reduction=reduction) + + dy_result = test_dygraph(place=place, + input_np=input, + label_np=label, + reduction=reduction) + + static_result = test_static(place=place, + input_np=input, + label_np=label, + reduction=reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + input_np=input, + label_np=label, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input_np=input, + label_np=label, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_MultiLabelSoftMarginLoss_error(self): + paddle.disable_static() + self.assertRaises(ValueError, + paddle.nn.MultiLabelSoftMarginLoss, + reduction="unsupport reduction") + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + self.assertRaises(ValueError, + paddle.nn.functional.multi_label_soft_margin_loss, + input=input, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + def test_MultiLabelSoftMarginLoss_weights(self): + input = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64) + label = np.random.randint(0, 2, size=(5, 5)).astype(np.float64) + weight = np.random.randint(0, 2, size=(5, 5)).astype(np.float64) + place = 'cpu' + reduction = 'mean' + expected = calc_multilabel_margin_loss(input=input, + label=label, + weight=weight, + reduction=reduction) + + dy_result = test_dygraph(place=place, + input_np=input, + label_np=label, + weight=weight, + reduction=reduction) + + static_result = test_static(place=place, + input_np=input, + label_np=label, + weight_np=weight, + reduction=reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + input_np=input, + label_np=label, + weight_np=weight, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input_np=input, + label_np=label, + weight=weight, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_MultiLabelSoftMarginLoss_dimension(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32') + label = paddle.to_tensor([[0.2, 0.1]], dtype='float32') + self.assertRaises(ValueError, + paddle.nn.functional.multi_label_soft_margin_loss, + input=input, + label=label) + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 8b29659a1f400bc1b7aff1bfa329912834dc6554..a3ae38c9794325ecad511b03f72b6e4cd9615937 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -104,6 +104,7 @@ from .layer.loss import NLLLoss # noqa: F401 from .layer.loss import BCELoss # noqa: F401 from .layer.loss import KLDivLoss # noqa: F401 from .layer.loss import MarginRankingLoss # noqa: F401 +from .layer.loss import MultiLabelSoftMarginLoss from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401 @@ -312,6 +313,7 @@ __all__ = [ # noqa 'MaxUnPool1D', 'MaxUnPool2D', 'MaxUnPool3D', + 'MultiLabelSoftMarginLoss', 'HingeEmbeddingLoss', 'Identity', 'CosineEmbeddingLoss', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index cdb1135eba800242c096795d747a964dbf8103c0..a9c1d24e2c6fcc60dcda879a860f6025dde0ad6e 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -91,6 +91,7 @@ from .loss import square_error_cost # noqa: F401 from .loss import ctc_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401 from .loss import cosine_embedding_loss # noqa: F401 +from .loss import multi_label_soft_margin_loss from .loss import triplet_margin_with_distance_loss from .loss import triplet_margin_loss from .norm import batch_norm # noqa: F401 @@ -206,6 +207,7 @@ __all__ = [ # noqa 'log_loss', 'mse_loss', 'margin_ranking_loss', + 'multi_label_soft_margin_loss', 'nll_loss', 'npair_loss', 'sigmoid_focal_loss', diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 2f37f8a50f4d1a4b8bbab1830f4e9d9030b74a3f..2537a9f3ae610cf33af34c7f615f03fd4a34cd37 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2668,6 +2668,86 @@ def sigmoid_focal_loss(logit, return loss +def multi_label_soft_margin_loss(input, + label, + weight=None, + reduction="mean", + name=None): + r""" + + Parameters: + input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1. + label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is the same as the shape of input. + weight (Tensor,optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size C and the data type is float32, float64. + Default is ``'None'`` . + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements. + label: N-D Tensor, same shape as the input. + weight:N-D Tensor, the shape is [N,1] + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. + + Returns: + Tensor, The tensor variable storing the multi_label_soft_margin_loss of input and label. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) + # label elements in {1., -1.} + label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32) + loss = F.multi_label_soft_margin_loss(input, label, reduction='none') + print(loss) + # Tensor([3.49625897, 0.71111226, 0.43989015]) + loss = F.multi_label_soft_margin_loss(input, label, reduction='mean') + print(loss) + # Tensor([1.54908717]) + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'multi_label_soft_margin_loss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + + if not (input.shape == label.shape): + raise ValueError("The input and label should have same dimension," + "but received {}!={}".format(input.shape, label.shape)) + + if not _non_static_mode(): + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'multilabel_soft_margin_loss') + check_variable_and_dtype(label, 'label', ['float32', 'float64'], + 'multilabel_soft_margin_loss') + + loss = -(label * paddle.nn.functional.log_sigmoid(input) + + (1 - label) * paddle.nn.functional.log_sigmoid(-input)) + + if weight is not None: + if not _non_static_mode(): + check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], + 'multilabel_soft_margin_loss') + loss = loss * weight + + loss = loss.mean(axis=-1) # only return N loss values + + if reduction == "none": + return loss + elif reduction == "mean": + return paddle.mean(loss) + elif reduction == "sum": + return paddle.sum(loss) + + def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None): r""" This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1). diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index e9ccee1bd3829df9ab4f727d802ea1b500261a69..e6f6a6508488ba6040516226af32fa76cb47a4bb 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -76,6 +76,7 @@ from .loss import NLLLoss # noqa: F401 from .loss import BCELoss # noqa: F401 from .loss import KLDivLoss # noqa: F401 from .loss import MarginRankingLoss # noqa: F401 +from .loss import MultiLabelSoftMarginLoss from .loss import CTCLoss # noqa: F401 from .loss import SmoothL1Loss # noqa: F401 from .loss import HingeEmbeddingLoss # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 1e72548ecc13823b8bb9b18895192cf86a2a30a2..aeb213bb2c1f5efdbf3aba3d6b6abcd2cdcec4a1 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1217,6 +1217,84 @@ class SmoothL1Loss(Layer): name=self.name) +class MultiLabelSoftMarginLoss(Layer): + r"""Creates a criterion that optimizes a multi-class multi-classification + hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 2D `Tensor` of target class indices). + For each sample in the mini-batch: + + .. math:: + \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \ + :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \ + :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \ + and :math:`i \neq y[j]` for all :math:`i` and :math:`j`. + :math:`y` and :math:`x` must have the same size. + + Parameters: + weight (Tensor,optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size C and the data type is float32, float64. + Default is ``'None'`` . + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Call parameters: + input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1. + label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input. + + Shape: + input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements. + label: N-D Tensor, same shape as the input. + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input. + + Returns: + A callable object of MultiLabelSoftMarginLoss. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) + label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32) + + multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none') + loss = multi_label_soft_margin_loss(input, label) + print(loss) + # Tensor([3.49625897, 0.71111226, 0.43989015]) + + multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='mean') + loss = multi_label_soft_margin_loss(input, label) + print(loss) + # Tensor([1.54908717]) + """ + + def __init__(self, weight=None, reduction="mean", name=None): + super(MultiLabelSoftMarginLoss, self).__init__() + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + self.weight = weight + self.reduction = reduction + self.name = name + + def forward(self, input, label): + return F.multi_label_soft_margin_loss(input, + label, + weight=self.weight, + reduction=self.reduction, + name=self.name) + + class HingeEmbeddingLoss(Layer): r""" This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).