From 8474392d7f8e30b3a30a952fc823d9e43c82ed87 Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Thu, 13 Oct 2022 11:00:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Paddle=20Hackathon=20No.11=E3=80=91=20?= =?UTF-8?q?(#45595)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 2022-08-30_update nn.layer.loss nn.functional.loss, test_file * 2022-08-30_update nn.layer.loss nn.functional.loss, test_file * fix: test_file * fix: test_file, docs, multi_margin_loss * fix: doc weight function * fix: test_multi_margin_loss * fix: weight np.testing.assert_allclose * fix: test_file * fix: en_doc * 2022-10-10 --- .../tests/unittests/test_multimarginloss.py | 454 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/loss.py | 111 +++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 97 ++++ 6 files changed, 667 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_multimarginloss.py diff --git a/python/paddle/fluid/tests/unittests/test_multimarginloss.py b/python/paddle/fluid/tests/unittests/test_multimarginloss.py new file mode 100644 index 0000000000..1eff1deb69 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multimarginloss.py @@ -0,0 +1,454 @@ +# -*- coding: utf-8 -* +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + + +def call_MultiMarginLoss_layer( + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', +): + triplet_margin_loss = paddle.nn.MultiMarginLoss(p=p, + margin=margin, + weight=weight, + reduction=reduction) + res = triplet_margin_loss( + input=input, + label=label, + ) + return res + + +def call_MultiMarginLoss_functional( + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', +): + res = paddle.nn.functional.multi_margin_loss(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + return res + + +def test_static(place, + input_np, + label_np, + p=1, + margin=1.0, + weight_np=None, + reduction='mean', + functional=False): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype=input_np.dtype) + label = paddle.static.data(name='label', + shape=label_np.shape, + dtype=label_np.dtype) + feed_dict = { + "input": input_np, + "label": label_np, + } + weight = None + if weight_np is not None: + weight = paddle.static.data(name='weight', + shape=weight_np.shape, + dtype=weight_np.dtype) + feed_dict['weight'] = weight_np + if functional: + res = call_MultiMarginLoss_functional(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + else: + res = call_MultiMarginLoss_layer(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result[0] + + +def test_static_data_shape(place, + input_np, + label_np, + wrong_label_shape=None, + weight_np=None, + wrong_weight_shape=None, + functional=False): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype=input_np.dtype) + if wrong_label_shape is None: + label_shape = label_np.shape + else: + label_shape = wrong_label_shape + label = paddle.static.data(name='label', + shape=label_shape, + dtype=label_np.dtype) + feed_dict = { + "input": input_np, + "label": label_np, + } + weight = None + if weight_np is not None: + if wrong_weight_shape is None: + weight_shape = weight_np.shape + else: + weight_shape = wrong_weight_shape + weight = paddle.static.data(name='weight', + shape=weight_shape, + dtype=weight_np.dtype) + feed_dict['weight'] = weight_np + if functional: + res = call_MultiMarginLoss_functional( + input=input, + label=label, + weight=weight, + ) + else: + res = call_MultiMarginLoss_layer( + input=input, + label=label, + weight=weight, + ) + + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +def test_dygraph(place, + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', + functional=False): + paddle.disable_static() + input = paddle.to_tensor(input) + label = paddle.to_tensor(label) + + if weight is not None: + weight = paddle.to_tensor(weight) + if functional: + dy_res = call_MultiMarginLoss_functional(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + else: + dy_res = call_MultiMarginLoss_layer(input=input, + label=label, + p=p, + margin=margin, + weight=weight, + reduction=reduction) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_multi_margin_loss( + input, + label, + p=1, + margin=1.0, + weight=None, + reduction='mean', +): + index_sample = np.array([input[i, label[i]] + for i in range(label.size)]).reshape(-1, 1) + if weight is None: + expected = np.mean(np.maximum(margin + input - index_sample, 0.0)**p, + axis=1) - margin**p / input.shape[1] + else: + weight = np.array([weight[label[i]] + for i in range(label.size)]).reshape(-1, 1) + expected = np.mean(np.maximum(weight * (margin + input - index_sample), 0.0) ** p, axis=1) - weight*(margin ** p / \ + input.shape[1]) + + if reduction == 'mean': + expected = np.mean(expected) + elif reduction == 'sum': + expected = np.sum(expected) + else: + expected = expected + + return expected + + +class TestMultiMarginLoss(unittest.TestCase): + + def test_MultiMarginLoss(self): + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(batch_size, )).astype(np.int64) + + places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + expected = calc_multi_margin_loss(input=input, + label=label, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + input=input, + label=label, + reduction=reduction, + ) + + static_result = test_static( + place=place, + input_np=input, + label_np=label, + reduction=reduction, + ) + np.testing.assert_allclose(static_result, expected) + np.testing.assert_allclose(static_result, dy_result) + np.testing.assert_allclose(dy_result, expected) + static_functional = test_static(place=place, + input_np=input, + label_np=label, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input=input, + label=label, + reduction=reduction, + functional=True) + np.testing.assert_allclose(static_functional, expected) + np.testing.assert_allclose(static_functional, dy_functional) + np.testing.assert_allclose(dy_functional, expected) + + def test_MultiMarginLoss_error(self): + paddle.disable_static() + self.assertRaises(ValueError, + paddle.nn.MultiMarginLoss, + reduction="unsupport reduction") + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([0], dtype='int32') + self.assertRaises(ValueError, + paddle.nn.functional.multi_margin_loss, + input=input, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + def test_MultiMarginLoss_dimension(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32') + label = paddle.to_tensor([0, 1, 1], dtype='int32') + + self.assertRaises( + ValueError, + paddle.nn.functional.multi_margin_loss, + input=input, + label=label, + ) + MMLoss = paddle.nn.MultiMarginLoss() + self.assertRaises( + ValueError, + MMLoss, + input=input, + label=label, + ) + paddle.enable_static() + + def test_MultiMarginLoss_p(self): + p = 2 + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) + reduction = 'mean' + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(batch_size, )).astype(np.int64) + expected = calc_multi_margin_loss(input=input, + p=p, + label=label, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + p=p, + input=input, + label=label, + reduction=reduction, + ) + + static_result = test_static( + place=place, + p=p, + input_np=input, + label_np=label, + reduction=reduction, + ) + np.testing.assert_allclose(static_result, expected) + np.testing.assert_allclose(static_result, dy_result) + np.testing.assert_allclose(dy_result, expected) + static_functional = test_static(place=place, + p=p, + input_np=input, + label_np=label, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + p=p, + input=input, + label=label, + reduction=reduction, + functional=True) + np.testing.assert_allclose(static_functional, expected) + np.testing.assert_allclose(static_functional, dy_functional) + np.testing.assert_allclose(dy_functional, expected) + + def test_MultiMarginLoss_weight(self): + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) + reduction = 'mean' + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(batch_size, )).astype(np.int64) + weight = np.random.uniform(0, 2, + size=(num_classes, )).astype(np.float64) + expected = calc_multi_margin_loss(input=input, + label=label, + weight=weight, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + input=input, + label=label, + weight=weight, + reduction=reduction, + ) + + static_result = test_static( + place=place, + input_np=input, + label_np=label, + weight_np=weight, + reduction=reduction, + ) + np.testing.assert_allclose(static_result, expected) + np.testing.assert_allclose(static_result, dy_result) + np.testing.assert_allclose(dy_result, expected) + static_functional = test_static(place=place, + input_np=input, + label_np=label, + weight_np=weight, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input=input, + label=label, + weight=weight, + reduction=reduction, + functional=True) + np.testing.assert_allclose(static_functional, expected) + np.testing.assert_allclose(static_functional, dy_functional) + np.testing.assert_allclose(dy_functional, expected) + + def test_MultiMarginLoss_static_data_shape(self): + batch_size = 5 + num_classes = 2 + shape = (batch_size, num_classes) + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + label = np.random.uniform(0, input.shape[1], + size=(batch_size, )).astype(np.int64) + weight = np.random.uniform(0, 2, + size=(num_classes, )).astype(np.float64) + + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + wrong_label_shape=(10, ), + functional=True, + ) + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + wrong_label_shape=(10, ), + functional=False, + ) + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + weight_np=weight, + wrong_weight_shape=(3, ), + functional=True, + ) + self.assertRaises( + ValueError, + test_static_data_shape, + place=place, + input_np=input, + label_np=label, + weight_np=weight, + wrong_weight_shape=(3, ), + functional=False, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index e47fa8c3c5..331131d6e2 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -109,6 +109,7 @@ from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401 from .layer.loss import HingeEmbeddingLoss # noqa: F401 from .layer.loss import CosineEmbeddingLoss # noqa: F401 +from .layer.loss import MultiMarginLoss from .layer.loss import TripletMarginWithDistanceLoss from .layer.loss import TripletMarginLoss from .layer.loss import SoftMarginLoss @@ -319,6 +320,7 @@ __all__ = [ # noqa 'Identity', 'CosineEmbeddingLoss', 'RReLU', + 'MultiMarginLoss', 'TripletMarginWithDistanceLoss', 'TripletMarginLoss', 'SoftMarginLoss', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 701997e0d0..bf0554d78d 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -92,6 +92,7 @@ from .loss import square_error_cost # noqa: F401 from .loss import ctc_loss # noqa: F401 from .loss import hinge_embedding_loss # noqa: F401 from .loss import cosine_embedding_loss # noqa: F401 +from .loss import multi_margin_loss from .loss import multi_label_soft_margin_loss from .loss import triplet_margin_with_distance_loss from .loss import triplet_margin_loss @@ -241,5 +242,6 @@ __all__ = [ # noqa 'rrelu', 'triplet_margin_with_distance_loss', 'triplet_margin_loss', + 'multi_margin_loss', 'soft_margin_loss', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 20b699c7a2..ed28bc2190 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3457,6 +3457,117 @@ def triplet_margin_loss(input, return loss +def multi_margin_loss(input, + label, + p: int = 1, + margin: float = 1.0, + weight=None, + reduction='mean', + name=None): + r""" + Measures a multi-class classification hinge loss between input :math:`input` and label :math:`label`: + + For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar + output :math:`label_i` is: + + .. math:: + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} + + where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function for i-th sample then becomes: + + .. math:: + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} + + + Parameters: + input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes. + + label (Tensor): Label tensor, the data type is int32 or int64. The shape of label is (N,) + + p (int, Optional): The power num. Default: :math:`1`. + + margin (float, Optional): Default: :math:`1`. + + weight (Tensor,optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of shape (C,) and the data type is float32, float64. + Default is ``'None'`` . + + + reduction (str, Optional):Indicate how to calculate the loss by batch_size. + the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + + name (str, Optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Output: Tensor. The tensor variable storing the multi_margin_loss of input and label. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32) + label = paddle.to_tensor([1, 2, 1], dtype=paddle.int32) + loss = F.multi_margin_loss(input, label, margin=1.0, reduction='none') + print(loss) + + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'multi_margin_loss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + + if not _non_static_mode(): + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'multi_margin_loss') + check_variable_and_dtype(label, 'label', ['int32', 'int64'], + 'multi_margin_loss') + if not (input.shape[0] == label.shape[0]): + raise ValueError( + "The label's shape[0] should be equal to input's shape[0], " + "but received input's shape[0] {} and label's shape[0]:{}. ".format( + input.shape[0], label.shape[0])) + label = label.reshape((-1, 1)) + index_sample = paddle.index_sample(input, label) + if weight is not None: + if not _non_static_mode(): + check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], + 'multi_margin_loss') + if not (input.shape[1] == weight.shape[0]): + raise ValueError( + "The weight's shape[0] should be equal to input's shape[1]" + "but received weight's shape[0]: {} and input's shape[1]: {}". + format(weight.shape[0], input.shape[1])) + weight = paddle.gather(weight, label, axis=0).reshape((-1, 1)) + loss = paddle.mean( + paddle.pow( + paddle.clip(weight * + (margin - index_sample + input), min=0.0), p), + axis=1) - weight * (margin**p / paddle.shape(input)[1]) + else: + loss = paddle.mean(paddle.pow( + paddle.clip(margin - index_sample + input, min=0.0), p), + axis=1) - margin**p / paddle.shape(input)[1] + + if reduction == 'mean': + return paddle.mean(loss, name=name) + elif reduction == 'sum': + return paddle.sum(loss, name=name) + elif reduction == 'none': + return loss + + def soft_margin_loss(input, label, reduction='mean', name=None): """ The API measures the soft margin loss between input predictions ``input`` diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 45cb652332..1acea10d67 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -83,6 +83,7 @@ from .loss import HingeEmbeddingLoss # noqa: F401 from .loss import TripletMarginWithDistanceLoss from .loss import TripletMarginLoss from .loss import SoftMarginLoss +from .loss import MultiMarginLoss from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index fea2add79b..6de2717a06 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1674,6 +1674,103 @@ class TripletMarginLoss(Layer): name=self.name) +class MultiMarginLoss(Layer): + r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between + input :math:`input` and label :math:`label`: + + For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar + output :math:`label_i` is: + + .. math:: + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}} + + where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function for i-th sample then becomes: + + .. math:: + \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}} + + + Parameters: + + p (int, Optional):The norm degree for pairwise distance. Default: :math:`1`. + + margin (float, Optional):Default: :math:`1`. + + weight (Tensor,optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of shape (C,) and the data type is float32, float64. + Default is ``'None'`` . + + reduction (str, optional): Indicate how to calculate the loss by batch_size, + the candidates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Call parameters: + input (Tensor): Input tensor, the data type is float32 or float64. + + label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64. + + Shape: + input: 2-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes. + + label: 1-D Tensor, the shape is [N,]. + + output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the label. + + Returns: + A callable object of MultiMarginLoss. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32) + label = paddle.to_tensor([0, 1, 2], dtype=paddle.int32) + + multi_margin_loss = nn.MultiMarginLoss(reduction='mean') + loss = multi_margin_loss(input, label) + print(loss) + """ + + def __init__(self, + p: int = 1, + margin: float = 1.0, + weight=None, + reduction="mean", + name=None): + super(MultiMarginLoss, self).__init__() + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'MultiMarginLoss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + self.p = p + self.margin = margin + self.weight = weight + self.reduction = reduction + self.name = name + + def forward(self, input, label): + return F.multi_margin_loss(input, + label, + p=self.p, + margin=self.margin, + weight=self.weight, + reduction=self.reduction, + name=self.name) + + class SoftMarginLoss(Layer): r""" Creates a criterion that measures a two-class soft margin loss between input predictions ``input`` -- GitLab