loss.py 90.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import paddle

17
# TODO: define loss functions of neural network
18
from paddle import fluid, in_dynamic_mode
19
from paddle.fluid.framework import in_dygraph_mode
20 21

from .. import functional as F
22
from .layers import Layer
23

24 25
__all__ = []

L
Leo Chen 已提交
26

Z
zhiboniu 已提交
27
class BCEWithLogitsLoss(Layer):
28
    r"""
29

学渣戊's avatar
学渣戊 已提交
30
    Combine the sigmoid layer and the :ref:`api_paddle_nn_BCELoss` layer.
31 32 33 34 35 36 37

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

学渣戊's avatar
学渣戊 已提交
38
    Firstly, calculate loss function as follows:
39 40

    .. math::
41
           Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
42

43
    We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
44 45

    .. math::
46
           Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
47

48
    For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
49 50
    we reformulate the loss as follows:

51
        .. math::
52
           Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
53

学渣戊's avatar
学渣戊 已提交
54
    Then, if ``weight`` or ``pos_weight`` is not None, then multiply the
55 56 57 58
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

学渣戊's avatar
学渣戊 已提交
59 60
    Finally, apply reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, will return the original loss `Out`.
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
学渣戊's avatar
学渣戊 已提交
83 84 85
        - logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, `*`], N is batch_size, `*` means number of additional dimensions. The ``logit`` is usually the output of Linear layer. Available dtype is float32, float64.
        - label (Tensor): The target labels tensor. 2-D tensor with the same shape as ``logit``. The target labels which values should be numbers between 0 and 1. Available dtype is float32, float64.
        - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``logit`` , else the shape of output is scalar.
86 87 88 89 90

    Returns:
        A callable object of BCEWithLogitsLoss.

    Examples:
学渣戊's avatar
学渣戊 已提交
91

92
        .. code-block:: python
93

94
            >>> import paddle
学渣戊's avatar
学渣戊 已提交
95

96 97 98 99 100 101 102
            >>> logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
            >>> label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
            >>> bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
            >>> output = bce_logit_loss(logit, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.45618808)
103 104 105

    """

106 107 108
    def __init__(
        self, weight=None, reduction='mean', pos_weight=None, name=None
    ):
109 110 111
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
112 113
                "received %s, which is not allowed." % reduction
            )
114

115
        super().__init__()
116 117 118 119 120 121 122
        self.weight = weight
        self.reduction = reduction
        self.pos_weight = pos_weight
        self.name = name

    def forward(self, logit, label):
        out = paddle.nn.functional.binary_cross_entropy_with_logits(
123 124 125 126 127 128 129
            logit,
            label,
            self.weight,
            self.reduction,
            self.pos_weight,
            self.name,
        )
130 131 132
        return out


Z
zhiboniu 已提交
133
class CrossEntropyLoss(Layer):
134
    r"""
135

136
    By default, the cross entropy loss function is implemented using softmax. This function
137
    combines the calculation of the softmax operation and the cross entropy loss function
138
    to provide a more numerically stable computing.
S
swtkiwi 已提交
139

140
    Calculate the cross entropy loss function without softmax when use_softmax=False.
141

142
    By default, calculate the mean of the result, and you can also affect
143
    the default behavior by using the reduction parameter. Please refer to the part of
144
    parameters for details.
145

146
    Can be used to calculate the softmax cross entropy loss with soft and hard labels.
147
    Where, the hard labels mean the actual label value, 0, 1, 2, etc.  And the soft labels
148
    mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
149

150
    The calculation includes the following two steps.
151

152
    -  **I.softmax cross entropy**
153

154
        1. Hard label (each sample can only be assigned into one category)
155

156
        1.1. when use_softmax=True
157

158 159
            .. math::
              \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N
160

161
            where, N is the number of samples and C is the number of categories.
162

163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        1.2. when use_softmax=False

            .. math::
              \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).


        2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1).

        2.1. when use_softmax=True

            .. math::
              \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories.

        2.2. when use_softmax=False

            .. math::
              \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).



189
    -  **II.Weight and reduction processing**
190 191 192 193 194 195 196 197 198 199 200

        1. Weight

            If the ``weight`` parameter is ``None`` , go to the next step directly.

            If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight
            according to soft_label = False or True as follows.

            1.1. Hard labels (soft_label = False)

            .. math::
201
                \\loss_j=loss_j*weight[label_j]
202

203

204 205 206 207 208 209 210
            1.2. Soft labels (soft_label = True)

             .. math::
                \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right)

        2. reduction

211
            2.1 if the ``reduction`` parameter is ``none``
212 213 214

            Return the previous result directly

215
            2.2 if the ``reduction`` parameter is ``sum``
216 217 218 219 220 221

            Return the sum of the previous results

            .. math::
               \\loss=\sum_{j}loss_j

222 223
            2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to
            the ``weight`` parameter as follows.
224

225
            2.3.1. If the  ``weight``  parameter is ``None``
226 227 228 229 230 231 232 233

            Return the average value of the previous results

             .. math::
                \\loss=\sum_{j}loss_j/N

            where, N is the number of samples and C is the number of categories.

234
            2.3.2. If the ``weight`` parameter is ``None`` , the weighted average value of the previous result will be returned
235 236 237 238

            1. Hard labels (soft_label = False)

             .. math::
239
                \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j]
240 241 242 243 244

            2. Soft labels (soft_label = True)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right)
245 246


247
    Parameters:
248
        weight (Tensor, optional): a manual rescaling weight given to each class.
249
            If given, has to be a Tensor of size C and the data type is float32, float64.
250
            Default is ``'None'`` .
251
        ignore_index (int64, optional): Specifies a target value that is ignored
252 253
            and does not contribute to the loss. A negative value means that no label
            value needs to be ignored. Only valid when soft_label = False.
254
            Default is ``-100`` .
255
        reduction (str, optional): Indicate how to average the loss by batch_size,
256 257 258 259 260
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
261
        soft_label (bool, optional): Indicate whether label is soft.
262 263
            If soft_label=False, the label is hard.  If soft_label=True, the label is soft.
            Default is ``False``.
264
        axis (int, optional): The index of dimension to perform softmax calculations.
265 266
            It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the number
            of dimensions of input :attr:`input`.
267
            Default is ``-1`` .
268
        use_softmax (bool, optional): Indicate whether compute softmax before cross_entropy.
269
            Default is ``True``.
270
        name (str, optional): The name of the operator. Default is ``None`` .
271 272 273 274
            For more information, please refer to :ref:`api_guide_Name` .


    Shape:
275 276
        - **input** (Tensor), the data type is float32, float64. Shape is :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes, ``k >= 1`` .

277
            Note:
278

279
                1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the
280 281 282
                output of softmax operator, which will produce incorrect results.

                2. when use_softmax=False, it expects the output of softmax operator.
283

284 285
        - **label** (Tensor)

286
            1. If soft_label=False, the shape is
287 288 289
            :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1.
            the data type is int32, int64, float32, float64, where each value is [0, C-1].

290
            2. If soft_label=True, the shape and data type should be same with ``input`` ,
291
            and the sum of the labels for each sample should be 1.
292

293 294 295 296
        - **output** (Tensor), Return the softmax cross_entropy loss of ``input`` and ``label``.
          The data type is the same as input.
          If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``.
          If :attr:`reduction` is ``'none'``:
297

298
            1. If soft_label = False, the dimension of return value is the same with ``label`` .
299

300
            2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` .
301

302
    Examples:
303 304

        .. code-block:: python
305
            :name: code-example1
306

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
            >>> # hard labels
            >>> import paddle
            >>> paddle.seed(2023)
            >>> N=100
            >>> C=200
            >>> reduction='mean'
            >>> input =  paddle.rand([N, C], dtype='float64')
            >>> label =  paddle.randint(0, C, shape=[N], dtype='int64')
            >>> weight = paddle.rand([C], dtype='float64')

            >>> cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
            ...     weight=weight, reduction=reduction)
            >>> dy_ret = cross_entropy_loss(input, label)
            >>> print(dy_ret)
            Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True,
            5.33697682)
323

324
        .. code-block:: python
325
            :name: code-example2
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
            >>> # soft labels
            >>> import paddle
            >>> paddle.seed(2023)
            >>> axis = -1
            >>> ignore_index = -100
            >>> N = 4
            >>> C = 3
            >>> shape = [N, C]
            >>> reduction='mean'
            >>> weight = None
            >>> logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            >>> labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            >>> labels /= paddle.sum(labels, axis=axis, keepdim=True)
            >>> paddle_loss_mean = paddle.nn.functional.cross_entropy(logits,
            ...                                                       labels,
            ...                                                       soft_label=True,
            ...                                                       axis=axis,
            ...                                                       weight=weight,
            ...                                                       reduction=reduction)
            >>> print(paddle_loss_mean)
            Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=True,
            1.14554912)
349

350 351
    """

