loss.py 76.0 KB
Newer Older
1
# -*- coding: utf-8 -*
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
# TODO: define loss functions of neural network
L
Leo Chen 已提交
17
import paddle.fluid as fluid
18
import paddle
19
from .. import functional as F
20
from paddle.fluid.framework import in_dygraph_mode
Z
zhiboniu 已提交
21
from .. import Layer
Z
zhiboniu 已提交
22
from paddle import in_dynamic_mode
23

24 25
__all__ = []

L
Leo Chen 已提交
26

Z
zhiboniu 已提交
27
class BCEWithLogitsLoss(Layer):
28
    r"""
29 30

    This operator combines the sigmoid layer and the :ref:`api_paddle_nn_BCELoss` layer.
31 32 33 34 35 36 37 38 39 40 41 42
    Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
    layer and some reduce operations.

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

    First this operator calculate loss function as follows:

    .. math::
43
           Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
44

45
    We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
46 47

    .. math::
48
           Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
49

50
    For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
51 52
    we reformulate the loss as follows:

53
        .. math::
54
           Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
85 86 87 88 89 90 91 92
        - logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, `*`],
          N is batch_size, `*` means number of additional dimensions. The ``logit``
          is usually the output of Linear layer. Available dtype is float32, float64.
        - label (Tensor): The target labels tensor. 2-D tensor with the same shape as
          ``logit``. The target labels which values should be numbers between 0 and 1.
          Available dtype is float32, float64.
        - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
          same as ``logit`` , else the shape of output is scalar.
93 94 95 96 97 98

    Returns:
        A callable object of BCEWithLogitsLoss.

    Examples:
        .. code-block:: python
99

100 101 102 103 104 105 106 107 108
            import paddle
            logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
            label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
            bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
            output = bce_logit_loss(logit, label)
            print(output.numpy())  # [0.45618808]

    """

109 110 111
    def __init__(
        self, weight=None, reduction='mean', pos_weight=None, name=None
    ):
112 113 114
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
115 116
                "received %s, which is not allowed." % reduction
            )
117

118
        super().__init__()
119 120 121 122 123 124 125
        self.weight = weight
        self.reduction = reduction
        self.pos_weight = pos_weight
        self.name = name

    def forward(self, logit, label):
        out = paddle.nn.functional.binary_cross_entropy_with_logits(
126 127 128 129 130 131 132
            logit,
            label,
            self.weight,
            self.reduction,
            self.pos_weight,
            self.name,
        )
133 134 135
        return out


Z
zhiboniu 已提交
136
class CrossEntropyLoss(Layer):
137
    r"""
138

139 140
    By default, this operator implements the cross entropy loss function with softmax. This function
    combines the calculation of the softmax operation and the cross entropy loss function
141
    to provide a more numerically stable computing.
S
swtkiwi 已提交
142

143
    This operator will calculate the cross entropy loss function without softmax when use_softmax=False.
144

145 146
    By default, this operator will calculate the mean of the result, and you can also affect
    the default behavior by using the reduction parameter. Please refer to the part of
147
    parameters for details.
148

149
    This operator can be used to calculate the softmax cross entropy loss with soft and hard labels.
150
    Where, the hard labels mean the actual label value, 0, 1, 2, etc.  And the soft labels
151
    mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
152

153
    The calculation of this operator includes the following two steps.
154

155
    -  **I.softmax cross entropy**
156

157
        1. Hard label (each sample can only be assigned into one category)
158

159
        1.1. when use_softmax=True
160

161 162
            .. math::
              \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N
163

164
            where, N is the number of samples and C is the number of categories.
165

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
        1.2. when use_softmax=False

            .. math::
              \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).


        2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1).

        2.1. when use_softmax=True

            .. math::
              \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories.

        2.2. when use_softmax=False

            .. math::
              \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).



192
    -  **II.Weight and reduction processing**
193 194 195 196 197 198 199 200 201 202 203

        1. Weight

            If the ``weight`` parameter is ``None`` , go to the next step directly.

            If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight
            according to soft_label = False or True as follows.

            1.1. Hard labels (soft_label = False)

            .. math::
204
                \\loss_j=loss_j*weight[label_j]
205

206

207 208 209 210 211 212 213
            1.2. Soft labels (soft_label = True)

             .. math::
                \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right)

        2. reduction

214
            2.1 if the ``reduction`` parameter is ``none``
215 216 217

            Return the previous result directly

218
            2.2 if the ``reduction`` parameter is ``sum``
219 220 221 222 223 224

            Return the sum of the previous results

            .. math::
               \\loss=\sum_{j}loss_j

225 226
            2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to
            the ``weight`` parameter as follows.
227

228
            2.3.1. If the  ``weight``  parameter is ``None``
229 230 231 232 233 234 235 236 237 238 239 240 241

            Return the average value of the previous results

             .. math::
                \\loss=\sum_{j}loss_j/N

            where, N is the number of samples and C is the number of categories.

            2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned

            1. Hard labels (soft_label = False)

             .. math::
242
                \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j]
243 244 245 246 247

            2. Soft labels (soft_label = True)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right)
248 249


250
    Parameters:
251
        weight (Tensor, optional): a manual rescaling weight given to each class.
252
            If given, has to be a Tensor of size C and the data type is float32, float64.
253
            Default is ``'None'`` .
254
        ignore_index (int64, optional): Specifies a target value that is ignored
255 256
            and does not contribute to the loss. A negative value means that no label
            value needs to be ignored. Only valid when soft_label = False.
257
            Default is ``-100`` .
258
        reduction (str, optional): Indicate how to average the loss by batch_size,
259 260 261 262 263
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
264
        soft_label (bool, optional): Indicate whether label is soft.
265 266
            If soft_label=False, the label is hard.  If soft_label=True, the label is soft.
            Default is ``False``.
267
        axis (int, optional): The index of dimension to perform softmax calculations.
268 269
            It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the number
            of dimensions of input :attr:`input`.
270
            Default is ``-1`` .
271
        use_softmax (bool, optional): Indicate whether compute softmax before cross_entropy.
272
            Default is ``True``.
273
        name (str, optional): The name of the operator. Default is ``None`` .
274 275 276 277
            For more information, please refer to :ref:`api_guide_Name` .


    Shape:
278 279
        - **input** (Tensor), the data type is float32, float64. Shape is
          :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes ,  ``k >= 1`` .
280
            Note:
281

282
                1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the
283 284 285
                output of softmax operator, which will produce incorrect results.

                2. when use_softmax=False, it expects the output of softmax operator.
286

287 288
        - **label** (Tensor)

289
            1. If soft_label=False, the shape is
290 291 292
            :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1.
            the data type is int32, int64, float32, float64, where each value is [0, C-1].

293
            2. If soft_label=True, the shape and data type should be same with ``input`` ,
294
            and the sum of the labels for each sample should be 1.
295

296 297 298 299
        - **output** (Tensor), Return the softmax cross_entropy loss of ``input`` and ``label``.
          The data type is the same as input.
          If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``.
          If :attr:`reduction` is ``'none'``:
300

301
            1. If soft_label = False, the dimension of return value is the same with ``label`` .
302

303
            2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` .
304

305
    Examples:
306 307

        .. code-block:: python
308

309
            # hard labels
310 311 312 313 314
            import paddle
            paddle.seed(99999)
            N=100
            C=200
            reduction='mean'
315
            input =  paddle.rand([N, C], dtype='float64')
316
            label =  paddle.randint(0, C, shape=[N], dtype='int64')
317 318
            weight = paddle.rand([C], dtype='float64')

319 320 321 322 323 324 325
            cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
                weight=weight, reduction=reduction)
            dy_ret = cross_entropy_loss(
                                       input,
                                       label)
            print(dy_ret.numpy()) #[5.41993642]

326
        .. code-block:: python
327 328

            # soft labels
329
            import paddle
330 331 332 333 334 335 336 337 338 339 340 341
            paddle.seed(99999)
            axis = -1
            ignore_index = -100
            N = 4
            C = 3
            shape = [N, C]
            reduction='mean'
            weight = None
            logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels /= paddle.sum(labels, axis=axis, keepdim=True)
            paddle_loss_mean = paddle.nn.functional.cross_entropy(
342 343 344
                                                                  logits,
                                                                  labels,
                                                                  soft_label=True,
345 346 347 348 349
                                                                  axis=axis,
                                                                  weight=weight,
                                                                  reduction=reduction)
            print(paddle_loss_mean.numpy()) #[1.12908343]

350 351
    """

