loss.py 75.3 KB
Newer Older
1
# -*- coding: utf-8 -*
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
# TODO: define loss functions of neural network
17
import numpy as np
L
Leo Chen 已提交
18
import paddle.fluid as fluid
19
import paddle
20
from .. import functional as F
21
from paddle.fluid.framework import _varbase_creator, in_dygraph_mode, _in_legacy_dygraph
Z
zhiboniu 已提交
22
from .. import Layer
Z
zhiboniu 已提交
23
from paddle import in_dynamic_mode
24

25 26
__all__ = []

L
Leo Chen 已提交
27

Z
zhiboniu 已提交
28
class BCEWithLogitsLoss(Layer):
29
    r"""
30 31 32 33 34 35 36 37 38 39 40 41 42
    This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
    Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
    layer and some reduce operations.

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

    First this operator calculate loss function as follows:

    .. math::
43
           Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
44

45
    We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
46 47

    .. math::
48
           Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
49

50
    For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
51 52 53
    we reformulate the loss as follows:

    .. math::
54
           Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

    Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
        logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``logit``
            is usually the output of Linear layer. Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``logit``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``logit`` , else the shape of output is scalar.

    Returns:
        A callable object of BCEWithLogitsLoss.

    Examples:

        .. code-block:: python
            import paddle
            logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
            label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
            bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
            output = bce_logit_loss(logit, label)
            print(output.numpy())  # [0.45618808]

    """

    def __init__(self,
                 weight=None,
                 reduction='mean',
                 pos_weight=None,
                 name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(BCEWithLogitsLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
        self.pos_weight = pos_weight
        self.name = name

    def forward(self, logit, label):
        out = paddle.nn.functional.binary_cross_entropy_with_logits(
            logit, label, self.weight, self.reduction, self.pos_weight,
            self.name)
        return out


Z
zhiboniu 已提交
132
class CrossEntropyLoss(Layer):
133
    r"""
134
    By default, this operator implements the cross entropy loss function with softmax. This function 
135
    combines the calculation of the softmax operation and the cross entropy loss function 
136
    to provide a more numerically stable computing.
S
swtkiwi 已提交
137

138
    This operator will calculate the cross entropy loss function without softmax when use_softmax=False.
139

140 141 142
    By default, this operator will calculate the mean of the result, and you can also affect 
    the default behavior by using the reduction parameter. Please refer to the part of 
    parameters for details.
143

144 145 146
    This operator can be used to calculate the softmax cross entropy loss with soft and hard labels.
    Where, the hard labels mean the actual label value, 0, 1, 2, etc.  And the soft labels 
    mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
147

148
    The calculation of this operator includes the following two steps.
149

150
    -  **I.softmax cross entropy** 
151

152
        1. Hard label (each sample can only be assigned into one category)
153

154
        1.1. when use_softmax=True
155

156 157
            .. math::
              \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N
158

159
            where, N is the number of samples and C is the number of categories.
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        1.2. when use_softmax=False

            .. math::
              \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).


        2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1).

        2.1. when use_softmax=True

            .. math::
              \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories.

        2.2. when use_softmax=False

            .. math::
              \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).



    -  **II.Weight and reduction processing** 

        1. Weight

            If the ``weight`` parameter is ``None`` , go to the next step directly.

            If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight
            according to soft_label = False or True as follows.

            1.1. Hard labels (soft_label = False)

            .. math::
                \\loss_j=loss_j*weight[label_j] 
200

201

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
            1.2. Soft labels (soft_label = True)

             .. math::
                \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right)

        2. reduction

            2.1 if the ``reduction`` parameter is ``none`` 

            Return the previous result directly

            2.2 if the ``reduction`` parameter is ``sum`` 

            Return the sum of the previous results

            .. math::
               \\loss=\sum_{j}loss_j

            2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to 
            the ``weight`` parameter as follows. 

            2.3.1. If the  ``weight``  parameter is ``None`` 

            Return the average value of the previous results

             .. math::
                \\loss=\sum_{j}loss_j/N

            where, N is the number of samples and C is the number of categories.

            2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned

            1. Hard labels (soft_label = False)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] 

            2. Soft labels (soft_label = True)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right)
 
 
245
    Parameters:
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262

        - **weight** (Tensor, optional)

            a manual rescaling weight given to each class. 
            If given, has to be a Tensor of size C and the data type is float32, float64. 
            Default is ``'None'`` .

        - **ignore_index** (int64, optional)

            Specifies a target value that is ignored
            and does not contribute to the loss. A negative value means that no label 
            value needs to be ignored. Only valid when soft_label = False.  
            Default is ``-100`` .

        - **reduction** (str, optional)

            Indicate how to average the loss by batch_size,
263 264 265 266 267
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
268

269
        - **soft_label** (bool, optional)
270

271 272 273
            Indicate whether label is soft. 
            If soft_label=False, the label is hard.  If soft_label=True, the label is soft.
            Default is ``False``.
274

275 276 277 278 279 280 281 282 283 284 285 286
        - **axis** (int, optional)

            The index of dimension to perform softmax calculations. 
            It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the number 
            of dimensions of input :attr:`input`. 
            Default is ``-1`` .

        - **use_softmax** (bool, optional)

            Indicate whether compute softmax before cross_entropy.
            Default is ``True``.

Z
zhiboniu 已提交
287
        - **name** (str, optional)
288 289 290 291 292 293 294 295 296 297

            The name of the operator. Default is ``None`` .
            For more information, please refer to :ref:`api_guide_Name` .


    Shape:

        - **input** (Tensor)

            Input tensor, the data type is float32, float64. Shape is
298
        :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes ,  ``k >= 1`` . 
299 300 301 302 303 304 305

            Note: 

                1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the 
                output of softmax operator, which will produce incorrect results.

                2. when use_softmax=False, it expects the output of softmax operator.
306

307 308 309

        - **label** (Tensor)

Z
zhiboniu 已提交
310
            1. If soft_label=False, the shape is 
311 312 313 314 315
            :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1.
            the data type is int32, int64, float32, float64, where each value is [0, C-1].

            2. If soft_label=True, the shape and data type should be same with ``input`` , 
            and the sum of the labels for each sample should be 1.
316

317 318 319 320 321 322 323 324 325 326 327 328 329 330
        - **output** (Tensor)

            Return the softmax cross_entropy loss of ``input`` and ``label``.

            The data type is the same as input.

            If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``.

            If :attr:`reduction` is ``'none'``:

            1. If soft_label = False, the dimension of return value is the same with ``label`` . 

            2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . 

331
    Examples:
332 333 334

        .. code-block:: python
            
335
            # hard labels
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
            import paddle
            paddle.seed(99999)
            N=100
            C=200
            reduction='mean'
            input =  paddle.rand([N, C], dtype='float64')  
            label =  paddle.randint(0, C, shape=[N], dtype='int64')
            weight = paddle.rand([C], dtype='float64') 
            
            cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
                weight=weight, reduction=reduction)
            dy_ret = cross_entropy_loss(
                                       input,
                                       label)
            print(dy_ret.numpy()) #[5.41993642]

352
        .. code-block:: python
353 354

            # soft labels
355
            import paddle
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
            paddle.seed(99999)
            axis = -1
            ignore_index = -100
            N = 4
            C = 3
            shape = [N, C]
            reduction='mean'
            weight = None
            logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels /= paddle.sum(labels, axis=axis, keepdim=True)
            paddle_loss_mean = paddle.nn.functional.cross_entropy(
                                                                  logits,  
                                                                  labels, 
                                                                  soft_label=True, 
                                                                  axis=axis,
                                                                  weight=weight,
                                                                  reduction=reduction)
            print(paddle_loss_mean.numpy()) #[1.12908343]

376 377
    """

378 379 380 381 382 383
    def __init__(self,
                 weight=None,
                 ignore_index=-100,
                 reduction='mean',
                 soft_label=False,
                 axis=-1,
384
                 use_softmax=True,
385
                 name=None):
386 387 388
        super(CrossEntropyLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
389
        self.ignore_index = ignore_index
390 391
        self.soft_label = soft_label
        self.axis = axis
392
        self.use_softmax = use_softmax
393
        self.name = name
394 395

    def forward(self, input, label):
396 397 398 399 400 401 402 403 404
        ret = paddle.nn.functional.cross_entropy(input,
                                                 label,
                                                 weight=self.weight,
                                                 ignore_index=self.ignore_index,
                                                 reduction=self.reduction,
                                                 soft_label=self.soft_label,
                                                 axis=self.axis,
                                                 use_softmax=self.use_softmax,
                                                 name=self.name)
405 406

        return ret
407 408


Z
zhiboniu 已提交
409
class HSigmoidLoss(Layer):
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    """
    Hierarchical Sigmoid Layer.
    
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
    Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
    represents the number of classes or the size of word dict.

    The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>_`. For the custom
    tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        feature_size (int): The number of features.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
            should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
            :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        weight_attr (ParamAttr, optional): The parameter attribute for the learnable weights
            of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
            ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
            initialized with Xavier. Default is None.
        bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
            is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
            hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
            set, the bias is initialized zero. Default is None.
        is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and 
            `path_code` should be passed to its forward method, otherwise `path_table` and `path_code`
            should not be passed to its forward method. Default is False.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True,
            the gradient of weight and input will be sparse. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input (Tensor): The input tensor. The shapes is [N, D], where N is batch size and D is feature size. It's data type should be float32, float64.
        label (Tensor): It's shapes is [N, 1]. It's data type should be int64.
        output (Tensor): The HSigmoid Loss of ``input`` and ``label``. Shape is [N, 1]

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_device('cpu')

L
Linjie Chen 已提交
466 467 468 469 470
            input = paddle.uniform([4, 3])
            # [[0.56194401  -0.22450298  -0.10741806] # random
            #  [0.36136317  0.23556745  0.88748658] # random
            #  [0.18151939  0.80947340  -0.31078976] # random
            #  [0.68886101  -0.14239830  -0.41297770]] # random
471 472 473
            label = paddle.to_tensor([0, 1, 4, 5])
            m = paddle.nn.HSigmoidLoss(3, 5)
            out = m(input, label)
L
Linjie Chen 已提交
474 475 476 477
            # [[2.42524505]
            #  [1.74917245]
            #  [3.14571381]
            #  [2.34564662]]
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
    """

    def __init__(self,
                 feature_size,
                 num_classes,
                 weight_attr=None,
                 bias_attr=None,
                 is_custom=False,
                 is_sparse=False,
                 name=None):
        super(HSigmoidLoss, self).__init__()
        if (num_classes < 2) and (not is_custom):
            raise ValueError(
                "num_classes must not be less than 2 with default tree")

        if (not is_custom) and (is_sparse):
            print("Sparse mode should not be used without custom tree")
            is_sparse = False

        self._feature_size = feature_size
        self._num_classes = num_classes
        self._is_custom = is_custom
        self._is_sparse = is_sparse

        self._weight_attr = weight_attr
        self._bias_attr = bias_attr

        self._name = name
        self._dtype = paddle.get_default_dtype()

        remote_prefetch = is_sparse
        print("With sparse mode, if your models has only"
              " small parameter prefetch may cause speed down")

        C = self._num_classes if is_custom else self._num_classes - 1
513 514 515 516 517 518 519 520
        self.weight = self.create_parameter([C, self._feature_size],
                                            attr=self._weight_attr,
                                            is_bias=False,
                                            dtype=self._dtype)
        self.bias = self.create_parameter([C, 1],
                                          attr=self._bias_attr,
                                          is_bias=True,
                                          dtype=self._dtype)
521 522

    def forward(self, input, label, path_table=None, path_code=None):
523 524 525 526 527 528 529 530 531
        out = F.hsigmoid_loss(input,
                              label,
                              self._num_classes,
                              self.weight,
                              self.bias,
                              path_table=path_table,
                              path_code=path_code,
                              is_sparse=self._is_sparse,
                              name=self._name)
532 533 534
        return out


Z
zhiboniu 已提交
535
class MSELoss(Layer):
536
    r"""
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
    **Mean Square Error Loss**
    Computes the mean square error (squared L2 norm) of given input and label.

    If :attr:`reduction` is set to ``'none'``, loss is calculated as:

    .. math::
        Out = (input - label)^2

    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:

    .. math::
        Out = \operatorname{mean}((input - label)^2)

    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:

    .. math::
        Out = \operatorname{sum}((input - label)^2)

555
    where `input` and `label` are `float32` tensors of same shape.
556 557 558 559

    Parameters:
        reduction (string, optional): The reduction method for the output,
            could be 'none' | 'mean' | 'sum'.
560 561 562
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
563 564
            Default is ``'mean'``.

B
Bai Yifan 已提交
565 566 567 568
    Shape:
        input (Tensor): Input tensor, the data type is float32 or float64
        label (Tensor): Label tensor, the data type is float32 or float64
        output (Tensor): output tensor storing the MSE loss of input and label, the data type is same as input.
569 570 571

    Examples:
        .. code-block:: python
572 573 574 575 576 577 578

            import numpy as np
            import paddle

            input_data = np.array([1.5]).astype("float32")
            label_data = np.array([1.7]).astype("float32")

B
Bai Yifan 已提交
579 580 581 582
            mse_loss = paddle.nn.loss.MSELoss()
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
            output = mse_loss(input, label)
583
            print(output)
B
Bai Yifan 已提交
584
            # [0.04000002]
585 586 587 588 589 590 591 592 593 594 595
    """

    def __init__(self, reduction='mean'):
        super(MSELoss, self).__init__()
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', "
                "but received {}.".format(reduction))
        self.reduction = reduction

    def forward(self, input, label):
Z
zhiboniu 已提交
596
        if not in_dynamic_mode():
597 598 599 600 601 602
            fluid.data_feeder.check_variable_and_dtype(input, 'input',
                                                       ['float32', 'float64'],
                                                       'MSELoss')
            fluid.data_feeder.check_variable_and_dtype(label, 'label',
                                                       ['float32', 'float64'],
                                                       'MSELoss')
603

604 605 606 607 608
        if in_dygraph_mode():
            square_out = paddle._C_ops.final_state_square(
                paddle.subtract(input, label))
        else:
            square_out = paddle.square(paddle.subtract(input, label))
609 610 611 612 613 614 615 616 617 618
        if self.reduction == 'none':
            return square_out

        reduce_op = 'reduce_mean'
        if self.reduction == 'sum':
            reduce_op = 'reduce_sum'

        return getattr(fluid.layers, reduce_op)(square_out)


Z
zhiboniu 已提交
619
class L1Loss(Layer):
620
    r"""
L
Leo Chen 已提交
621
    This interface is used to construct a callable object of the ``L1Loss`` class.
622
    The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows.
623

624
     If `reduction` set to ``'none'``, the loss is:
L
Leo Chen 已提交
625 626

    .. math::
627
        Out = \lvert input - label\rvert
628

629
    If `reduction` set to ``'mean'``, the loss is:
630

L
Leo Chen 已提交
631
    .. math::
632
        Out = MEAN(\lvert input - label\rvert)
633

634
    If `reduction` set to ``'sum'``, the loss is:
635

L
Leo Chen 已提交
636
    .. math::
637
        Out = SUM(\lvert input - label\rvert)
L
Leo Chen 已提交
638

639

L
Leo Chen 已提交
640
    Parameters:
641
        reduction (str, optional): Indicate the reduction to apply to the loss,
L
Leo Chen 已提交
642
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
643 644 645
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
L
Leo Chen 已提交
646
            Default is ``'mean'``.
647 648 649
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
650 651
        input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
652
        output (Tensor): The L1 Loss of ``input`` and ``label``.
653 654
            If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` .
            If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
655

L
Leo Chen 已提交
656 657
    Examples:
        .. code-block:: python
C
Chen Long 已提交
658
            
L
Leo Chen 已提交
659
            import paddle
660
            import numpy as np
661

662 663 664 665
            input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
            label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
666

C
Chen Long 已提交
667
            l1_loss = paddle.nn.L1Loss()
668
            output = l1_loss(input, label)
669
            print(output.numpy())
670 671
            # [0.35]

C
Chen Long 已提交
672
            l1_loss = paddle.nn.L1Loss(reduction='sum')
673
            output = l1_loss(input, label)
674
            print(output.numpy())
675 676
            # [1.4]

C
Chen Long 已提交
677
            l1_loss = paddle.nn.L1Loss(reduction='none')
678
            output = l1_loss(input, label)
C
Chen Long 已提交
679
            print(output)
680
            # [[0.20000005 0.19999999]
681
            # [0.2        0.79999995]]
L
Leo Chen 已提交
682 683
    """

684
    def __init__(self, reduction='mean', name=None):
L
Leo Chen 已提交
685 686 687 688 689 690
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)
        super(L1Loss, self).__init__()
        self.reduction = reduction
691
        self.name = name
L
Leo Chen 已提交
692

693
    def forward(self, input, label):
694 695 696 697
        return paddle.nn.functional.l1_loss(input,
                                            label,
                                            self.reduction,
                                            name=self.name)
C
ceci3 已提交
698 699


Z
zhiboniu 已提交
700
class BCELoss(Layer):
C
ceci3 已提交
701
    """
C
ceci3 已提交
702
    This interface is used to construct a callable object of the ``BCELoss`` class.
703 704
    The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
    and target labels ``label`` . The binary_cross_entropy loss can be described as:
C
ceci3 已提交
705

C
ceci3 已提交
706
    If :attr:`weight` is set, the loss is:
C
ceci3 已提交
707 708

    .. math::
C
ceci3 已提交
709
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
710

C
ceci3 已提交
711
    If :attr:`weight` is None, the loss is:
C
ceci3 已提交
712 713

    .. math::
C
ceci3 已提交
714 715
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

716
    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
C
ceci3 已提交
717

C
ceci3 已提交
718
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
C
ceci3 已提交
719

C
ceci3 已提交
720 721
    .. math::
        Out = MEAN(Out)
722

C
ceci3 已提交
723
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
C
ceci3 已提交
724

C
ceci3 已提交
725 726
    .. math::
        Out = SUM(Out)
C
ceci3 已提交
727

728
    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
C
ceci3 已提交
729 730
    should be numbers between 0 and 1.

C
ceci3 已提交
731
    Parameters:
732 733
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
C
ceci3 已提交
734
            is float32, float64. Default is ``'None'``.
735
        reduction (str, optional): Indicate how to average the loss by batch_size,
C
ceci3 已提交
736
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
C
ceci3 已提交
737
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
738
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
C
ceci3 已提交
739
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
C
ceci3 已提交
740
            Default is ``'mean'``.
741 742 743 744
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
Z
Zhong Hui 已提交
745
        input (Tensor): 2-D tensor with shape: [N, *], N is batch_size, `*` means
746 747 748 749 750 751 752
            number of additional dimensions. The input ``input`` should always
            be the output of sigmod.  Available dtype is float32, float64.
        label (Tensor): 2-D tensor with the same shape as ``input``. The target
            labels which values should be numbers between 0 and 1. Available
            dtype is float32, float64.
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``input`` , else the shape of output is scalar.
C
ceci3 已提交
753

754
    Returns:
C
ceci3 已提交
755 756
        A callable object of BCELoss.

C
ceci3 已提交
757 758
    Examples:
        .. code-block:: python
C
ceci3 已提交
759

C
ceci3 已提交
760 761 762 763
            import numpy as np
            import paddle
            input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
            label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
764

Z
Zhong Hui 已提交
765 766
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
C
Chen Long 已提交
767
            bce_loss = paddle.nn.BCELoss()
768
            output = bce_loss(input, label)
C
Chen Long 已提交
769
            print(output)  # [0.65537095]
770

C
ceci3 已提交
771 772
    """

773
    def __init__(self, weight=None, reduction='mean', name=None):
C
ceci3 已提交
774 775 776 777 778 779 780 781
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(BCELoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
782
        self.name = name
C
ceci3 已提交
783 784

    def forward(self, input, label):
785 786 787 788
        out = paddle.nn.functional.binary_cross_entropy(input, label,
                                                        self.weight,
                                                        self.reduction,
                                                        self.name)
789
        return out
790 791


Z
zhiboniu 已提交
792
class NLLLoss(Layer):
793
    r"""
S
swtkiwi 已提交
794

795
    This class accepts input and target label and returns negative log likelihood
796
    cross error. It is useful to train a classification problem with C classes.
797

798
    The input for the loss is epected to contain log-probabilities of
799
    each classes. It has to be a Tensor of size either (batch_size, C) or
800 801 802 803
    (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
    The label for the loss should be a class index in the range [0, C-1]
    where C is the number of classes. If ignore_index is specified, the
    specified target value does not contribute to the input gradient.
804

805 806 807
    If the optional argument `weight` is provided, it should be a 1D Tensor
    assigning weight to each of the classed. This is particularly useful
    when you have an unbalanced training set.
808

809 810 811 812
    The loss is calculated as follows.
    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
813 814

        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
815
        l_n = - w_{y_n} x_{n,y_n}, \quad
816
        w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\},
817 818 819 820 821

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then

    .. math::
822 823 824 825 826 827 828 829 830 831

        \ell(x, y) =
        \left\{
            \begin{array}{lcl}
            \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
            \text{if  reduction} = \text{'mean';}\\
            \sum_{n=1}^N l_n,  &
            \text{if  reduction} = \text{'sum'.}
            \end{array}
        \right.
832 833

    Parameters:
834 835
        weight (Tensor, optional): Weight tensor, a manual rescaling weight given
            to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
836
            it treated as if having all ones. the data type is
837
            float32, float64, Default is ``'None'``.
838 839
        ignore_index (int64, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient.
840
        reduction (str, optional): Indicate how to average the loss,
841
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
842 843 844
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
845
            Default is ``'mean'``.
846 847
         name (str, optional): Name for the operation (optional, default is None).
             For more information, please refer to :ref:`api_guide_Name`.
848

849 850 851 852 853 854 855 856 857
    Shape:
        input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
            But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
            The data type is float32, float64.
        label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
            The data type is int64.
        output (Tensor): the `negative log likelihood loss` between input `x` and `label`.
            If `reduction` is `'none'`, the shape is `[N, *]`.
            If `reduction` is `'sum'` or `'mean'`, the shape is `[1]`.
858 859 860 861

    Examples:
        .. code-block:: python

862
                import paddle
863

864
                nll_loss = paddle.nn.loss.NLLLoss()
865
                log_softmax = paddle.nn.LogSoftmax(axis=1)
866

867 868 869 870 871
                input = paddle.to_tensor([[0.88103855, 0.9908683 , 0.6226845 ],
                                          [0.53331435, 0.07999352, 0.8549948 ],
                                          [0.25879037, 0.39530203, 0.698465  ],
                                          [0.73427284, 0.63575995, 0.18827209],
                                          [0.05689114, 0.0862954 , 0.6325046 ]], "float32")
872
                log_out = log_softmax(input)
873
                label = paddle.to_tensor([0, 2, 1, 1, 0], "int64")
874
                result = nll_loss(log_out, label)
875
                print(result) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [1.07202101])
876

877
    """
878

879 880 881 882 883 884
    def __init__(self,
                 weight=None,
                 ignore_index=-100,
                 reduction='mean',
                 name=None):
        if reduction not in ['sum', 'mean', 'none']:
885
            raise ValueError(
886 887 888 889 890 891 892
                "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
                "'none', but received %s, which is not allowed." % reduction)
        super(NLLLoss, self).__init__()
        self._weight = weight
        self._ignore_index = ignore_index
        self._reduction = reduction
        self._name = name
893

894
    def forward(self, input, label):
895 896 897 898 899 900
        return F.nll_loss(input,
                          label,
                          weight=self._weight,
                          ignore_index=self._ignore_index,
                          reduction=self._reduction,
                          name=self._name)
901 902


Z
zhiboniu 已提交
903
class KLDivLoss(Layer):
904
    r"""
905 906 907 908 909 910 911 912 913
    This interface calculates the Kullback-Leibler divergence loss
    between Input(X) and Input(Target). Notes that Input(X) is the
    log-probability and Input(Target) is the probability.

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

    Parameters:
L
LielinJiang 已提交
914 915 916 917 918 919 920
        reduction (Tensor): Indicate how to average the loss,
             the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
             If `reduction` is ``'mean'``, the reduced mean loss is returned;
             If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
             if `reduction` is ``'sum'``, the reduced sum loss is returned;
             if `reduction` is ``'none'``, no reduction will be apllied.
             Default is ``'mean'``.
921 922

    Shape:
923 924 925 926 927 928

        - input (Tensor): (N, *), where * means, any number of additional dimensions.

        - label (Tensor): (N, *), same shape as input.

        - output (Tensor): tensor with shape: [1] by default.
929 930 931 932 933 934 935 936


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            import paddle.nn as nn
937

938 939 940 941
            shape = (5, 20)
            x = np.random.uniform(-10, 10, shape).astype('float32')
            target = np.random.uniform(-10, 10, shape).astype('float32')

L
LielinJiang 已提交
942
            # 'batchmean' reduction, loss shape will be [1]
943
            kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
944 945
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
L
LielinJiang 已提交
946
            # shape=[1]
947

948 949
            # 'mean' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='mean')
950 951
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
952 953 954 955
            # shape=[1]

            # 'sum' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='sum')
956 957
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
958 959 960 961
            # shape=[1]

            # 'none' reduction, loss shape is same with X shape
            kldiv_criterion = nn.KLDivLoss(reduction='none')
962 963
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
964 965 966 967 968 969 970 971
            # shape=[5, 20]
    """

    def __init__(self, reduction='mean'):
        super(KLDivLoss, self).__init__()
        self.reduction = reduction

    def forward(self, input, label):
L
LielinJiang 已提交
972
        out = F.kl_div(input, label, self.reduction)
973 974 975
        return out


Z
zhiboniu 已提交
976
class MarginRankingLoss(Layer):
977
    r"""
978 979

    This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
980
    The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
981 982
    , use the math function as follows.

983
    .. math::
984
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

1003
    Shape:
N
Noel 已提交
1004 1005 1006
    
        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

1007
        other: N-D Tensor, `other` have the same shape and dtype as `input`.
N
Noel 已提交
1008

1009
        label: N-D Tensor, label have the same shape and dtype as `input`.
N
Noel 已提交
1010

1011
        output: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
1012 1013 1014 1015 1016 1017 1018 1019

    Returns:
        A callable object of MarginRankingLoss.

    Examples:

        .. code-block:: python

1020 1021
            import paddle

C
Chen Long 已提交
1022 1023
            input = paddle.to_tensor([[1, 2], [3, 4]], dtype="float32")
            other = paddle.to_tensor([[2, 1], [2, 4]], dtype="float32")
Z
Zhong Hui 已提交
1024
            label = paddle.to_tensor([[1, -1], [-1, -1]], dtype="float32")
1025
            margin_rank_loss = paddle.nn.MarginRankingLoss()
1026
            loss = margin_rank_loss(input, other, label)
1027 1028 1029

            print(loss)
            # [0.75]
1030 1031 1032 1033 1034
    """

    def __init__(self, margin=0.0, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
1035
                "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
1036 1037 1038 1039 1040 1041
                "received %s, which is not allowed." % reduction)
        super(MarginRankingLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.name = name

1042
    def forward(self, input, other, label):
1043 1044 1045 1046
        out = paddle.nn.functional.margin_ranking_loss(input, other, label,
                                                       self.margin,
                                                       self.reduction,
                                                       self.name)
1047
        return out
1048 1049


Z
zhiboniu 已提交
1050
class CTCLoss(Layer):
1051 1052
    """

1053 1054 1055
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
1056 1057 1058 1059 1060 1061 1062
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.

    Shape:
1063
        log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
1064 1065 1066
        labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
1067
        norm_by_times (bool, default false) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
1068 1069 1070

    Returns:
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
1071

1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108
    Examples:

        .. code-block:: python

            # declarative mode
            import numpy as np
            import paddle

            # length of the longest logit sequence
            max_seq_length = 4
            #length of the longest label sequence
            max_label_length = 3
            # number of logit sequences
            batch_size = 2
            # class num
            class_num = 3

            np.random.seed(1)
            log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
                                    [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

                                    [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
                                    [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

                                    [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
                                    [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

                                    [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
                                    [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

                                    [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
                                    [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
            labels = np.array([[1, 2, 2],
                            [1, 2, 2]]).astype("int32")
            input_lengths = np.array([5, 5]).astype("int64")
            label_lengths = np.array([3, 3]).astype("int64")

1109 1110 1111 1112
            log_probs = paddle.to_tensor(log_probs)
            labels = paddle.to_tensor(labels)
            input_lengths = paddle.to_tensor(input_lengths)
            label_lengths = paddle.to_tensor(label_lengths)
1113

1114 1115
            loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
                input_lengths,
1116
                label_lengths)
1117
            print(loss)  #[3.9179852 2.9076521]
1118

1119 1120
            loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
                input_lengths,
1121
                label_lengths)
1122
            print(loss)  #[1.1376063]
1123 1124 1125 1126 1127 1128 1129
    """

    def __init__(self, blank=0, reduction='mean'):
        super(CTCLoss, self).__init__()
        self.blank = blank
        self.reduction = reduction

1130 1131 1132 1133 1134
    def forward(self,
                log_probs,
                labels,
                input_lengths,
                label_lengths,
H
Hui Zhang 已提交
1135
                norm_by_times=False):
1136 1137 1138 1139 1140 1141 1142
        return paddle.nn.functional.ctc_loss(log_probs,
                                             labels,
                                             input_lengths,
                                             label_lengths,
                                             self.blank,
                                             self.reduction,
                                             norm_by_times=norm_by_times)
1143 1144


Z
zhiboniu 已提交
1145
class SmoothL1Loss(Layer):
1146
    r"""
1147 1148 1149 1150 1151 1152 1153
    This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

1154
         loss(x,y) = \frac{1}{n}\sum_{i}z_i
1155 1156 1157 1158 1159

    where z_i is given by:

    .. math::

1160 1161
        \mathop{z_i} = \left\{\begin{array}{rcl}
        0.5(x_i - y_i)^2 & & {if |x_i - y_i| < delta} \\
1162
        delta * |x_i - y_i| - 0.5 * delta^2 & & {otherwise}
1163
        \end{array} \right.
1164 1165 1166 1167 1168 1169 1170 1171

    Parameters:
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
1172
        delta (float, optional): Specifies the hyperparameter delta to be used.
1173 1174 1175 1176 1177 1178 1179 1180
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
            negative/zero values. Default = 1.0
        name (str, optional): Name for the operation (optional, default is
            None). For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:

1181 1182 1183 1184 1185 1186
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), 
        where C is number of classes, and if shape is more than 2D, 
        this is (N, C, D1, D2,..., Dk), k >= 1.

        label (Tensor): Label tensor, the data type is float32 or float64. 
        The shape of label is the same as the shape of input.
1187

1188 1189
    Returns:
        Tensor, The tensor storing the smooth_l1_loss of input and label.
1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            input_data = np.random.rand(3,3).astype("float32")
            label_data = np.random.rand(3,3).astype("float32")
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
            loss = paddle.nn.SmoothL1Loss()
            output = loss(input, label)
G
Guanghua Yu 已提交
1202
            print(output)
1203 1204 1205 1206 1207 1208 1209 1210 1211
    """

    def __init__(self, reduction='mean', delta=1.0, name=None):
        super(SmoothL1Loss, self).__init__()
        self.reduction = reduction
        self.delta = delta
        self.name = name

    def forward(self, input, label):
1212 1213 1214 1215 1216
        return F.smooth_l1_loss(input,
                                label,
                                reduction=self.reduction,
                                delta=self.delta,
                                name=self.name)
1217 1218


Y
yangguohao 已提交
1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296
class MultiLabelSoftMarginLoss(Layer):
    r"""Creates a criterion that optimizes a multi-class multi-classification
        hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
        and output :math:`y` (which is a 2D `Tensor` of target class indices).
        For each sample in the mini-batch:

        .. math::
            \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}

        where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
        :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
        :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
        and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
        :math:`y` and :math:`x` must have the same size.

        Parameters:
	        weight (Tensor,optional): a manual rescaling weight given to each class.
                    If given, has to be a Tensor of size C and the data type is float32, float64.
                    Default is ``'None'`` .
            reduction (str, optional): Indicate how to average the loss by batch_size,
                    the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                    If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                    If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                    If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                    Default: ``'mean'``
            name (str, optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.

        Call parameters:
            input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
            label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

        Shape:
            input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
            label: N-D Tensor, same shape as the input.
            output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

        Returns:
            A callable object of MultiLabelSoftMarginLoss.

        Examples:
            .. code-block:: python

                import paddle
                import paddle.nn as nn

                input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
                label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

                multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none')
                loss = multi_label_soft_margin_loss(input, label)
                print(loss)
                # Tensor([3.49625897, 0.71111226, 0.43989015])

                multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='mean')
                loss = multi_label_soft_margin_loss(input, label)
                print(loss)
                # Tensor([1.54908717])
        """

    def __init__(self, weight=None, reduction="mean", name=None):
        super(MultiLabelSoftMarginLoss, self).__init__()
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', "
                "but received {}.".format(reduction))
        self.weight = weight
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
        return F.multi_label_soft_margin_loss(input,
                                              label,
                                              weight=self.weight,
                                              reduction=self.reduction,
                                              name=self.name)


1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383
class HingeEmbeddingLoss(Layer):
    r"""
    This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
    This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
    and is typically used for learning nonlinear embeddings or semi-supervised learning.

    The loss function for :math:`n`-th sample in the mini-batch is

    .. math::
        l_n = \begin{cases}
            x_n, & \text{if}\; y_n = 1,\\
            \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
        \end{cases}

    and the total loss functions is

    .. math::
        \ell(x, y) = \begin{cases}
            \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
        \end{cases}

    where :math:`L = \{l_1,\dots,l_N\}^\top`.

    Parameters:

        margin (float, optional): Specifies the hyperparameter margin to be used.
            The value determines how large the input need to be to calculate in
            hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss.
            Default = 1.0
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:

        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.

        label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

    Shape:

        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

        label: N-D Tensor, same shape as the input.

        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

    Returns:

        Tensor, The tensor variable storing the hinge_embedding_loss of input and label.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            # label elements in {1., -1.}
            label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

            hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='none')
            loss = hinge_embedding_loss(input, label)
            print(loss)
            # Tensor([[0., -2., 0.],
            #         [0., -1., 2.],
            #         [1., 1., 1.]])

            hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='mean')
            loss = hinge_embedding_loss(input, label)
            print(loss)
            # Tensor([0.22222222])
    """

    def __init__(self, margin=1.0, reduction="mean", name=None):
        super(HingeEmbeddingLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
1384 1385 1386 1387 1388
        return F.hinge_embedding_loss(input,
                                      label,
                                      reduction=self.reduction,
                                      margin=self.margin,
                                      name=self.name)
1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478


class CosineEmbeddingLoss(Layer):
    r"""
    This interface is used to construct a callable object of the ``CosineEmbeddingLoss`` class.
    The CosineEmbeddingLoss layer measures the cosine_embedding loss between input predictions ``input1``, ``input2``
    and target labels ``label`` with values 1 or 0. This is used for measuring whether two inputs are similar or
    dissimilar and is typically used for learning nonlinear embeddings or semi-supervised learning.
    The cosine embedding loss can be described as:

    If label = 1, then the loss value can be calculated as follow:

    .. math::
        Out = 1 - cos(input1, input2)

    If label = -1, then the loss value can be calculated as follow:

    .. math::
        Out = max(0, cos(input1, input2)) - margin

    The operator cos can be described as follow:
     .. math::
        cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2}

    Parameters:
        margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
            :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
            default value is :math:`0`.
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
                         Available dtypes are float32, float64.
        input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, 'M' means the length of input array.
                         Available dtypes are float32, float64.
        label (Tensor): tensor with shape: [N] or [1]. The target labels values should be -1 or 1.
                         Available dtypes are int32, int64, float32, float64.
        output (Tensor): Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``.
                         If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` .
                         If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].

    Examples:
        .. code-block:: python

            import paddle

            input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32')
            input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32')
            label = paddle.to_tensor([1, -1], 'int64')

            cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='mean')
            output = cosine_embedding_loss(input1, input2, label)
            print(output) # [0.21155193]

            cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='sum')
            output = cosine_embedding_loss(input1, input2, label)
            print(output) # [0.42310387]

            cosine_embedding_loss = paddle.nn.CosineEmbeddingLoss(margin=0.5, reduction='none')
            output = cosine_embedding_loss(input1, input2, label)
            print(output) # [0.42310387, 0.        ]

    """

    def __init__(self, margin=0, reduction='mean', name=None):
        if margin > 1 or margin < -1:
            raise ValueError(
                "The value of 'margin' should be in the interval of [-1, 1], but received %f, which is not allowed."
                % margin)
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' should be 'sum', 'mean' or "
                "'none', but received %s, which is not allowed." % reduction)
        super(CosineEmbeddingLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input1, input2, label):
        return F.cosine_embedding_loss(input1,
                                       input2,
                                       label,
                                       margin=self.margin,
                                       reduction=self.reduction,
                                       name=self.name)
Y
yangguohao 已提交
1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585


class TripletMarginWithDistanceLoss(Layer):
    r"""
    Creates a criterion that measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, D)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}

    where the default `distance_function`
    
    .. math::
    	d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_2
    
    or user can define their own distance function. `margin` is a nonnegative margin representing the minimum difference 
    between the positive and negative distances that is required for the loss to be 0. If `swap` is true, it will compare distance of (input, negative) with
    distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.

    Parameters:
        distance_function (Callable, Optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
	
        margin (float, Optional):Default: :math:`1`.A nonnegative margin representing the minimum difference
                between the positive and negative distances required for the loss to be 0. Larger
                margins penalize cases where the negative examples are not distant enough from the
                anchors, relative to the positives.
		
        swap (bool, Optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
                and negative samples) if swap distance smaller than negative distance. Default: ``False``.

        reduction (str, Optional):Indicate how to average the loss by batch_size.
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
	    
    Shapes:
        input (Tensor):Input tensor, the data type is float32 or float64.
	the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

        positive (Tensor):Positive tensor, the data type is float32 or float64.
	The shape of label is the same as the shape of input.

        negative (Tensor):Negative tensor, the data type is float32 or float64.
	The shape of label is the same as the shape of input.
	
	    output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.

    Return:
        A callable object of TripletMarginWithDistanceLoss

    Examples:
        .. code-block:: python

            import paddle
            from paddle.nn import TripletMarginWithDistanceLoss

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='none')
            loss = triplet_margin_with_distance_loss(input, positive, negative,)
            print(loss)
            # Tensor([0.        , 0.57496738, 0.        ])

            triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='mean')
            loss = triplet_margin_with_distance_loss(input, positive, negative,)
            print(loss)
            # Tensor([0.19165580])

    """

    def __init__(self,
                 distance_function=None,
                 margin=1.0,
                 swap=False,
                 reduction: str = 'mean',
                 name=None):
        super(TripletMarginWithDistanceLoss, self).__init__()
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in TripletMarginWithDistanceLoss "
                "should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)
        self.margin = margin
        self.swap = swap
        self.reduction = reduction
        self.distance_function = distance_function
        self.name = name

    def forward(self, input, positive, negative):
        return F.triplet_margin_with_distance_loss(input,
                                                   positive,
                                                   negative,
                                                   margin=self.margin,
                                                   swap=self.swap,
                                                   reduction=self.reduction,
                                                   name=self.name)
Y
yangguohao 已提交
1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691


class TripletMarginLoss(Layer):
    r"""
    Creates a criterion that measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, *)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}


    where

    .. math::
        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p

    Parameters:
        margin (float, Optional):Default: :math:`1`.

        p (int, Optional):The norm degree for pairwise distance. Default: :math:`2`.

        epsilon (float, Optional):Add small value to avoid division by zero,
            default value is 1e-6.

        swap (bool, Optional):The distance swap change the negative distance to the distance between
            positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`.
            Default: ``False``.

        reduction (str, Optional):Indicate how to average the loss by batch_size.
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``

        name (str,Optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:
        input (Tensor):Input tensor, the data type is float32 or float64.
        the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

        positive (Tensor):Positive tensor, the data type is float32 or float64.
        The shape of label is the same as the shape of input.

        negative (Tensor):Negative tensor, the data type is float32 or float64.
        The shape of label is the same as the shape of input.

    Returns:
        Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative.

    Examples:
        .. code-block:: python

            import paddle

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            triplet_margin_loss = paddle.nn.TripletMarginLoss(reduction='none')
            loss = triplet_margin_loss(input, positive, negative)
            print(loss)
            # Tensor([0.        , 0.57496738, 0.        ])
	    
            triplet_margin_loss = paddle.nn.TripletMarginLoss(margin=1.0, swap=True, reduction='mean', )
            loss = triplet_margin_loss(input, positive, negative,)
            print(loss)
            # Tensor([0.19165580])

    """

    def __init__(self,
                 margin=1.0,
                 p=2.,
                 epsilon=1e-6,
                 swap=False,
                 reduction='mean',
                 name=None):
        super(TripletMarginLoss, self).__init__()
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in TripletMarginLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)
        self.margin = margin
        self.p = p
        self.epsilon = epsilon
        self.swap = swap
        self.reduction = reduction
        self.name = name

    def forward(self, input, positive, negative):
        return F.triplet_margin_loss(input,
                                     positive,
                                     negative,
                                     margin=self.margin,
                                     p=self.p,
                                     epsilon=self.epsilon,
                                     swap=self.swap,
                                     reduction=self.reduction,
                                     name=self.name)
1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762


class SoftMarginLoss(Layer):
    r"""
    Creates a criterion that measures a two-class soft margin loss between input predictions ``input``
    and target labels ``label`` . It can be described as:

    .. math::
        Out = log(1 + exp((-label * input)))

    Parameters:

        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.

        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shapes:

        Input (Tensor): The input tensor with shape: [N, *],
        N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf
        Available dtype is float32, float64.

        Label (Tensor): The target labels tensor with the same shape as
        ``input``. The target labels which values should be numbers -1 or 1.
        Available dtype is int32, int64, float32, float64.

        Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``input`` , else the shape of output is [1].

    Returns:
        A callable object of SoftMarginLoss.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
            label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
            soft_margin_loss = paddle.nn.SoftMarginLoss()
            output = soft_margin_loss(input, label)

            input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64)
            label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.int64)
            label_np[label_np==0]=-1
            input = paddle.to_tensor(input_np)
            label = paddle.to_tensor(label_np)
            soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none')
            output = soft_margin_loss(input, label)
    """

    def __init__(self, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(SoftMarginLoss, self).__init__()
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
        out = paddle.nn.functional.soft_margin_loss(input, label,
                                                    self.reduction, self.name)
        return out