352 353 354 355 356 357 358 359 360 361
    def __init__(
        self,
        weight=None,
        ignore_index=-100,
        reduction='mean',
        soft_label=False,
        axis=-1,
        use_softmax=True,
        name=None,
    ):
362
        super().__init__()
363 364
        self.weight = weight
        self.reduction = reduction
365
        self.ignore_index = ignore_index
366 367
        self.soft_label = soft_label
        self.axis = axis
368
        self.use_softmax = use_softmax
369
        self.name = name
370 371

    def forward(self, input, label):
372 373 374 375 376 377 378 379 380 381 382
        ret = paddle.nn.functional.cross_entropy(
            input,
            label,
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            soft_label=self.soft_label,
            axis=self.axis,
            use_softmax=self.use_softmax,
            name=self.name,
        )
383 384

        return ret
385 386


Z
zhiboniu 已提交
387
class HSigmoidLoss(Layer):
388 389
    """
    Hierarchical Sigmoid Layer.
390

391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
    Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
    represents the number of classes or the size of word dict.

    The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>_`. For the custom
    tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        feature_size (int): The number of features.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
            should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
            :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        weight_attr (ParamAttr, optional): The parameter attribute for the learnable weights
            of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
            ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
            initialized with Xavier. Default is None.
        bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
            is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
            hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
            set, the bias is initialized zero. Default is None.
425
        is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
            `path_code` should be passed to its forward method, otherwise `path_table` and `path_code`
            should not be passed to its forward method. Default is False.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True,
            the gradient of weight and input will be sparse. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input (Tensor): The input tensor. The shapes is [N, D], where N is batch size and D is feature size. It's data type should be float32, float64.
        label (Tensor): It's shapes is [N, 1]. It's data type should be int64.
        output (Tensor): The HSigmoid Loss of ``input`` and ``label``. Shape is [N, 1]

    Examples:
        .. code-block:: python

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
            >>> import paddle
            >>> paddle.set_device('cpu')
            >>> paddle.seed(2023)
            >>> input = paddle.uniform([4, 3])
            >>> print(input)
            Tensor(shape=[4, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
            [[ 0.73167229,  0.04029441, -0.48078126],
             [ 0.81050646, -0.15199822, -0.18717426],
             [ 0.94041789,  0.48874724,  0.03570259],
             [ 0.46585739,  0.95573163, -0.91368192]])
            >>> label = paddle.to_tensor([0, 1, 4, 5])
            >>> m = paddle.nn.HSigmoidLoss(3, 6)
            >>> out = m(input, label)
            >>> print(out)
            Tensor(shape=[4, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
            [[1.94512916],
             [2.26129627],
             [2.36135936],
             [2.97453213]])
460 461
    """

462 463 464 465 466 467 468 469 470 471
    def __init__(
        self,
        feature_size,
        num_classes,
        weight_attr=None,
        bias_attr=None,
        is_custom=False,
        is_sparse=False,
        name=None,
    ):
472
        super().__init__()
473 474
        if (num_classes < 2) and (not is_custom):
            raise ValueError(
475 476
                "num_classes must not be less than 2 with default tree"
            )
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493

        if (not is_custom) and (is_sparse):
            print("Sparse mode should not be used without custom tree")
            is_sparse = False

        self._feature_size = feature_size
        self._num_classes = num_classes
        self._is_custom = is_custom
        self._is_sparse = is_sparse

        self._weight_attr = weight_attr
        self._bias_attr = bias_attr

        self._name = name
        self._dtype = paddle.get_default_dtype()

        remote_prefetch = is_sparse
494 495 496 497
        print(
            "With sparse mode, if your models has only"
            " small parameter prefetch may cause speed down"
        )
498 499

        C = self._num_classes if is_custom else self._num_classes - 1
500 501 502 503 504 505 506 507 508
        self.weight = self.create_parameter(
            [C, self._feature_size],
            attr=self._weight_attr,
            is_bias=False,
            dtype=self._dtype,
        )
        self.bias = self.create_parameter(
            [C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype
        )
509 510

    def forward(self, input, label, path_table=None, path_code=None):
511 512 513 514 515 516 517 518 519 520 521
        out = F.hsigmoid_loss(
            input,
            label,
            self._num_classes,
            self.weight,
            self.bias,
            path_table=path_table,
            path_code=path_code,
            is_sparse=self._is_sparse,
            name=self._name,
        )
522 523 524
        return out


Z
zhiboniu 已提交
525
class MSELoss(Layer):
526
    r"""
527 528
    **Mean Square Error Loss**
    Computes the mean square error (squared L2 norm) of given input and label.
529

530
    If :attr:`reduction` is set to ``'none'``, loss is calculated as:
531

532 533
    .. math::
        Out = (input - label)^2
534

535
    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:
536

537 538
    .. math::
        Out = \operatorname{mean}((input - label)^2)
539

540
    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:
541

542 543
    .. math::
        Out = \operatorname{sum}((input - label)^2)
544

545
    where `input` and `label` are `float32` tensors of same shape.
546

547
    Parameters:
548
        reduction (str, optional): The reduction method for the output,
549
            could be 'none' | 'mean' | 'sum'.
550 551 552
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
553
            Default is ``'mean'``.
554

B
Bai Yifan 已提交
555 556 557 558
    Shape:
        input (Tensor): Input tensor, the data type is float32 or float64
        label (Tensor): Label tensor, the data type is float32 or float64
        output (Tensor): output tensor storing the MSE loss of input and label, the data type is same as input.
559

560
    Examples:
561

562
        .. code-block:: python
563

564 565 566 567 568 569 570 571
            >>> import paddle
            >>> mse_loss = paddle.nn.loss.MSELoss()
            >>> input = paddle.to_tensor([1.5])
            >>> label = paddle.to_tensor([1.7])
            >>> output = mse_loss(input, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.04000002)
572

573 574 575
    """

    def __init__(self, reduction='mean'):
576
        super().__init__()
577 578 579
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', "
580 581
                "but received {}.".format(reduction)
            )
582 583 584
        self.reduction = reduction

    def forward(self, input, label):
Z
zhiboniu 已提交
585
        if not in_dynamic_mode():
586 587 588 589 590 591
            fluid.data_feeder.check_variable_and_dtype(
                input, 'input', ['float32', 'float64'], 'MSELoss'
            )
            fluid.data_feeder.check_variable_and_dtype(
                label, 'label', ['float32', 'float64'], 'MSELoss'
            )
592

593
        if in_dygraph_mode():
594
            square_out = paddle._C_ops.square(paddle.subtract(input, label))
595 596
        else:
            square_out = paddle.square(paddle.subtract(input, label))
597 598 599 600 601
        if self.reduction == 'none':
            return square_out

        reduce_op = 'reduce_mean'
        if self.reduction == 'sum':