352 353 354 355 356 357 358 359 360 361
    def __init__(
        self,
        weight=None,
        ignore_index=-100,
        reduction='mean',
        soft_label=False,
        axis=-1,
        use_softmax=True,
        name=None,
    ):
362
        super().__init__()
363 364
        self.weight = weight
        self.reduction = reduction
365
        self.ignore_index = ignore_index
366 367
        self.soft_label = soft_label
        self.axis = axis
368
        self.use_softmax = use_softmax
369
        self.name = name
370 371

    def forward(self, input, label):
372 373 374 375 376 377 378 379 380 381 382
        ret = paddle.nn.functional.cross_entropy(
            input,
            label,
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            soft_label=self.soft_label,
            axis=self.axis,
            use_softmax=self.use_softmax,
            name=self.name,
        )
383 384

        return ret
385 386


Z
zhiboniu 已提交
387
class HSigmoidLoss(Layer):
388 389
    """
    Hierarchical Sigmoid Layer.
390

391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
    Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
    represents the number of classes or the size of word dict.

    The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>_`. For the custom
    tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        feature_size (int): The number of features.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
            should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
            :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        weight_attr (ParamAttr, optional): The parameter attribute for the learnable weights
            of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
            ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
            initialized with Xavier. Default is None.
        bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
            is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
            hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
            set, the bias is initialized zero. Default is None.
425
        is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
            `path_code` should be passed to its forward method, otherwise `path_table` and `path_code`
            should not be passed to its forward method. Default is False.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True,
            the gradient of weight and input will be sparse. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input (Tensor): The input tensor. The shapes is [N, D], where N is batch size and D is feature size. It's data type should be float32, float64.
        label (Tensor): It's shapes is [N, 1]. It's data type should be int64.
        output (Tensor): The HSigmoid Loss of ``input`` and ``label``. Shape is [N, 1]

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_device('cpu')

L
Linjie Chen 已提交
444 445 446 447 448
            input = paddle.uniform([4, 3])
            # [[0.56194401  -0.22450298  -0.10741806] # random
            #  [0.36136317  0.23556745  0.88748658] # random
            #  [0.18151939  0.80947340  -0.31078976] # random
            #  [0.68886101  -0.14239830  -0.41297770]] # random
449 450 451
            label = paddle.to_tensor([0, 1, 4, 5])
            m = paddle.nn.HSigmoidLoss(3, 5)
            out = m(input, label)
L
Linjie Chen 已提交
452 453 454 455
            # [[2.42524505]
            #  [1.74917245]
            #  [3.14571381]
            #  [2.34564662]]
456 457
    """

458 459 460 461 462 463 464 465 466 467
    def __init__(
        self,
        feature_size,
        num_classes,
        weight_attr=None,
        bias_attr=None,
        is_custom=False,
        is_sparse=False,
        name=None,
    ):
468
        super().__init__()
469 470
        if (num_classes < 2) and (not is_custom):
            raise ValueError(
471 472
                "num_classes must not be less than 2 with default tree"
            )
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489

        if (not is_custom) and (is_sparse):
            print("Sparse mode should not be used without custom tree")
            is_sparse = False

        self._feature_size = feature_size
        self._num_classes = num_classes
        self._is_custom = is_custom
        self._is_sparse = is_sparse

        self._weight_attr = weight_attr
        self._bias_attr = bias_attr

        self._name = name
        self._dtype = paddle.get_default_dtype()

        remote_prefetch = is_sparse
490 491 492 493
        print(
            "With sparse mode, if your models has only"
            " small parameter prefetch may cause speed down"
        )
494 495

        C = self._num_classes if is_custom else self._num_classes - 1