602 603
            square_out = paddle.sum(square_out)
            return square_out
604

605
        return paddle.mean(square_out)
606 607


Z
zhiboniu 已提交
608
class L1Loss(Layer):
609
    r"""
610

611
    Construct a callable object of the ``L1Loss`` class.
612
    The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows.
613

614
    If `reduction` set to ``'none'``, the loss is:
L
Leo Chen 已提交
615 616

    .. math::
617
        Out = \lvert input - label\rvert
618

619
    If `reduction` set to ``'mean'``, the loss is:
620

L
Leo Chen 已提交
621
    .. math::
622
        Out = MEAN(\lvert input - label\rvert)
623

624
    If `reduction` set to ``'sum'``, the loss is:
625

L
Leo Chen 已提交
626
    .. math::
627
        Out = SUM(\lvert input - label\rvert)
L
Leo Chen 已提交
628

629

L
Leo Chen 已提交
630
    Parameters:
631
        reduction (str, optional): Indicate the reduction to apply to the loss,
L
Leo Chen 已提交
632
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
633 634 635
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
L
Leo Chen 已提交
636
            Default is ``'mean'``.
637 638 639
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
640 641 642 643
        - input (Tensor): The input tensor. The shapes is ``[N, *]``, where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        - label (Tensor): label. The shapes is ``[N, *]``, same shape as ``input`` . It's data type should be float32, float64, int32, int64.
        - output (Tensor): The L1 Loss of ``input`` and ``label``.
          If `reduction` is ``'none'``, the shape of output loss is ``[N, *]``, the same as ``input`` .
644
          If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [].
645

L
Leo Chen 已提交
646 647
    Examples:
        .. code-block:: python
648

649
            >>> import paddle
650

651 652
            >>> input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]])
            >>> label = paddle.to_tensor([[1.7, 1], [0.4, 0.5]])
653

654 655 656 657 658
            >>> l1_loss = paddle.nn.L1Loss()
            >>> output = l1_loss(input, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.34999999)
659

660 661 662 663 664
            >>> l1_loss = paddle.nn.L1Loss(reduction='sum')
            >>> output = l1_loss(input, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            1.39999998)
665

666 667 668 669 670 671
            >>> l1_loss = paddle.nn.L1Loss(reduction='none')
            >>> output = l1_loss(input, label)
            >>> print(output)
            Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
            [[0.20000005, 0.19999999],
             [0.20000000, 0.79999995]])
672

L
Leo Chen 已提交
673 674
    """

675
    def __init__(self, reduction='mean', name=None):
L
Leo Chen 已提交
676 677 678
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
679 680
                "received %s, which is not allowed." % reduction
            )
681
        super().__init__()
L
Leo Chen 已提交
682
        self.reduction = reduction
683
        self.name = name
L
Leo Chen 已提交
684

685
    def forward(self, input, label):
686 687 688
        return paddle.nn.functional.l1_loss(
            input, label, self.reduction, name=self.name
        )
C
ceci3 已提交
689 690


Z
zhiboniu 已提交
691
class BCELoss(Layer):
C
ceci3 已提交
692
    """
693

C
ceci3 已提交
694
    This interface is used to construct a callable object of the ``BCELoss`` class.
695 696
    The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
    and target labels ``label`` . The binary_cross_entropy loss can be described as:
C
ceci3 已提交
697

C
ceci3 已提交
698
    If :attr:`weight` is set, the loss is:
C
ceci3 已提交
699 700

    .. math::
C
ceci3 已提交
701
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
702

C
ceci3 已提交
703
    If :attr:`weight` is None, the loss is:
C
ceci3 已提交
704 705

    .. math::
C
ceci3 已提交
706 707
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

708
    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
C
ceci3 已提交
709

C
ceci3 已提交
710
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
C
ceci3 已提交
711

C
ceci3 已提交
712 713
    .. math::
        Out = MEAN(Out)
714

C
ceci3 已提交
715
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
C
ceci3 已提交
716

C
ceci3 已提交
717 718
    .. math::
        Out = SUM(Out)
C
ceci3 已提交
719

720
    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
C
ceci3 已提交
721 722
    should be numbers between 0 and 1.

C
ceci3 已提交
723
    Parameters:
724 725
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
C
ceci3 已提交
726
            is float32, float64. Default is ``'None'``.
727
        reduction (str, optional): Indicate how to average the loss by batch_size,
C
ceci3 已提交
728
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
C
ceci3 已提交
729
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
730
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
C
ceci3 已提交
731
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
C
ceci3 已提交
732
            Default is ``'mean'``.
733 734 735 736
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
737 738
        - input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float16, float32, float64.
        - label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float16, float32, float64.
学渣戊's avatar
学渣戊 已提交
739
        - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar.
C
ceci3 已提交
740

741
    Returns:
C
ceci3 已提交
742 743
        A callable object of BCELoss.

C
ceci3 已提交
744 745
    Examples:
        .. code-block:: python
C
ceci3 已提交
746

747
            >>> import paddle
748

749 750 751 752 753 754 755
            >>> input = paddle.to_tensor([0.5, 0.6, 0.7])
            >>> label = paddle.to_tensor([1.0, 0.0, 1.0])
            >>> bce_loss = paddle.nn.BCELoss()
            >>> output = bce_loss(input, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.65537095)
756

C
ceci3 已提交
757 758
    """

759
    def __init__(self, weight=None, reduction='mean', name=None):
C
ceci3 已提交
760 761 762
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
763 764
                "received %s, which is not allowed." % reduction
            )
C
ceci3 已提交
765

766
        super().__init__()
C
ceci3 已提交
767 768
        self.weight = weight
        self.reduction = reduction
769
        self.name = name
C
ceci3 已提交
770 771

    def forward(self, input, label):
772 773 774
        out = paddle.nn.functional.binary_cross_entropy(
            input, label, self.weight, self.reduction, self.name
        )
775
        return out
776 777


Z
zhiboniu 已提交
778
class NLLLoss(Layer):
779
    r"""
S
swtkiwi 已提交
780

781
    This class accepts input and target label and returns negative log likelihood
782
    cross error. It is useful to train a classification problem with C classes.
783

784
    The input for the loss is expected to contain log-probabilities of
785
    each classes. It has to be a Tensor of size either (batch_size, C) or
786 787 788 789
    (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
    The label for the loss should be a class index in the range [0, C-1]
    where C is the number of classes. If ignore_index is specified, the
    specified target value does not contribute to the input gradient.
790

791 792 793
    If the optional argument `weight` is provided, it should be a 1D Tensor
    assigning weight to each of the classed. This is particularly useful
    when you have an unbalanced training set.
794

795 796 797 798
    The loss is calculated as follows.
    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
799 800

        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
801
        l_n = - w_{y_n} x_{n,y_n}, \quad
802
        w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore_index}\},
803 804 805 806 807

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then

    .. math::
808 809 810 811 812 813 814 815 816 817

        \ell(x, y) =
        \left\{
            \begin{array}{lcl}
            \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
            \text{if  reduction} = \text{'mean';}\\
            \sum_{n=1}^N l_n,  &
            \text{if  reduction} = \text{'sum'.}
            \end{array}
        \right.
818 819

    Parameters:
820 821
        weight (Tensor, optional): Weight tensor, a manual rescaling weight given
            to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
822
            it treated as if having all ones. the data type is
823
            float32, float64, Default is ``'None'``.
824
        ignore_index (int, optional): Specifies a target value that is ignored
825
            and does not contribute to the input gradient.
826
        reduction (str, optional): Indicate how to average the loss,
827
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. Default is ``'mean'``.
828 829 830
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
831
            Default is ``'mean'``.
832
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default is ``'None'``.
833

834
    Shape:
835
        - input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
836 837
            But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
            The data type is float32, float64.
838
        - label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
839
            The data type is int64.
840
        - output (Tensor): the `negative log likelihood loss` between input `x` and `label`.
841
            If `reduction` is `'none'`, the shape is `[N, *]`.
842
            If `reduction` is `'sum'` or `'mean'`, the shape is `[]`.
843 844 845 846

    Examples:
        .. code-block:: python

847
            >>> import paddle
848

849 850
            >>> nll_loss = paddle.nn.loss.NLLLoss()
            >>> log_softmax = paddle.nn.LogSoftmax(axis=1)
851

852 853 854 855 856 857 858 859 860 861 862
            >>> input = paddle.to_tensor([[0.88103855, 0.9908683 , 0.6226845 ],
            ...                           [0.53331435, 0.07999352, 0.8549948 ],
            ...                           [0.25879037, 0.39530203, 0.698465  ],
            ...                           [0.73427284, 0.63575995, 0.18827209],
            ...                           [0.05689114, 0.0862954 , 0.6325046 ]], "float32")
            >>> log_out = log_softmax(input)
            >>> label = paddle.to_tensor([0, 2, 1, 1, 0], "int64")
            >>> result = nll_loss(log_out, label)
            >>> print(result)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            1.07202101)
863

864
    """