496 497 498 499 500 501 502 503 504
        self.weight = self.create_parameter(
            [C, self._feature_size],
            attr=self._weight_attr,
            is_bias=False,
            dtype=self._dtype,
        )
        self.bias = self.create_parameter(
            [C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype
        )
505 506

    def forward(self, input, label, path_table=None, path_code=None):
507 508 509 510 511 512 513 514 515 516 517
        out = F.hsigmoid_loss(
            input,
            label,
            self._num_classes,
            self.weight,
            self.bias,
            path_table=path_table,
            path_code=path_code,
            is_sparse=self._is_sparse,
            name=self._name,
        )
518 519 520
        return out


Z
zhiboniu 已提交
521
class MSELoss(Layer):
522
    r"""
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
    **Mean Square Error Loss**
    Computes the mean square error (squared L2 norm) of given input and label.

    If :attr:`reduction` is set to ``'none'``, loss is calculated as:

    .. math::
        Out = (input - label)^2

    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:

    .. math::
        Out = \operatorname{mean}((input - label)^2)

    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:

    .. math::
        Out = \operatorname{sum}((input - label)^2)

541
    where `input` and `label` are `float32` tensors of same shape.
542 543 544 545

    Parameters:
        reduction (string, optional): The reduction method for the output,
            could be 'none' | 'mean' | 'sum'.
546 547 548
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
549 550
            Default is ``'mean'``.

B
Bai Yifan 已提交
551 552 553 554
    Shape:
        input (Tensor): Input tensor, the data type is float32 or float64
        label (Tensor): Label tensor, the data type is float32 or float64
        output (Tensor): output tensor storing the MSE loss of input and label, the data type is same as input.
555 556 557

    Examples:
        .. code-block:: python
558 559 560

            import paddle

B
Bai Yifan 已提交
561
            mse_loss = paddle.nn.loss.MSELoss()
562 563
            input = paddle.to_tensor([1.5])
            label = paddle.to_tensor([1.7])
B
Bai Yifan 已提交
564
            output = mse_loss(input, label)
565
            print(output)
B
Bai Yifan 已提交
566
            # [0.04000002]
567 568 569
    """

    def __init__(self, reduction='mean'):
570
        super().__init__()
571 572 573
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', "
574 575
                "but received {}.".format(reduction)
            )
576 577 578
        self.reduction = reduction

    def forward(self, input, label):
Z
zhiboniu 已提交
579
        if not in_dynamic_mode():
580 581 582 583 584 585
            fluid.data_feeder.check_variable_and_dtype(
                input, 'input', ['float32', 'float64'], 'MSELoss'
            )
            fluid.data_feeder.check_variable_and_dtype(
                label, 'label', ['float32', 'float64'], 'MSELoss'
            )
586

587
        if in_dygraph_mode():
588
            square_out = paddle._C_ops.square(paddle.subtract(input, label))
589 590
        else:
            square_out = paddle.square(paddle.subtract(input, label))
591 592 593 594 595
        if self.reduction == 'none':
            return square_out

        reduce_op = 'reduce_mean'
        if self.reduction == 'sum':
596 597
            square_out = paddle.sum(square_out)
            return square_out
598 599 600 601

        return getattr(fluid.layers, reduce_op)(square_out)


Z
zhiboniu 已提交
602
class L1Loss(Layer):
603
    r"""
604

605
    Construct a callable object of the ``L1Loss`` class.
606
    The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows.
607

608
    If `reduction` set to ``'none'``, the loss is:
L
Leo Chen 已提交
609 610

    .. math::
611
        Out = \lvert input - label\rvert
612

613
    If `reduction` set to ``'mean'``, the loss is:
614

L
Leo Chen 已提交
615
    .. math::
616
        Out = MEAN(\lvert input - label\rvert)
617

618
    If `reduction` set to ``'sum'``, the loss is:
619

L
Leo Chen 已提交
620
    .. math::
621
        Out = SUM(\lvert input - label\rvert)
L
Leo Chen 已提交
622

623

L
Leo Chen 已提交
624
    Parameters:
625
        reduction (str, optional): Indicate the reduction to apply to the loss,
L
Leo Chen 已提交
626
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
627 628 629
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
L
Leo Chen 已提交
630
            Default is ``'mean'``.
631 632 633
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
634 635 636 637 638
        - input (Tensor): The input tensor. The shapes is ``[N, *]``, where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        - label (Tensor): label. The shapes is ``[N, *]``, same shape as ``input`` . It's data type should be float32, float64, int32, int64.
        - output (Tensor): The L1 Loss of ``input`` and ``label``.
          If `reduction` is ``'none'``, the shape of output loss is ``[N, *]``, the same as ``input`` .
          If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
639

L
Leo Chen 已提交
640 641
    Examples:
        .. code-block:: python
642

L
Leo Chen 已提交
643
            import paddle
644

645 646
            input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]])
            label = paddle.to_tensor([[1.7, 1], [0.4, 0.5]])
647

C
Chen Long 已提交
648
            l1_loss = paddle.nn.L1Loss()
649
            output = l1_loss(input, label)
650
            print(output.numpy())
651 652
            # [0.35]

C
Chen Long 已提交
653
            l1_loss = paddle.nn.L1Loss(reduction='sum')
654
            output = l1_loss(input, label)
655
            print(output.numpy())
656 657
            # [1.4]

C
Chen Long 已提交
658
            l1_loss = paddle.nn.L1Loss(reduction='none')
659
            output = l1_loss(input, label)
C
Chen Long 已提交
660
            print(output)
661
            # [[0.20000005 0.19999999]
662
            # [0.2        0.79999995]]
663

L
Leo Chen 已提交
664 665
    """

666
    def __init__(self, reduction='mean', name=None):
L
Leo Chen 已提交
667 668 669
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
670 671
                "received %s, which is not allowed." % reduction
            )
672
        super().__init__()
L
Leo Chen 已提交
673
        self.reduction = reduction
674
        self.name = name
L
Leo Chen 已提交
675

676
    def forward(self, input, label):
677 678 679
        return paddle.nn.functional.l1_loss(
            input, label, self.reduction, name=self.name
        )
C
ceci3 已提交
680 681