865

866 867 868
    def __init__(
        self, weight=None, ignore_index=-100, reduction='mean', name=None
    ):
869
        if reduction not in ['sum', 'mean', 'none']:
870
            raise ValueError(
871
                "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
872 873
                "'none', but received %s, which is not allowed." % reduction
            )
874
        super().__init__()
875 876 877 878
        self._weight = weight
        self._ignore_index = ignore_index
        self._reduction = reduction
        self._name = name
879

880
    def forward(self, input, label):
881 882 883 884 885 886 887 888
        return F.nll_loss(
            input,
            label,
            weight=self._weight,
            ignore_index=self._ignore_index,
            reduction=self._reduction,
            name=self._name,
        )
889 890


891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937
class PoissonNLLLoss(Layer):
    r"""Generate a callable object of 'PoissonNLLLoss' to calculate the
    Poisson negative log likelihood loss between Input(input) and
    Input(label). Notes that Input(input) is the expectation of underlying
    Poisson distribution and Input(label) is the random samples from the
    Poisson distribution


    Poisson negative log likelihood loss is calculated as follows:

    .. math::
        \text{loss}(\text{input}, \text{label}) = \text{input} - \text{label} * \log(\text{label}) + \log(\text{label!})

    The last term can be approximated with Stirling formula. This approximation term is used when :attr:`full` is ``True``.
    The approximation is added when label values are more than 1 and omitted when the labels are less than or equal to 1.

    Parameters:
         log_input (bool, optional):
            Whether to the treat input tensor as log input.
            If ``True`` the loss is computed as, :math:`\exp(\text{input}) - \text{label} * \text{input}` .
            If ``False`` then loss is :math:`\text{input} - \text{label} * \log(\text{input}+\text{epsilon})` .
            Default: ``True``.
         full (bool, optional):
            Whether to compute full loss.
            If ``True``, the Stirling approximation term is added.
            If ``False``, the Stirling approximation is dropped.
            Default: ``False``.
         epsilon (float, optional):
            A small value to avoid evaluation of :math:`\log(0)` when ``log_input`` = ``False``. ``epsilon > 0``.
            Default: 1e-8.
         reduction (str, optional):
            Indicate how to reduce the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
            Default is ``'mean'``.
         name (str, optional):
            Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input (Tensor): The shape of input tensor should be `(N, *)` or `(*)` where `(*)` denotes any number of extra dimensions.
        - label (Tensor): The shape of input tensor should be `(N, *)` or `(*)`, same shape as the input tensor.
        - output (Tensor): scalar if :attr:`reduction` is ``'mean'`` (default) or ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input

    Examples:
        .. code-block:: python

938 939 940 941 942 943 944 945 946
            >>> import paddle
            >>> paddle.seed(2023)
            >>> poisson_nll_loss = paddle.nn.loss.PoissonNLLLoss()
            >>> input = paddle.randn([5, 2], dtype=paddle.float32)
            >>> label = paddle.randn([5, 2], dtype=paddle.float32)
            >>> loss = poisson_nll_loss(input, label)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            1.52983975)
947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986

    """

    def __init__(
        self,
        log_input=True,
        full=False,
        epsilon=1e-8,
        reduction="mean",
        name=None,
    ):
        if epsilon <= 0:
            raise ValueError(
                "The value of `epsilon` in PoissonNLLLoss should be positve, but received %f, which is not allowed"
                % epsilon
            )
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in PoissonNLLLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction
            )
        super().__init__()
        self._log_input = log_input
        self._full = full
        self._epsilon = epsilon
        self._reduction = reduction
        self._name = name

    def forward(self, input, label):
        return F.poisson_nll_loss(
            input,
            label,
            log_input=self._log_input,
            full=self._full,
            epsilon=self._epsilon,
            reduction=self._reduction,
            name=self._name,
        )


Z
zhiboniu 已提交
987
class KLDivLoss(Layer):
988
    r"""
989

990 991 992 993
    Generate a callable object of 'KLDivLoss' to calculate the
    Kullback-Leibler divergence loss between Input(X) and
    Input(Target). Notes that Input(X) is the log-probability
    and Input(Target) is the probability.
994 995 996 997 998

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

999 1000 1001 1002
    Here :math:`x` is input and :math:`y` is label.

    If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result.

1003
    If `reduction` is ``'mean'``, the output loss is the shape of [], and the output is the average of all losses.
1004

1005
    If `reduction` is ``'sum'``, the output loss is the shape of [], and the output is the sum of all losses.
1006 1007 1008

    If `reduction` is ``'batchmean'``, the output loss is the shape of [N], N is the batch size, and the output is the sum of all losses divided by the batch size.

1009
    Parameters:
1010 1011 1012 1013 1014 1015 1016
        reduction (str, optional): Indicate how to average the loss,
            the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
            Default is ``'mean'``.
1017 1018

    Shape:
1019 1020 1021 1022 1023

        input (Tensor): ``(N, *)``, where ``*`` means, any number of additional dimensions.

        label (Tensor): ``(N, *)``, same shape as input.

1024
        output (Tensor): tensor with shape: [] by default.
1025 1026 1027 1028

    Examples:
        .. code-block:: python

1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058
            >>> import paddle
            >>> import paddle.nn as nn

            >>> shape = (5, 20)
            >>> x = paddle.uniform(shape, min=-10, max=10).astype('float32')
            >>> target = paddle.uniform(shape, min=-10, max=10).astype('float32')

            >>> # 'batchmean' reduction, loss shape will be []
            >>> kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
            >>> pred_loss = kldiv_criterion(x, target)
            >>> print(pred_loss.shape)
            []

            >>> # 'mean' reduction, loss shape will be []
            >>> kldiv_criterion = nn.KLDivLoss(reduction='mean')
            >>> pred_loss = kldiv_criterion(x, target)
            >>> print(pred_loss.shape)
            []

            >>> # 'sum' reduction, loss shape will be []
            >>> kldiv_criterion = nn.KLDivLoss(reduction='sum')
            >>> pred_loss = kldiv_criterion(x, target)
            >>> print(pred_loss.shape)
            []

            >>> # 'none' reduction, loss shape is same with X shape
            >>> kldiv_criterion = nn.KLDivLoss(reduction='none')
            >>> pred_loss = kldiv_criterion(x, target)
            >>> print(pred_loss.shape)
            [5, 20]
1059

1060 1061 1062
    """

    def __init__(self, reduction='mean'):
1063
        super().__init__()
1064 1065 1066
        self.reduction = reduction

    def forward(self, input, label):
L
LielinJiang 已提交
1067
        out = F.kl_div(input, label, self.reduction)
1068 1069 1070
        return out


Z
zhiboniu 已提交
1071
class MarginRankingLoss(Layer):
1072
    r"""
1073 1074

    This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
1075
    The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
1076 1077
    , use the math function as follows.

1078
    .. math::
1079
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

1098
    Shape:
1099

N
Noel 已提交
1100 1101
        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

1102
        other: N-D Tensor, `other` have the same shape and dtype as `input`.
N
Noel 已提交
1103

1104
        label: N-D Tensor, label have the same shape and dtype as `input`.
N
Noel 已提交
1105

1106
        output: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
1107 1108 1109 1110 1111 1112 1113 1114

    Returns:
        A callable object of MarginRankingLoss.

    Examples:

        .. code-block:: python

1115
            >>> import paddle
1116

1117 1118 1119 1120 1121 1122 1123 1124
            >>> input = paddle.to_tensor([[1, 2], [3, 4]], dtype="float32")
            >>> other = paddle.to_tensor([[2, 1], [2, 4]], dtype="float32")
            >>> label = paddle.to_tensor([[1, -1], [-1, -1]], dtype="float32")
            >>> margin_rank_loss = paddle.nn.MarginRankingLoss()
            >>> loss = margin_rank_loss(input, other, label)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.75000000)
1125 1126 1127 1128 1129
    """

    def __init__(self, margin=0.0, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
1130
                "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
1131 1132
                "received %s, which is not allowed." % reduction
            )
1133
        super().__init__()
1134 1135 1136 1137
        self.margin = margin
        self.reduction = reduction
        self.name = name

1138
    def forward(self, input, other, label):
1139 1140 1141
        out = paddle.nn.functional.margin_ranking_loss(
            input, other, label, self.margin, self.reduction, self.name
        )
1142
        return out
1143 1144


Z
zhiboniu 已提交
1145
class CTCLoss(Layer):
1146
    r"""
1147

1148 1149 1150
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
1151 1152 1153 1154 1155 1156 1157
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.

    Shape:
1158 1159 1160 1161 1162
        - log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
        - labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        - input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        - label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
        - norm_by_times (bool, optional): Whether to normalize the gradients by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if reduction mode is 'mean'. Default: False.
1163 1164

    Returns:
1165
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is []. Data type is the same as ``log_probs``.
1166

1167 1168 1169 1170
    Examples:

        .. code-block:: python

1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
            >>> # declarative mode
            >>> import paddle

            >>> # length of the longest logit sequence
            >>> max_seq_length = 4
            >>> #length of the longest label sequence
            >>> max_label_length = 3
            >>> # number of logit sequences
            >>> batch_size = 2
            >>> # class num
            >>> class_num = 3

            >>> log_probs = paddle.to_tensor([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
            ...                                [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],
            ...                               [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
            ...                                [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],
            ...                               [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
            ...                                [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],
            ...                               [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
            ...                                [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],
            ...                               [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
            ...                                [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]], dtype="float32")
            >>> labels = paddle.to_tensor([[1, 2, 2], [1, 2, 2]], dtype="int32")
            >>> input_lengths = paddle.to_tensor([5, 5], dtype="int64")
            >>> label_lengths = paddle.to_tensor([3, 3], dtype="int64")

            >>> loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels, input_lengths, label_lengths)
            >>> print(loss)
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
            [3.91798496, 2.90765214])

            >>> loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels, input_lengths, label_lengths)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            1.13760614)
1206 1207 1208
    """

    def __init__(self, blank=0, reduction='mean'):
1209
        super().__init__()
1210 1211 1212
        self.blank = blank
        self.reduction = reduction

1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
    def forward(
        self,
        log_probs,
        labels,
        input_lengths,
        label_lengths,
        norm_by_times=False,
    ):
        return paddle.nn.functional.ctc_loss(
            log_probs,
            labels,
            input_lengths,
            label_lengths,
            self.blank,
            self.reduction,
            norm_by_times=norm_by_times,
        )
1230 1231


H
Hui Zhang 已提交
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248
class RNNTLoss(Layer):
    """
    Parameters:
        blank (int, optional): blank label. Default: 0.
        fastemit_lambda (float, optional): Regularization parameter for FastEmit (https://arxiv.org/pdf/2010.11148.pdf)
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the output losses will be divided by the target lengths and
            then the mean over the batch is taken. Default: 'mean'

    Shape:
        input: logprob Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
        label: 2 dimensional (batch, labelLength) Tensor containing all the targets of the batch with zero padded
        input_lengths: Tensor of size (batch) containing size of each output sequence from the network
        label_lengths: Tensor of (batch) containing label length of each example

    Returns:
1249
     Tensor, The RNN-T loss between ``logprobs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is []. Data type is the same as ``logprobs``.
H
Hui Zhang 已提交
1250 1251 1252 1253

    Examples:
        .. code-block:: python

1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280
            >>> # declarative mode
            >>> import numpy as np
            >>> import paddle
            >>> from paddle.nn import RNNTLoss

            >>> fn = RNNTLoss(reduction='sum', fastemit_lambda=0.0)

            >>> acts = np.array([[[[0.1, 0.6, 0.1, 0.1, 0.1],
            ...                    [0.1, 0.1, 0.6, 0.1, 0.1],
            ...                    [0.1, 0.1, 0.2, 0.8, 0.1]],
            ...                   [[0.1, 0.6, 0.1, 0.1, 0.1],
            ...                    [0.1, 0.1, 0.2, 0.1, 0.1],
            ...                    [0.7, 0.1, 0.2, 0.1, 0.1]]]])
            >>> labels = [[1, 2]]

            >>> acts = paddle.to_tensor(acts, stop_gradient=False)

            >>> lengths = [acts.shape[1]] * acts.shape[0]
            >>> label_lengths = [len(l) for l in labels]
            >>> labels = paddle.to_tensor(labels, paddle.int32)
            >>> lengths = paddle.to_tensor(lengths, paddle.int32)
            >>> label_lengths = paddle.to_tensor(label_lengths, paddle.int32)

            >>> costs = fn(acts, labels, lengths, label_lengths)
            >>> print(costs)
            Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=False,
            -2.85042444)
H
Hui Zhang 已提交
1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304
    """

    def __init__(
        self, blank=0, fastemit_lambda=0.001, reduction='mean', name=None
    ):
        super().__init__()
        self.blank = blank
        self.reduction = reduction
        self.fastemit_lambda = fastemit_lambda
        self.name = name

    def forward(self, input, label, input_lengths, label_lengths):
        return paddle.nn.functional.rnnt_loss(
            input,
            label,
            input_lengths,
            label_lengths,
            blank=self.blank,
            fastemit_lambda=self.fastemit_lambda,
            reduction=self.reduction,
            name=self.name,
        )


Z
zhiboniu 已提交
1305
class SmoothL1Loss(Layer):
1306
    r"""
1307 1308 1309 1310 1311 1312 1313
    This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

1314
        loss(x, y) = \frac{1}{n}\sum_{i}z_i
1315

1316
    where :math:`z_i` is given by:
1317 1318 1319

    .. math::

1320
        \mathop{z_i} = \left\{\begin{array}{rcl}
1321 1322 1323
                0.5(x_i - y_i)^2 & & {if |x_i - y_i| < \delta} \\
                \delta * |x_i - y_i| - 0.5 * \delta^2 & & {otherwise}
            \end{array} \right.
1324 1325 1326 1327 1328 1329 1330 1331

    Parameters:
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
1332
        delta (float, optional): Specifies the hyperparameter :math:`\delta` to be used.
1333 1334
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
1335 1336
            negative/zero values. Default value is :math:`1.0`.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1337 1338 1339

    Call Parameters:

1340 1341
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C),
        where C is number of classes, and if shape is more than 2D,
1342 1343
        this is (N, C, D1, D2,..., Dk), k >= 1.

1344
        label (Tensor): Label tensor, the data type is float32 or float64.
1345
        The shape of label is the same as the shape of input.
1346

1347 1348
    Returns:
        Tensor, The tensor storing the smooth_l1_loss of input and label.
1349 1350 1351 1352

    Examples:
        .. code-block:: python

1353 1354 1355 1356 1357 1358 1359 1360 1361
            >>> import paddle
            >>> paddle.seed(2023)
            >>> input = paddle.rand([3, 3]).astype("float32")
            >>> label = paddle.rand([3, 3]).astype("float32")
            >>> loss = paddle.nn.SmoothL1Loss()
            >>> output = loss(input, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.08307374)
1362 1363 1364
    """

    def __init__(self, reduction='mean', delta=1.0, name=None):