Z
zhiboniu 已提交
682
class BCELoss(Layer):
C
ceci3 已提交
683
    """
684

C
ceci3 已提交
685
    This interface is used to construct a callable object of the ``BCELoss`` class.
686 687
    The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
    and target labels ``label`` . The binary_cross_entropy loss can be described as:
C
ceci3 已提交
688

C
ceci3 已提交
689
    If :attr:`weight` is set, the loss is:
C
ceci3 已提交
690 691

    .. math::
C
ceci3 已提交
692
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
693

C
ceci3 已提交
694
    If :attr:`weight` is None, the loss is:
C
ceci3 已提交
695 696

    .. math::
C
ceci3 已提交
697 698
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

699
    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
C
ceci3 已提交
700

C
ceci3 已提交
701
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
C
ceci3 已提交
702

C
ceci3 已提交
703 704
    .. math::
        Out = MEAN(Out)
705

C
ceci3 已提交
706
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
C
ceci3 已提交
707

C
ceci3 已提交
708 709
    .. math::
        Out = SUM(Out)
C
ceci3 已提交
710

711
    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
C
ceci3 已提交
712 713
    should be numbers between 0 and 1.

C
ceci3 已提交
714
    Parameters:
715 716
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
C
ceci3 已提交
717
            is float32, float64. Default is ``'None'``.
718
        reduction (str, optional): Indicate how to average the loss by batch_size,
C
ceci3 已提交
719
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
C
ceci3 已提交
720
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
721
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
C
ceci3 已提交
722
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
C
ceci3 已提交
723
            Default is ``'mean'``.
724 725 726 727
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
728 729 730 731 732 733 734 735
        - input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means
          number of additional dimensions. The input ``input`` should always
          be the output of sigmod.  Available dtype is float32, float64.
        - label (Tensor): 2-D tensor with the same shape as ``input``. The target
          labels which values should be numbers between 0 and 1. Available
          dtype is float32, float64.
        - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
          same as ``input`` , else the shape of output is scalar.
C
ceci3 已提交
736

737
    Returns:
C
ceci3 已提交
738 739
        A callable object of BCELoss.

C
ceci3 已提交
740 741
    Examples:
        .. code-block:: python
C
ceci3 已提交
742

C
ceci3 已提交
743
            import paddle
744

745 746
            input = paddle.to_tensor([0.5, 0.6, 0.7])
            label = paddle.to_tensor([1.0, 0.0, 1.0])
C
Chen Long 已提交
747
            bce_loss = paddle.nn.BCELoss()
748
            output = bce_loss(input, label)
749 750 751
            print(output)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [0.65537101])
752

C
ceci3 已提交
753 754
    """

755
    def __init__(self, weight=None, reduction='mean', name=None):
C
ceci3 已提交
756 757 758
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
759 760
                "received %s, which is not allowed." % reduction
            )
C
ceci3 已提交
761

762
        super().__init__()
C
ceci3 已提交
763 764
        self.weight = weight
        self.reduction = reduction
765
        self.name = name
C
ceci3 已提交
766 767

    def forward(self, input, label):
768 769 770
        out = paddle.nn.functional.binary_cross_entropy(
            input, label, self.weight, self.reduction, self.name
        )
771
        return out
772 773


Z
zhiboniu 已提交
774
class NLLLoss(Layer):
775
    r"""
S
swtkiwi 已提交
776

777
    This class accepts input and target label and returns negative log likelihood
778
    cross error. It is useful to train a classification problem with C classes.
779

780
    The input for the loss is expected to contain log-probabilities of
781
    each classes. It has to be a Tensor of size either (batch_size, C) or
782 783 784 785
    (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
    The label for the loss should be a class index in the range [0, C-1]
    where C is the number of classes. If ignore_index is specified, the
    specified target value does not contribute to the input gradient.
786

787 788 789
    If the optional argument `weight` is provided, it should be a 1D Tensor
    assigning weight to each of the classed. This is particularly useful
    when you have an unbalanced training set.
790

791 792 793 794
    The loss is calculated as follows.
    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
795 796

        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
797
        l_n = - w_{y_n} x_{n,y_n}, \quad
798
        w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore_index}\},
799 800 801 802 803

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then

    .. math::
804 805 806 807 808 809 810 811 812 813

        \ell(x, y) =
        \left\{
            \begin{array}{lcl}
            \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
            \text{if  reduction} = \text{'mean';}\\
            \sum_{n=1}^N l_n,  &
            \text{if  reduction} = \text{'sum'.}
            \end{array}
        \right.
814 815

    Parameters:
816 817
        weight (Tensor, optional): Weight tensor, a manual rescaling weight given
            to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
818
            it treated as if having all ones. the data type is
819
            float32, float64, Default is ``'None'``.
820
        ignore_index (int, optional): Specifies a target value that is ignored
821
            and does not contribute to the input gradient.
822
        reduction (str, optional): Indicate how to average the loss,
823
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
824 825 826
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
827
            Default is ``'mean'``.
828
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
829

830
    Shape:
831
        - input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
832 833
            But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
            The data type is float32, float64.
834
        - label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
835
            The data type is int64.
836
        - output (Tensor): the `negative log likelihood loss` between input `x` and `label`.
837 838
            If `reduction` is `'none'`, the shape is `[N, *]`.
            If `reduction` is `'sum'` or `'mean'`, the shape is `[1]`.
839 840 841 842

    Examples:
        .. code-block:: python

843
                import paddle
844

845
                nll_loss = paddle.nn.loss.NLLLoss()
846
                log_softmax = paddle.nn.LogSoftmax(axis=1)
847

848 849 850 851 852
                input = paddle.to_tensor([[0.88103855, 0.9908683 , 0.6226845 ],
                                          [0.53331435, 0.07999352, 0.8549948 ],
                                          [0.25879037, 0.39530203, 0.698465  ],
                                          [0.73427284, 0.63575995, 0.18827209],
                                          [0.05689114, 0.0862954 , 0.6325046 ]], "float32")
853
                log_out = log_softmax(input)
854
                label = paddle.to_tensor([0, 2, 1, 1, 0], "int64")
855
                result = nll_loss(log_out, label)
856
                print(result) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [1.07202101])
857

858
    """
859

860 861 862
    def __init__(
        self, weight=None, ignore_index=-100, reduction='mean', name=None
    ):
863
        if reduction not in ['sum', 'mean', 'none']:
864
            raise ValueError(
865
                "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
866 867
                "'none', but received %s, which is not allowed." % reduction
            )
868
        super().__init__()
869 870 871 872
        self._weight = weight
        self._ignore_index = ignore_index
        self._reduction = reduction
        self._name = name
873

874
    def forward(self, input, label):
875 876 877 878 879 880 881 882
        return F.nll_loss(
            input,
            label,
            weight=self._weight,
            ignore_index=self._ignore_index,
            reduction=self._reduction,
            name=self._name,
        )
883 884