1365
        super().__init__()
1366 1367 1368 1369 1370
        self.reduction = reduction
        self.delta = delta
        self.name = name

    def forward(self, input, label):
1371 1372 1373 1374 1375 1376 1377
        return F.smooth_l1_loss(
            input,
            label,
            reduction=self.reduction,
            delta=self.delta,
            name=self.name,
        )
1378 1379


Y
yangguohao 已提交
1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395
class MultiLabelSoftMarginLoss(Layer):
    r"""Creates a criterion that optimizes a multi-class multi-classification
        hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
        and output :math:`y` (which is a 2D `Tensor` of target class indices).
        For each sample in the mini-batch:

        .. math::
            \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}

        where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
        :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
        :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
        and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
        :math:`y` and :math:`x` must have the same size.

        Parameters:
1396
            weight (Tensor,optional): a manual rescaling weight given to each class.
Y
yangguohao 已提交
1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422
                    If given, has to be a Tensor of size C and the data type is float32, float64.
                    Default is ``'None'`` .
            reduction (str, optional): Indicate how to average the loss by batch_size,
                    the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                    If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                    If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                    If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                    Default: ``'mean'``
            name (str, optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.

        Call parameters:
            input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
            label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

        Shape:
            input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
            label: N-D Tensor, same shape as the input.
            output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

        Returns:
            A callable object of MultiLabelSoftMarginLoss.

        Examples:
            .. code-block:: python

1423 1424
                >>> import paddle
                >>> import paddle.nn as nn
Y
yangguohao 已提交
1425

1426 1427
                >>> input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
                >>> label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
Y
yangguohao 已提交
1428

1429 1430 1431 1432 1433
                >>> multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none')
                >>> loss = multi_label_soft_margin_loss(input, label)
                >>> print(loss)
                Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                [3.49625897, 0.71111226, 0.43989015])
Y
yangguohao 已提交
1434

1435 1436 1437 1438 1439
                >>> multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='mean')
                >>> loss = multi_label_soft_margin_loss(input, label)
                >>> print(loss)
                Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                1.54908717)
Y
yangguohao 已提交
1440 1441 1442
        """

    def __init__(self, weight=None, reduction="mean", name=None):
1443
        super().__init__()
Y
yangguohao 已提交
1444 1445 1446
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', "
1447 1448
                "but received {}.".format(reduction)
            )
Y
yangguohao 已提交
1449 1450 1451 1452 1453
        self.weight = weight
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1454 1455 1456 1457 1458 1459 1460
        return F.multi_label_soft_margin_loss(
            input,
            label,
            weight=self.weight,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1461 1462


1463 1464
class HingeEmbeddingLoss(Layer):
    r"""
1465
    Create a callable object of `HingeEmbeddingLoss` to calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522
    This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
    and is typically used for learning nonlinear embeddings or semi-supervised learning.

    The loss function for :math:`n`-th sample in the mini-batch is

    .. math::
        l_n = \begin{cases}
            x_n, & \text{if}\; y_n = 1,\\
            \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
        \end{cases}

    and the total loss functions is

    .. math::
        \ell(x, y) = \begin{cases}
            \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
        \end{cases}

    where :math:`L = \{l_1,\dots,l_N\}^\top`.

    Parameters:

        margin (float, optional): Specifies the hyperparameter margin to be used.
            The value determines how large the input need to be to calculate in
            hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss.
            Default = 1.0
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:

        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.

        label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

    Shape:

        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

        label: N-D Tensor, same shape as the input.

        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

    Returns:

        Tensor, The tensor variable storing the hinge_embedding_loss of input and label.

    Examples:
        .. code-block:: python

1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
            >>> import paddle
            >>> import paddle.nn as nn

            >>> input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            >>> # label elements in {1., -1.}
            >>> label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

            >>> hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='none')
            >>> loss = hinge_embedding_loss(input, label)
            >>> print(loss)
            Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
            [[ 0., -2.,  0.],
             [ 0., -1.,  2.],
             [ 1.,  1.,  1.]])

            >>> hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='mean')
            >>> loss = hinge_embedding_loss(input, label)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.22222222)
1543 1544 1545
    """

    def __init__(self, margin=1.0, reduction="mean", name=None):
1546
        super().__init__()
1547 1548 1549 1550 1551
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1552 1553 1554 1555 1556 1557 1558
        return F.hinge_embedding_loss(
            input,
            label,
            reduction=self.reduction,
            margin=self.margin,
            name=self.name,
        )
1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594


class CosineEmbeddingLoss(Layer):
    r"""
    This interface is used to construct a callable object of the ``CosineEmbeddingLoss`` class.
    The CosineEmbeddingLoss layer measures the cosine_embedding loss between input predictions ``input1``, ``input2``
    and target labels ``label`` with values 1 or 0. This is used for measuring whether two inputs are similar or
    dissimilar and is typically used for learning nonlinear embeddings or semi-supervised learning.
    The cosine embedding loss can be described as:

    If label = 1, then the loss value can be calculated as follow:

    .. math::
        Out = 1 - cos(input1, input2)

    If label = -1, then the loss value can be calculated as follow:

    .. math::
        Out = max(0, cos(input1, input2)) - margin

    The operator cos can be described as follow:
     .. math::
        cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2}

    Parameters:
        margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
            :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
            default value is :math:`0`.
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
1595
        input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, which can be 0, 'M' means the length of input array.
1596
                         Available dtypes are float32, float64.
1597
        input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, which can be 0, 'M' means the length of input array.
1598
                         Available dtypes are float32, float64.
1599
        label (Tensor): tensor with shape: [N] or [1], 'N' means the length of input array. The target labels values should be -1 or 1.
1600 1601 1602
                         Available dtypes are int32, int64, float32, float64.
        output (Tensor): Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``.
                         If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` .
1603
                         If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [].
1604 1605 1606 1607

    Examples:
        .. code-block:: python

1608
            >>> import paddle
1609

1610 1611 1612
            >>> input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32')
            >>> input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32')
            >>> label = paddle.to_tensor([1, -1], 'int64')
1613

1614 1615 1616 1617 1618
            >>> cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='mean')
            >>> output = cosine_embedding_loss(input1, input2, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.21155193)
1619

1620 1621 1622 1623 1624
            >>> cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='sum')
            >>> output = cosine_embedding_loss(input1, input2, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.42310387)
1625

1626 1627 1628 1629 1630
            >>> cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='none')
            >>> output = cosine_embedding_loss(input1, input2, label)
            >>> print(output)
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
            [0.42310387, 0.        ])
1631 1632 1633 1634 1635 1636 1637

    """

    def __init__(self, margin=0, reduction='mean', name=None):
        if margin > 1 or margin < -1:
            raise ValueError(
                "The value of 'margin' should be in the interval of [-1, 1], but received %f, which is not allowed."
1638 1639
                % margin
            )
1640 1641 1642
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' should be 'sum', 'mean' or "
1643 1644
                "'none', but received %s, which is not allowed." % reduction
            )
1645
        super().__init__()
1646 1647 1648 1649 1650
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input1, input2, label):
1651 1652 1653 1654 1655 1656 1657 1658
        return F.cosine_embedding_loss(
            input1,
            input2,
            label,
            margin=self.margin,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675


class TripletMarginWithDistanceLoss(Layer):
    r"""
    Creates a criterion that measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, D)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}

    where the default `distance_function`
1676

Y
yangguohao 已提交
1677
    .. math::
1678
        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_2
1679 1680

    or user can define their own distance function. `margin` is a nonnegative margin representing the minimum difference
Y
yangguohao 已提交
1681 1682 1683 1684 1685
    between the positive and negative distances that is required for the loss to be 0. If `swap` is true, it will compare distance of (input, negative) with
    distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.

    Parameters:
        distance_function (Callable, Optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
1686

Y
yangguohao 已提交
1687 1688 1689 1690
        margin (float, Optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
                between the positive and negative distances required for the loss to be 0. Larger
                margins penalize cases where the negative examples are not distant enough from the
                anchors, relative to the positives.
1691

Y
yangguohao 已提交
1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702
        swap (bool, Optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
                and negative samples) if swap distance smaller than negative distance. Default: ``False``.

        reduction (str, Optional):Indicate how to average the loss by batch_size.
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
1703

Y
yangguohao 已提交
1704 1705
    Shapes:
        input (Tensor):Input tensor, the data type is float32 or float64.
1706
    the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
Y
yangguohao 已提交
1707 1708

        positive (Tensor):Positive tensor, the data type is float32 or float64.
1709
    The shape of label is the same as the shape of input.
Y
yangguohao 已提交
1710 1711

        negative (Tensor):Negative tensor, the data type is float32 or float64.
1712
    The shape of label is the same as the shape of input.
1713

1714
        output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.
Y
yangguohao 已提交
1715 1716 1717 1718 1719 1720 1721

    Return:
        A callable object of TripletMarginWithDistanceLoss

    Examples:
        .. code-block:: python

1722 1723
            >>> import paddle
            >>> from paddle.nn import TripletMarginWithDistanceLoss
Y
yangguohao 已提交
1724

1725 1726 1727 1728 1729 1730 1731 1732
            >>> input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            >>> positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            >>> negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            >>> triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='none')
            >>> loss = triplet_margin_with_distance_loss(input, positive, negative,)
            >>> print(loss)
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
            [0.        , 0.57496595, 0.        ])
Y
yangguohao 已提交
1733

1734 1735 1736 1737 1738
            >>> triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='mean')
            >>> loss = triplet_margin_with_distance_loss(input, positive, negative,)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.19165532)
Y
yangguohao 已提交
1739 1740 1741

    """

1742 1743 1744 1745 1746 1747 1748 1749
    def __init__(
        self,
        distance_function=None,
        margin=1.0,
        swap=False,
        reduction: str = 'mean',
        name=None,
    ):
1750
        super().__init__()
Y
yangguohao 已提交
1751 1752 1753 1754
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in TripletMarginWithDistanceLoss "
                "should be 'sum', 'mean' or 'none', but "
1755 1756
                "received %s, which is not allowed." % reduction
            )
Y
yangguohao 已提交
1757 1758 1759 1760 1761 1762 1763
        self.margin = margin
        self.swap = swap
        self.reduction = reduction
        self.distance_function = distance_function
        self.name = name

    def forward(self, input, positive, negative):
1764 1765 1766 1767 1768 1769 1770 1771 1772
        return F.triplet_margin_with_distance_loss(
            input,
            positive,
            negative,
            margin=self.margin,
            swap=self.swap,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832


class TripletMarginLoss(Layer):
    r"""
    Creates a criterion that measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, *)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}


    where

    .. math::
        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p

    Parameters:
        margin (float, Optional):Default: :math:`1`.

        p (int, Optional):The norm degree for pairwise distance. Default: :math:`2`.

        epsilon (float, Optional):Add small value to avoid division by zero,
            default value is 1e-6.

        swap (bool, Optional):The distance swap change the negative distance to the distance between
            positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`.
            Default: ``False``.

        reduction (str, Optional):Indicate how to average the loss by batch_size.
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``

        name (str,Optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:
        input (Tensor):Input tensor, the data type is float32 or float64.
        the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

        positive (Tensor):Positive tensor, the data type is float32 or float64.
        The shape of label is the same as the shape of input.

        negative (Tensor):Negative tensor, the data type is float32 or float64.
        The shape of label is the same as the shape of input.

    Returns:
        Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative.

    Examples:
        .. code-block:: python

1833
            >>> import paddle
Y
yangguohao 已提交
1834

1835 1836 1837 1838 1839 1840 1841 1842
            >>> input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            >>> positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            >>> negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            >>> triplet_margin_loss = paddle.nn.TripletMarginLoss(reduction='none')
            >>> loss = triplet_margin_loss(input, positive, negative)
            >>> print(loss)
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
            [0.        , 0.57496595, 0.        ])
1843

1844 1845 1846 1847 1848
            >>> triplet_margin_loss = paddle.nn.TripletMarginLoss(margin=1.0, swap=True, reduction='mean')
            >>> loss = triplet_margin_loss(input, positive, negative)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            2.40039468)
Y
yangguohao 已提交
1849 1850 1851

    """

1852 1853 1854 1855 1856 1857 1858 1859 1860
    def __init__(
        self,
        margin=1.0,
        p=2.0,
        epsilon=1e-6,
        swap=False,
        reduction='mean',
        name=None,
    ):
1861
        super().__init__()
Y
yangguohao 已提交
1862 1863 1864
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in TripletMarginLoss should be 'sum', 'mean' or 'none', but "
1865 1866
                "received %s, which is not allowed." % reduction
            )
Y
yangguohao 已提交
1867 1868 1869 1870 1871 1872 1873 1874
        self.margin = margin
        self.p = p
        self.epsilon = epsilon
        self.swap = swap
        self.reduction = reduction
        self.name = name

    def forward(self, input, positive, negative):
1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885
        return F.triplet_margin_loss(
            input,
            positive,
            negative,
            margin=self.margin,
            p=self.p,
            epsilon=self.epsilon,
            swap=self.swap,
            reduction=self.reduction,
            name=self.name,
        )
1886 1887


Y
yangguohao 已提交
1888 1889
class MultiMarginLoss(Layer):
    r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between
1890
    input :math:`input` and label :math:`label`:
Y
yangguohao 已提交
1891

1892 1893
    For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
    output :math:`label_i` is:
Y
yangguohao 已提交
1894

1895 1896
    .. math::
        \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}
Y
yangguohao 已提交
1897

1898
    where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.
Y
yangguohao 已提交
1899

1900 1901
    Optionally, you can give non-equal weighting on the classes by passing
    a 1D :attr:`weight` tensor into the constructor.
Y
yangguohao 已提交
1902

1903
    The loss function for i-th sample then becomes:
Y
yangguohao 已提交
1904

1905 1906
    .. math::
        \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}
Y
yangguohao 已提交
1907 1908


1909
    Parameters:
Y
yangguohao 已提交
1910

1911
        p (int, Optional):The norm degree for pairwise distance. Default: :math:`1`.
Y
yangguohao 已提交
1912

1913
        margin (float, Optional):Default: :math:`1`.
Y
yangguohao 已提交
1914

1915 1916 1917
        weight (Tensor,optional): a manual rescaling weight given to each class.
                If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
                Default is ``'None'`` .
Y
yangguohao 已提交
1918

1919 1920 1921 1922 1923 1924
        reduction (str, optional): Indicate how to calculate the loss by batch_size,
                the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``
Y
yangguohao 已提交
1925

1926 1927
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
Y
yangguohao 已提交
1928

1929 1930
    Call parameters:
        input (Tensor): Input tensor, the data type is float32 or float64.
Y
yangguohao 已提交
1931