Z
zhiboniu 已提交
885
class KLDivLoss(Layer):
886
    r"""
887

888 889 890 891
    Generate a callable object of 'KLDivLoss' to calculate the
    Kullback-Leibler divergence loss between Input(X) and
    Input(Target). Notes that Input(X) is the log-probability
    and Input(Target) is the probability.
892 893 894 895 896 897

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

    Parameters:
L
LielinJiang 已提交
898 899 900 901 902 903 904
        reduction (Tensor): Indicate how to average the loss,
             the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
             If `reduction` is ``'mean'``, the reduced mean loss is returned;
             If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
             if `reduction` is ``'sum'``, the reduced sum loss is returned;
             if `reduction` is ``'none'``, no reduction will be apllied.
             Default is ``'mean'``.
905 906

    Shape:
907 908
        - input (Tensor): ``(N, *)``, where ``*`` means, any number of additional dimensions.
        - label (Tensor): ``(N, *)``, same shape as input.
909
        - output (Tensor): tensor with shape: [1] by default.
910 911 912 913 914 915

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn
916

917
            shape = (5, 20)
918 919
            x = paddle.uniform(shape, min=-10, max=10).astype('float32')
            target = paddle.uniform(shape, min=-10, max=10).astype('float32')
920

L
LielinJiang 已提交
921
            # 'batchmean' reduction, loss shape will be [1]
922
            kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
923
            pred_loss = kldiv_criterion(x, target)
L
LielinJiang 已提交
924
            # shape=[1]
925

926 927
            # 'mean' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='mean')
928
            pred_loss = kldiv_criterion(x, target)
929 930 931 932
            # shape=[1]

            # 'sum' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='sum')
933
            pred_loss = kldiv_criterion(x, target)
934 935 936 937
            # shape=[1]

            # 'none' reduction, loss shape is same with X shape
            kldiv_criterion = nn.KLDivLoss(reduction='none')
938
            pred_loss = kldiv_criterion(x, target)
939
            # shape=[5, 20]
940

941 942 943
    """

    def __init__(self, reduction='mean'):
944
        super().__init__()
945 946 947
        self.reduction = reduction

    def forward(self, input, label):
L
LielinJiang 已提交
948
        out = F.kl_div(input, label, self.reduction)
949 950 951
        return out


Z
zhiboniu 已提交
952
class MarginRankingLoss(Layer):
953
    r"""
954 955

    This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
956
    The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
957 958
    , use the math function as follows.

959
    .. math::
960
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

979
    Shape:
980

N
Noel 已提交
981 982
        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

983
        other: N-D Tensor, `other` have the same shape and dtype as `input`.
N
Noel 已提交
984

985
        label: N-D Tensor, label have the same shape and dtype as `input`.
N
Noel 已提交
986

987
        output: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
988 989 990 991 992 993 994 995

    Returns:
        A callable object of MarginRankingLoss.

    Examples:

        .. code-block:: python

996 997
            import paddle

C
Chen Long 已提交
998 999
            input = paddle.to_tensor([[1, 2], [3, 4]], dtype="float32")
            other = paddle.to_tensor([[2, 1], [2, 4]], dtype="float32")
Z
Zhong Hui 已提交
1000
            label = paddle.to_tensor([[1, -1], [-1, -1]], dtype="float32")
1001
            margin_rank_loss = paddle.nn.MarginRankingLoss()
1002
            loss = margin_rank_loss(input, other, label)
1003 1004 1005

            print(loss)
            # [0.75]
1006 1007 1008 1009 1010
    """

    def __init__(self, margin=0.0, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
1011
                "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
1012 1013
                "received %s, which is not allowed." % reduction
            )
1014
        super().__init__()
1015 1016 1017 1018
        self.margin = margin
        self.reduction = reduction
        self.name = name

1019
    def forward(self, input, other, label):
1020 1021 1022
        out = paddle.nn.functional.margin_ranking_loss(
            input, other, label, self.margin, self.reduction, self.name
        )
1023
        return out
1024 1025


Z
zhiboniu 已提交
1026
class CTCLoss(Layer):
1027 1028
    """

1029 1030 1031
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
1032 1033 1034 1035 1036 1037 1038
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.

    Shape:
1039
        log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
1040 1041 1042
        labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
1043
        norm_by_times (bool, default false) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
1044 1045 1046

    Returns:
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
1047

1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
    Examples:

        .. code-block:: python

            # declarative mode
            import paddle

            # length of the longest logit sequence
            max_seq_length = 4
            #length of the longest label sequence
            max_label_length = 3
            # number of logit sequences
            batch_size = 2
            # class num
            class_num = 3

1064
            log_probs = paddle.to_tensor([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
                                    [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

                                    [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
                                    [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

                                    [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
                                    [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

                                    [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
                                    [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

                                    [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
1077 1078 1079 1080 1081
                                    [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]], dtype="float32")
            labels = paddle.to_tensor([[1, 2, 2],
                            [1, 2, 2]], dtype="int32")
            input_lengths = paddle.to_tensor([5, 5], dtype="int64")
            label_lengths = paddle.to_tensor([3, 3], dtype="int64")
1082

1083 1084
            loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
                input_lengths,
1085
                label_lengths)
1086 1087 1088
            print(loss)
            # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [3.91798496, 2.90765190])
1089

1090 1091
            loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
                input_lengths,
1092
                label_lengths)
1093 1094 1095
            print(loss)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [1.13760614])
1096 1097 1098
    """

    def __init__(self, blank=0, reduction='mean'):
1099
        super().__init__()
1100 1101 1102
        self.blank = blank
        self.reduction = reduction

1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119
    def forward(
        self,
        log_probs,
        labels,
        input_lengths,
        label_lengths,
        norm_by_times=False,
    ):
        return paddle.nn.functional.ctc_loss(
            log_probs,
            labels,
            input_lengths,
            label_lengths,
            self.blank,
            self.reduction,
            norm_by_times=norm_by_times,
        )
1120 1121