1932
        label (Tensor): Label tensor, 0<= label < input.shape[1], the data type is int32 or int64.
Y
yangguohao 已提交
1933

1934 1935
    Shape:
        input: 2-D Tensor, the shape is [N, C], N is batch size and `C` means number of classes.
Y
yangguohao 已提交
1936

1937
        label: 1-D Tensor, the shape is [N,].
Y
yangguohao 已提交
1938

1939
        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the label.
Y
yangguohao 已提交
1940

1941 1942
    Returns:
        A callable object of MultiMarginLoss.
Y
yangguohao 已提交
1943

1944 1945
    Examples:
        .. code-block:: python
Y
yangguohao 已提交
1946

1947 1948
            >>> import paddle
            >>> import paddle.nn as nn
Y
yangguohao 已提交
1949

1950 1951
            >>> input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            >>> label = paddle.to_tensor([0, 1, 2], dtype=paddle.int32)
Y
yangguohao 已提交
1952

1953 1954 1955 1956 1957
            >>> multi_margin_loss = nn.MultiMarginLoss(reduction='mean')
            >>> loss = multi_margin_loss(input, label)
            >>> print(loss)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            1.11111104)
1958
    """
Y
yangguohao 已提交
1959

1960 1961 1962 1963 1964 1965 1966 1967
    def __init__(
        self,
        p: int = 1,
        margin: float = 1.0,
        weight=None,
        reduction="mean",
        name=None,
    ):
1968
        super().__init__()
Y
yangguohao 已提交
1969 1970 1971
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MultiMarginLoss' should be 'sum', 'mean' or 'none', "
1972 1973
                "but received {}.".format(reduction)
            )
Y
yangguohao 已提交
1974 1975 1976 1977 1978 1979 1980
        self.p = p
        self.margin = margin
        self.weight = weight
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1981 1982 1983 1984 1985 1986 1987 1988 1989
        return F.multi_margin_loss(
            input,
            label,
            p=self.p,
            margin=self.margin,
            weight=self.weight,
            reduction=self.reduction,
            name=self.name,
        )
Y
yangguohao 已提交
1990 1991


1992 1993
class SoftMarginLoss(Layer):
    r"""
1994

1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012
    Creates a criterion that measures a two-class soft margin loss between input predictions ``input``
    and target labels ``label`` . It can be described as:

    .. math::
        Out = log(1 + exp((-label * input)))

    Parameters:

        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.

        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
2013 2014 2015 2016 2017 2018 2019
        - Input (Tensor): The input tensor with shape: ``[N, *]``,
          N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf
          Available dtype is float32, float64.
        - Label (Tensor): The target labels tensor with the same shape as
          ``input``. The target labels which values should be numbers -1 or 1.
          Available dtype is int32, int64, float32, float64.
        - Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
2020
          same as ``input`` , else the shape of output is [].
2021 2022 2023 2024 2025 2026 2027

    Returns:
        A callable object of SoftMarginLoss.

    Examples:
        .. code-block:: python

2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051
            >>> import paddle
            >>> paddle.seed(2023)
            >>> input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
            >>> label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
            >>> soft_margin_loss = paddle.nn.SoftMarginLoss()
            >>> output = soft_margin_loss(input, label)
            >>> print(output)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
            0.64022040)

            >>> input_np = paddle.uniform(shape=(5, 5), min=0.1, max=0.8, dtype="float64")
            >>> label_np = paddle.randint(high=2, shape=(5, 5), dtype="int64")
            >>> label_np[label_np==0]=-1
            >>> input = paddle.to_tensor(input_np)
            >>> label = paddle.to_tensor(label_np)
            >>> soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none')
            >>> output = soft_margin_loss(input, label)
            >>> print(output)
            Tensor(shape=[5, 5], dtype=float64, place=Place(cpu), stop_gradient=True,
            [[1.10725628, 0.48778139, 0.56217249, 1.12581404, 0.51430043],
             [0.90375795, 0.37761249, 0.43007557, 0.95089798, 0.43288319],
             [1.16043599, 0.63015939, 0.51362715, 0.43617541, 0.57783301],
             [0.81927846, 0.52558369, 0.59713908, 0.83100696, 0.50811616],
             [0.82684205, 1.02064907, 0.50296995, 1.13461733, 0.93222519]])
2052 2053 2054 2055 2056 2057
    """

    def __init__(self, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but "
2058 2059
                "received %s, which is not allowed." % reduction
            )
2060

2061
        super().__init__()
2062 2063 2064 2065
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
2066 2067 2068
        out = paddle.nn.functional.soft_margin_loss(
            input, label, self.reduction, self.name
        )
2069
        return out
Z
Zman 已提交
2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123


class GaussianNLLLoss(Layer):
    r"""Create a callable object of 'GaussianNLLLoss' to calculate Gaussian negative log likelihood loss.

    This class create a callable object of Gaussian negative log likelihood loss among ``input``, ``variance`` and
    ``label``. Note that the ``label`` is treated as samples from Gaussian distributions.
    This class is used to train a neural network predicts
    the ``input`` and ``variance`` of a gaussian distribution that ``label`` are supposed to
    be coming from. This means ``input`` and ``variance`` should be functions(the neural network) of some inputs.

    For a ``label`` having Gaussian distribution with ``input`` and ``variance`` predicted by neural network
    the loss is calculated as follows:

    .. math::
        \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
        \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{label}\right)^2}
        {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}

    where :attr:`epsilon` is used for stability. By default, the constant term of
    the loss function is omitted unless :attr:`full` is ``True``. If ``variance`` is not the same
    size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
    of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.

    Args:
        full (bool, optional): include the constant term in the loss
            calculation. Default: ``False``, means omit the constant term.
        epsilon (float, optional): value used to clamp ``variance`` (see note below), for
            stability. Default: 1e-6.
        reduction (str, optional): specifies the reduction to apply to the
            output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
            will be applied, ``'mean'``: the output is the average of all batch
            member losses, ``'sum'``: the output is the sum of all batch member
            losses. Default: ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - Input(Tensor): :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional
          dimensions. Available dtype is float32, float64.
        - Label(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input
          but with one dimension equal to 1 (to allow for broadcasting). Available dtype is float32, float64.
        - Variance(Tensor): :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but
          with one dimension equal to 1, or same shape as the input but with one fewer
          dimension (to allow for broadcasting). Available dtype is float32, float64.
        - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
          ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
          shape as the input

    Returns:
        A callable object of GaussianNLLLoss.

    Examples::
        .. code-block:: python

2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140
            >>> import paddle
            >>> import paddle.nn as nn
            >>> paddle.seed(2023)

            >>> input = paddle.randn([5, 2], dtype=paddle.float32)
            >>> label = paddle.randn([5, 2], dtype=paddle.float32)
            >>> variance = paddle.ones([5, 2], dtype=paddle.float32)

            >>> gs_nll_loss = nn.GaussianNLLLoss(full=False, epsilon=1e-6, reduction='none')
            >>> loss = gs_nll_loss(input, label, variance)
            >>> print(loss)
            Tensor(shape=[5, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
            [[0.21808575, 1.43013096],
             [1.05245590, 0.00394560],
             [1.20861185, 0.00000062],
             [0.56946373, 0.73300570],
             [0.37142906, 0.12038800]])
Z
Zman 已提交
2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170

    Note:
        The clamping of ``variance`` is ignored with respect to autograd, and so the
        gradients are unaffected by it.
    """

    def __init__(self, full=False, epsilon=1e-6, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in GaussianNLLLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction
            )

        super().__init__()
        self.full = full
        self.epsilon = epsilon
        self.reduction = reduction
        self.name = name

    def forward(self, input, label, variance):
        out = F.gaussian_nll_loss(
            input,
            label,
            variance,
            self.full,
            self.epsilon,
            self.reduction,
            self.name,
        )
        return out