Z
zhiboniu 已提交
1122
class SmoothL1Loss(Layer):
1123
    r"""
1124 1125 1126 1127 1128 1129 1130
    This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

1131
        loss(x, y) = \frac{1}{n}\sum_{i}z_i
1132

1133
    where :math:`z_i` is given by:
1134 1135 1136

    .. math::

1137
        \mathop{z_i} = \left\{\begin{array}{rcl}
1138 1139 1140
                0.5(x_i - y_i)^2 & & {if |x_i - y_i| < \delta} \\
                \delta * |x_i - y_i| - 0.5 * \delta^2 & & {otherwise}
            \end{array} \right.
1141 1142 1143 1144 1145 1146 1147 1148

    Parameters:
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
1149
        delta (float, optional): Specifies the hyperparameter :math:`\delta` to be used.
1150 1151
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
1152 1153
            negative/zero values. Default value is :math:`1.0`.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1154 1155 1156

    Call Parameters:

1157 1158
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C),
        where C is number of classes, and if shape is more than 2D,
1159 1160
        this is (N, C, D1, D2,..., Dk), k >= 1.

1161
        label (Tensor): Label tensor, the data type is float32 or float64.
1162
        The shape of label is the same as the shape of input.
1163

1164 1165
    Returns:
        Tensor, The tensor storing the smooth_l1_loss of input and label.
1166 1167 1168 1169 1170

    Examples:
        .. code-block:: python

            import paddle
1171 1172
            input = paddle.rand([3, 3]).astype("float32")
            label = paddle.rand([3, 3]).astype("float32")
1173 1174
            loss = paddle.nn.SmoothL1Loss()
            output = loss(input, label)
G
Guanghua Yu 已提交
1175
            print(output)
1176
            # [0.049606]
1177 1178 1179
    """

    def __init__(self, reduction='mean', delta=1.0, name=None):
1180
        super().__init__()
1181 1182 1183 1184 1185
        self.reduction = reduction
        self.delta = delta
        self.name = name

    def forward(self, input, label):
1186 1187 1188 1189 1190 1191 1192
        return F.smooth_l1_loss(
            input,
            label,
            reduction=self.reduction,
            delta=self.delta,
            name=self.name,
        )
1193 1194


Y
yangguohao 已提交
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
class MultiLabelSoftMarginLoss(Layer):
    r"""Creates a criterion that optimizes a multi-class multi-classification
        hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
        and output :math:`y` (which is a 2D `Tensor` of target class indices).
        For each sample in the mini-batch:

        .. math::
            \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}

        where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
        :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
        :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
        and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
        :math:`y` and :math:`x` must have the same size.

        Parameters:
1211
            weight (Tensor,optional): a manual rescaling weight given to each class.
Y
yangguohao 已提交
1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255
                    If given, has to be a Tensor of size C and the data type is float32, float64.
                    Default is ``'None'`` .
            reduction (str, optional): Indicate how to average the loss by batch_size,
                    the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                    If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                    If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                    If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                    Default: ``'mean'``
            name (str, optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.

        Call parameters:
            input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
            label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

        Shape:
            input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
            label: N-D Tensor, same shape as the input.
            output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

        Returns:
            A callable object of MultiLabelSoftMarginLoss.

        Examples:
            .. code-block:: python

                import paddle
                import paddle.nn as nn

                input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
                label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

                multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none')
                loss = multi_label_soft_margin_loss(input, label)
                print(loss)
                # Tensor([3.49625897, 0.71111226, 0.43989015])

                multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='mean')
                loss = multi_label_soft_margin_loss(input, label)
                print(loss)
                # Tensor([1.54908717])
        """

    def __init__(self, weight=None, reduction="mean", name=None):
1256
        super().__init__()
Y
yangguohao 已提交
1257 1258 1259
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', "
1260 1261
                "but received {}.".format(reduction)
            )
Y
yangguohao 已提交
1262 1263 1264 1265 1266
        self.weight = weight
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1267 1268 1269 1270 1271 1272 1273
        return F.multi_label_soft_margin_loss(
            input,
            label,
            weight=self.weight,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1274 1275


1276 1277
class HingeEmbeddingLoss(Layer):
    r"""
1278
    Create a callable object of `HingeEmbeddingLoss` to calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356
    This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
    and is typically used for learning nonlinear embeddings or semi-supervised learning.

    The loss function for :math:`n`-th sample in the mini-batch is

    .. math::
        l_n = \begin{cases}
            x_n, & \text{if}\; y_n = 1,\\
            \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
        \end{cases}

    and the total loss functions is

    .. math::
        \ell(x, y) = \begin{cases}
            \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
        \end{cases}

    where :math:`L = \{l_1,\dots,l_N\}^\top`.

    Parameters:

        margin (float, optional): Specifies the hyperparameter margin to be used.
            The value determines how large the input need to be to calculate in
            hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss.
            Default = 1.0
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:

        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.

        label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

    Shape:

        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

        label: N-D Tensor, same shape as the input.

        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

    Returns:

        Tensor, The tensor variable storing the hinge_embedding_loss of input and label.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            # label elements in {1., -1.}
            label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

            hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='none')
            loss = hinge_embedding_loss(input, label)
            print(loss)
            # Tensor([[0., -2., 0.],
            #         [0., -1., 2.],
            #         [1., 1., 1.]])

            hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='mean')
            loss = hinge_embedding_loss(input, label)
            print(loss)
            # Tensor([0.22222222])
    """

    def __init__(self, margin=1.0, reduction="mean", name=None):
1357
        super().__init__()
1358 1359 1360 1361 1362
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1363 1364 1365 1366 1367 1368 1369
        return F.hinge_embedding_loss(
            input,
            label,
            reduction=self.reduction,
            margin=self.margin,
            name=self.name,
        )
1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442


class CosineEmbeddingLoss(Layer):
    r"""
    This interface is used to construct a callable object of the ``CosineEmbeddingLoss`` class.
    The CosineEmbeddingLoss layer measures the cosine_embedding loss between input predictions ``input1``, ``input2``
    and target labels ``label`` with values 1 or 0. This is used for measuring whether two inputs are similar or
    dissimilar and is typically used for learning nonlinear embeddings or semi-supervised learning.
    The cosine embedding loss can be described as:

    If label = 1, then the loss value can be calculated as follow:

    .. math::
        Out = 1 - cos(input1, input2)

    If label = -1, then the loss value can be calculated as follow:

    .. math::
        Out = max(0, cos(input1, input2)) - margin

    The operator cos can be described as follow:
     .. math::
        cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2}

    Parameters:
        margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
            :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
            default value is :math:`0`.
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
                         Available dtypes are float32, float64.
        input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
                         Available dtypes are float32, float64.
        label (Tensor): tensor with shape: [N] or [1]. The target labels values should be -1 or 1.
                         Available dtypes are int32, int64, float32, float64.
        output (Tensor): Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``.
                         If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` .
                         If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].

    Examples:
        .. code-block:: python

            import paddle

            input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32')
            input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32')
            label = paddle.to_tensor([1, -1], 'int64')

            cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='mean')
            output = cosine_embedding_loss(input1, input2, label)
            print(output) # [0.21155193]

            cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='sum')
            output = cosine_embedding_loss(input1, input2, label)
            print(output) # [0.42310387]

            cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='none')
            output = cosine_embedding_loss(input1, input2, label)
            print(output) # [0.42310387, 0.        ]

    """

    def __init__(self, margin=0, reduction='mean', name=None):
        if margin > 1 or margin < -1:
            raise ValueError(
                "The value of 'margin' should be in the interval of [-1, 1], but received %f, which is not allowed."
1443 1444
                % margin
            )
1445 1446 1447
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' should be 'sum', 'mean' or "
1448 1449
                "'none', but received %s, which is not allowed." % reduction
            )
1450
        super().__init__()
1451 1452 1453 1454 1455
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input1, input2, label):
1456 1457 1458 1459 1460 1461 1462 1463
        return F.cosine_embedding_loss(
            input1,
            input2,
            label,
            margin=self.margin,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480


class TripletMarginWithDistanceLoss(Layer):
    r"""
    Creates a criterion that measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, D)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}

    where the default `distance_function`
1481

Y
yangguohao 已提交
1482
    .. math::
1483
        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_2
1484 1485

    or user can define their own distance function. `margin` is a nonnegative margin representing the minimum difference
Y
yangguohao 已提交
1486 1487 1488 1489 1490
    between the positive and negative distances that is required for the loss to be 0. If `swap` is true, it will compare distance of (input, negative) with
    distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.

    Parameters:
        distance_function (Callable, Optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
1491

Y
yangguohao 已提交
1492 1493 1494 1495
        margin (float, Optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
                between the positive and negative distances required for the loss to be 0. Larger
                margins penalize cases where the negative examples are not distant enough from the
                anchors, relative to the positives.
1496

Y
yangguohao 已提交
1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507
        swap (bool, Optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
                and negative samples) if swap distance smaller than negative distance. Default: ``False``.

        reduction (str, Optional):Indicate how to average the loss by batch_size.
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
1508

Y
yangguohao 已提交
1509 1510
    Shapes:
        input (Tensor):Input tensor, the data type is float32 or float64.
1511
    the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
Y
yangguohao 已提交
1512 1513

        positive (Tensor):Positive tensor, the data type is float32 or float64.
1514
    The shape of label is the same as the shape of input.
Y
yangguohao 已提交
1515 1516

        negative (Tensor):Negative tensor, the data type is float32 or float64.
1517
    The shape of label is the same as the shape of input.
1518

1519
        output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.
Y
yangguohao 已提交
1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544

    Return:
        A callable object of TripletMarginWithDistanceLoss

    Examples:
        .. code-block:: python

            import paddle
            from paddle.nn import TripletMarginWithDistanceLoss

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='none')
            loss = triplet_margin_with_distance_loss(input, positive, negative,)
            print(loss)
            # Tensor([0.        , 0.57496738, 0.        ])

            triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='mean')
            loss = triplet_margin_with_distance_loss(input, positive, negative,)
            print(loss)
            # Tensor([0.19165580])

    """

1545 1546 1547 1548 1549 1550 1551 1552
    def __init__(
        self,
        distance_function=None,
        margin=1.0,
        swap=False,
        reduction: str = 'mean',
        name=None,
    ):
1553
        super().__init__()
Y
yangguohao 已提交
1554 1555 1556 1557
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in TripletMarginWithDistanceLoss "
                "should be 'sum', 'mean' or 'none', but "
1558 1559
                "received %s, which is not allowed." % reduction
            )
Y
yangguohao 已提交
1560 1561 1562 1563 1564 1565 1566
        self.margin = margin
        self.swap = swap
        self.reduction = reduction
        self.distance_function = distance_function
        self.name = name

    def forward(self, input, positive, negative):
1567 1568 1569 1570 1571 1572 1573 1574 1575
        return F.triplet_margin_with_distance_loss(
            input,
            positive,
            negative,
            margin=self.margin,
            swap=self.swap,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644


class TripletMarginLoss(Layer):
    r"""
    Creates a criterion that measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, *)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}


    where

    .. math::
        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p

    Parameters:
        margin (float, Optional):Default: :math:`1`.

        p (int, Optional):The norm degree for pairwise distance. Default: :math:`2`.

        epsilon (float, Optional):Add small value to avoid division by zero,
            default value is 1e-6.

        swap (bool, Optional):The distance swap change the negative distance to the distance between
            positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`.
            Default: ``False``.

        reduction (str, Optional):Indicate how to average the loss by batch_size.
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``

        name (str,Optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:
        input (Tensor):Input tensor, the data type is float32 or float64.
        the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

        positive (Tensor):Positive tensor, the data type is float32 or float64.
        The shape of label is the same as the shape of input.

        negative (Tensor):Negative tensor, the data type is float32 or float64.
        The shape of label is the same as the shape of input.

    Returns:
        Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative.

    Examples:
        .. code-block:: python

            import paddle

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            triplet_margin_loss = paddle.nn.TripletMarginLoss(reduction='none')
            loss = triplet_margin_loss(input, positive, negative)
            print(loss)
            # Tensor([0.        , 0.57496738, 0.        ])
1645

Y
yangguohao 已提交
1646 1647 1648 1649 1650 1651 1652
            triplet_margin_loss = paddle.nn.TripletMarginLoss(margin=1.0, swap=True, reduction='mean', )
            loss = triplet_margin_loss(input, positive, negative,)
            print(loss)
            # Tensor([0.19165580])

    """

1653 1654 1655 1656 1657 1658 1659 1660 1661
    def __init__(
        self,
        margin=1.0,
        p=2.0,
        epsilon=1e-6,
        swap=False,
        reduction='mean',
        name=None,
    ):
1662
        super().__init__()
Y
yangguohao 已提交
1663 1664 1665
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in TripletMarginLoss should be 'sum', 'mean' or 'none', but "
1666 1667
                "received %s, which is not allowed." % reduction
            )
Y
yangguohao 已提交
1668 1669 1670 1671 1672 1673 1674 1675
        self.margin = margin
        self.p = p
        self.epsilon = epsilon
        self.swap = swap
        self.reduction = reduction
        self.name = name

    def forward(self, input, positive, negative):
1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686
        return F.triplet_margin_loss(
            input,
            positive,
            negative,
            margin=self.margin,
            p=self.p,
            epsilon=self.epsilon,
            swap=self.swap,
            reduction=self.reduction,
            name=self.name,
        )
1687 1688


Y
yangguohao 已提交
1689 1690
class MultiMarginLoss(Layer):
    r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between
1691
    input :math:`input` and label :math:`label`:
Y
yangguohao 已提交
1692

1693 1694
    For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
    output :math:`label_i` is:
Y
yangguohao 已提交
1695

1696 1697
    .. math::
        \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}
Y
yangguohao 已提交
1698

1699
    where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.
Y
yangguohao 已提交
1700

1701 1702
    Optionally, you can give non-equal weighting on the classes by passing
    a 1D :attr:`weight` tensor into the constructor.
Y
yangguohao 已提交
1703

1704
    The loss function for i-th sample then becomes:
Y
yangguohao 已提交
1705

1706 1707
    .. math::
        \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}
Y
yangguohao 已提交
1708 1709


1710
    Parameters:
Y
yangguohao 已提交
1711

1712
        p (int, Optional):The norm degree for pairwise distance. Default: :math:`1`.
Y
yangguohao 已提交
1713

1714
        margin (float, Optional):Default: :math:`1`.
Y
yangguohao 已提交
1715

1716 1717 1718
        weight (Tensor,optional): a manual rescaling weight given to each class.
                If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
                Default is ``'None'`` .
Y
yangguohao 已提交
1719

1720 1721 1722 1723 1724 1725
        reduction (str, optional): Indicate how to calculate the loss by batch_size,
                the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``
Y
yangguohao 已提交
1726

1727 1728
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
Y
yangguohao 已提交
1729

1730 1731
    Call parameters:
        input (Tensor): Input tensor, the data type is float32 or float64.
Y
yangguohao 已提交
1732

1733
        label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64.
Y
yangguohao 已提交
1734

1735 1736
    Shape:
        input: 2-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes.
Y
yangguohao 已提交
1737

1738
        label: 1-D Tensor, the shape is [N,].
Y
yangguohao 已提交
1739

1740
        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the label.
Y
yangguohao 已提交
1741

1742 1743
    Returns:
        A callable object of MultiMarginLoss.
Y
yangguohao 已提交
1744

1745 1746
    Examples:
        .. code-block:: python
Y
yangguohao 已提交
1747

1748 1749
            import paddle
            import paddle.nn as nn
Y
yangguohao 已提交
1750

1751 1752
            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            label = paddle.to_tensor([0, 1, 2], dtype=paddle.int32)
Y
yangguohao 已提交
1753

1754 1755 1756 1757
            multi_margin_loss = nn.MultiMarginLoss(reduction='mean')
            loss = multi_margin_loss(input, label)
            print(loss)
    """
Y
yangguohao 已提交
1758

1759 1760 1761 1762 1763 1764 1765 1766
    def __init__(
        self,
        p: int = 1,
        margin: float = 1.0,
        weight=None,
        reduction="mean",
        name=None,
    ):
1767
        super().__init__()
Y
yangguohao 已提交
1768 1769 1770
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MultiMarginLoss' should be 'sum', 'mean' or 'none', "
1771 1772
                "but received {}.".format(reduction)
            )
Y
yangguohao 已提交
1773 1774 1775 1776 1777 1778 1779
        self.p = p
        self.margin = margin
        self.weight = weight
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1780 1781 1782 1783 1784 1785 1786 1787 1788
        return F.multi_margin_loss(
            input,
            label,
            p=self.p,
            margin=self.margin,
            weight=self.weight,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1789 1790


1791 1792
class SoftMarginLoss(Layer):
    r"""
1793

1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811
    Creates a criterion that measures a two-class soft margin loss between input predictions ``input``
    and target labels ``label`` . It can be described as:

    .. math::
        Out = log(1 + exp((-label * input)))

    Parameters:

        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.

        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
1812 1813 1814 1815 1816 1817 1818 1819
        - Input (Tensor): The input tensor with shape: ``[N, *]``,
          N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf
          Available dtype is float32, float64.
        - Label (Tensor): The target labels tensor with the same shape as
          ``input``. The target labels which values should be numbers -1 or 1.
          Available dtype is int32, int64, float32, float64.
        - Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
          same as ``input`` , else the shape of output is [1].
1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832

    Returns:
        A callable object of SoftMarginLoss.

    Examples:
        .. code-block:: python

            import paddle

            input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
            label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
            soft_margin_loss = paddle.nn.SoftMarginLoss()
            output = soft_margin_loss(input, label)
1833 1834 1835
            print(output)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [0.64022040])
1836

1837 1838
            input_np = paddle.uniform(shape=(5, 5), min=0.1, max=0.8, dtype="float64")
            label_np = paddle.randint(high=2, shape=(5, 5), dtype="int64")
1839 1840 1841 1842 1843
            label_np[label_np==0]=-1
            input = paddle.to_tensor(input_np)
            label = paddle.to_tensor(label_np)
            soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none')
            output = soft_margin_loss(input, label)
1844 1845 1846 1847 1848 1849 1850
            print(output)
            # Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
            #        [[0.61739663, 0.51405668, 1.09346100, 0.42385561, 0.91602303],
            #         [0.76997038, 1.01977148, 0.98971722, 1.13976032, 0.88152088],
            #         [0.55476735, 1.10505384, 0.89923519, 0.45018155, 1.06587511],
            #         [0.37998142, 0.48067240, 0.47791212, 0.55664053, 0.98581399],
            #         [0.78571653, 0.59319711, 0.39701841, 0.76172109, 0.83781742]])
1851

1852 1853 1854 1855 1856 1857
    """

    def __init__(self, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but "
1858 1859
                "received %s, which is not allowed." % reduction
            )
1860

1861
        super().__init__()
1862 1863 1864 1865
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1866 1867 1868
        out = paddle.nn.functional.soft_margin_loss(
            input, label, self.reduction, self.name
        )
1869
        return out