loss.py 52.8 KB
Newer Older
1
# -*- coding: utf-8 -*
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
# TODO: define loss functions of neural network
17
import numpy as np
L
Leo Chen 已提交
18
import paddle.fluid as fluid
19
import paddle
20
from .. import functional as F
21
from paddle.fluid.framework import _varbase_creator, in_dygraph_mode, _in_legacy_dygraph
Z
zhiboniu 已提交
22
from .. import Layer
Z
zhiboniu 已提交
23
from paddle import in_dynamic_mode
24

25 26
__all__ = []

L
Leo Chen 已提交
27

Z
zhiboniu 已提交
28
class BCEWithLogitsLoss(Layer):
29
    r"""
30 31 32 33 34 35 36 37 38 39 40 41 42
    This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
    Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
    layer and some reduce operations.

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

    First this operator calculate loss function as follows:

    .. math::
43
           Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
44

45
    We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
46 47

    .. math::
48
           Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
49

50
    For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
51 52 53
    we reformulate the loss as follows:

    .. math::
54
           Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

    Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
        logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``logit``
            is usually the output of Linear layer. Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``logit``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``logit`` , else the shape of output is scalar.

    Returns:
        A callable object of BCEWithLogitsLoss.

    Examples:

        .. code-block:: python
            import paddle
            logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
            label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
            bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
            output = bce_logit_loss(logit, label)
            print(output.numpy())  # [0.45618808]

    """

    def __init__(self,
                 weight=None,
                 reduction='mean',
                 pos_weight=None,
                 name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(BCEWithLogitsLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
        self.pos_weight = pos_weight
        self.name = name

    def forward(self, logit, label):
        out = paddle.nn.functional.binary_cross_entropy_with_logits(
            logit, label, self.weight, self.reduction, self.pos_weight,
            self.name)
        return out


Z
zhiboniu 已提交
132
class CrossEntropyLoss(Layer):
133
    r"""
134
    By default, this operator implements the cross entropy loss function with softmax. This function 
135
    combines the calculation of the softmax operation and the cross entropy loss function 
136
    to provide a more numerically stable computing.
S
swtkiwi 已提交
137

138
    This operator will calculate the cross entropy loss function without softmax when use_softmax=False.
139

140 141 142
    By default, this operator will calculate the mean of the result, and you can also affect 
    the default behavior by using the reduction parameter. Please refer to the part of 
    parameters for details.
143

144 145 146
    This operator can be used to calculate the softmax cross entropy loss with soft and hard labels.
    Where, the hard labels mean the actual label value, 0, 1, 2, etc.  And the soft labels 
    mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
147

148
    The calculation of this operator includes the following two steps.
149

150
    -  **I.softmax cross entropy** 
151

152
        1. Hard label (each sample can only be assigned into one category)
153

154
        1.1. when use_softmax=True
155

156 157
            .. math::
              \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N
158

159
            where, N is the number of samples and C is the number of categories.
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        1.2. when use_softmax=False

            .. math::
              \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).


        2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1).

        2.1. when use_softmax=True

            .. math::
              \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories.

        2.2. when use_softmax=False

            .. math::
              \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).



    -  **II.Weight and reduction processing** 

        1. Weight

            If the ``weight`` parameter is ``None`` , go to the next step directly.

            If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight
            according to soft_label = False or True as follows.

            1.1. Hard labels (soft_label = False)

            .. math::
                \\loss_j=loss_j*weight[label_j] 
200

201

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
            1.2. Soft labels (soft_label = True)

             .. math::
                \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right)

        2. reduction

            2.1 if the ``reduction`` parameter is ``none`` 

            Return the previous result directly

            2.2 if the ``reduction`` parameter is ``sum`` 

            Return the sum of the previous results

            .. math::
               \\loss=\sum_{j}loss_j

            2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to 
            the ``weight`` parameter as follows. 

            2.3.1. If the  ``weight``  parameter is ``None`` 

            Return the average value of the previous results

             .. math::
                \\loss=\sum_{j}loss_j/N

            where, N is the number of samples and C is the number of categories.

            2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned

            1. Hard labels (soft_label = False)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] 

            2. Soft labels (soft_label = True)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right)
 
 
245
    Parameters:
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262

        - **weight** (Tensor, optional)

            a manual rescaling weight given to each class. 
            If given, has to be a Tensor of size C and the data type is float32, float64. 
            Default is ``'None'`` .

        - **ignore_index** (int64, optional)

            Specifies a target value that is ignored
            and does not contribute to the loss. A negative value means that no label 
            value needs to be ignored. Only valid when soft_label = False.  
            Default is ``-100`` .

        - **reduction** (str, optional)

            Indicate how to average the loss by batch_size,
263 264 265 266 267
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
268

269
        - **soft_label** (bool, optional)
270

271 272 273
            Indicate whether label is soft. 
            If soft_label=False, the label is hard.  If soft_label=True, the label is soft.
            Default is ``False``.
274

275 276 277 278 279 280 281 282 283 284 285 286
        - **axis** (int, optional)

            The index of dimension to perform softmax calculations. 
            It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the number 
            of dimensions of input :attr:`input`. 
            Default is ``-1`` .

        - **use_softmax** (bool, optional)

            Indicate whether compute softmax before cross_entropy.
            Default is ``True``.

Z
zhiboniu 已提交
287
        - **name** (str, optional)
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309

            The name of the operator. Default is ``None`` .
            For more information, please refer to :ref:`api_guide_Name` .


    Shape:

        - **input** (Tensor)

            Input tensor, the data type is float32, float64. Shape is
	    :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes ,  ``k >= 1`` . 

            Note: 

                1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the 
                output of softmax operator, which will produce incorrect results.

                2. when use_softmax=False, it expects the output of softmax operator.
 

        - **label** (Tensor)

Z
zhiboniu 已提交
310
            1. If soft_label=False, the shape is 
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
            :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1.
            the data type is int32, int64, float32, float64, where each value is [0, C-1].

            2. If soft_label=True, the shape and data type should be same with ``input`` , 
            and the sum of the labels for each sample should be 1.
 
        - **output** (Tensor)

            Return the softmax cross_entropy loss of ``input`` and ``label``.

            The data type is the same as input.

            If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``.

            If :attr:`reduction` is ``'none'``:

            1. If soft_label = False, the dimension of return value is the same with ``label`` . 

            2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . 

     Example1(hard labels):

        .. code-block:: python
            
            import paddle
            paddle.seed(99999)
            N=100
            C=200
            reduction='mean'
            input =  paddle.rand([N, C], dtype='float64')  
            label =  paddle.randint(0, C, shape=[N], dtype='int64')
            weight = paddle.rand([C], dtype='float64') 
            
            cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
                weight=weight, reduction=reduction)
            dy_ret = cross_entropy_loss(
                                       input,
                                       label)
            print(dy_ret.numpy()) #[5.41993642]


    Example2(soft labels):
353

354
        .. code-block:: python
C
Chen Long 已提交
355
            
356
            import paddle
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
            paddle.seed(99999)
            axis = -1
            ignore_index = -100
            N = 4
            C = 3
            shape = [N, C]
            reduction='mean'
            weight = None
            logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels /= paddle.sum(labels, axis=axis, keepdim=True)
            paddle_loss_mean = paddle.nn.functional.cross_entropy(
                                                                  logits,  
                                                                  labels, 
                                                                  soft_label=True, 
                                                                  axis=axis,
                                                                  weight=weight,
                                                                  reduction=reduction)
            print(paddle_loss_mean.numpy()) #[1.12908343]

377 378
    """

379 380 381 382 383 384
    def __init__(self,
                 weight=None,
                 ignore_index=-100,
                 reduction='mean',
                 soft_label=False,
                 axis=-1,
385
                 use_softmax=True,
386
                 name=None):
387 388 389
        super(CrossEntropyLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
390
        self.ignore_index = ignore_index
391 392
        self.soft_label = soft_label
        self.axis = axis
393
        self.use_softmax = use_softmax
394
        self.name = name
395 396

    def forward(self, input, label):
397
        ret = paddle.nn.functional.cross_entropy(
398 399
            input,
            label,
400
            weight=self.weight,
401
            ignore_index=self.ignore_index,
402 403 404
            reduction=self.reduction,
            soft_label=self.soft_label,
            axis=self.axis,
405
            use_softmax=self.use_softmax,
406 407 408
            name=self.name)

        return ret
409 410


Z
zhiboniu 已提交
411
class HSigmoidLoss(Layer):
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
    """
    Hierarchical Sigmoid Layer.
    
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
    Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
    represents the number of classes or the size of word dict.

    The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>_`. For the custom
    tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        feature_size (int): The number of features.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
            should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
            :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        weight_attr (ParamAttr, optional): The parameter attribute for the learnable weights
            of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
            ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
            initialized with Xavier. Default is None.
        bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
            is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
            hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
            set, the bias is initialized zero. Default is None.
        is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and 
            `path_code` should be passed to its forward method, otherwise `path_table` and `path_code`
            should not be passed to its forward method. Default is False.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True,
            the gradient of weight and input will be sparse. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input (Tensor): The input tensor. The shapes is [N, D], where N is batch size and D is feature size. It's data type should be float32, float64.
        label (Tensor): It's shapes is [N, 1]. It's data type should be int64.
        output (Tensor): The HSigmoid Loss of ``input`` and ``label``. Shape is [N, 1]

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_device('cpu')

L
Linjie Chen 已提交
468 469 470 471 472
            input = paddle.uniform([4, 3])
            # [[0.56194401  -0.22450298  -0.10741806] # random
            #  [0.36136317  0.23556745  0.88748658] # random
            #  [0.18151939  0.80947340  -0.31078976] # random
            #  [0.68886101  -0.14239830  -0.41297770]] # random
473 474 475
            label = paddle.to_tensor([0, 1, 4, 5])
            m = paddle.nn.HSigmoidLoss(3, 5)
            out = m(input, label)
L
Linjie Chen 已提交
476 477 478 479
            # [[2.42524505]
            #  [1.74917245]
            #  [3.14571381]
            #  [2.34564662]]
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
    """

    def __init__(self,
                 feature_size,
                 num_classes,
                 weight_attr=None,
                 bias_attr=None,
                 is_custom=False,
                 is_sparse=False,
                 name=None):
        super(HSigmoidLoss, self).__init__()
        if (num_classes < 2) and (not is_custom):
            raise ValueError(
                "num_classes must not be less than 2 with default tree")

        if (not is_custom) and (is_sparse):
            print("Sparse mode should not be used without custom tree")
            is_sparse = False

        self._feature_size = feature_size
        self._num_classes = num_classes
        self._is_custom = is_custom
        self._is_sparse = is_sparse

        self._weight_attr = weight_attr
        self._bias_attr = bias_attr

        self._name = name
        self._dtype = paddle.get_default_dtype()

        remote_prefetch = is_sparse
        print("With sparse mode, if your models has only"
              " small parameter prefetch may cause speed down")

        C = self._num_classes if is_custom else self._num_classes - 1
        self.weight = self.create_parameter(
            [C, self._feature_size],
            attr=self._weight_attr,
            is_bias=False,
            dtype=self._dtype)
        self.bias = self.create_parameter(
            [C, 1], attr=self._bias_attr, is_bias=True, dtype=self._dtype)

    def forward(self, input, label, path_table=None, path_code=None):
        out = F.hsigmoid_loss(
            input,
            label,
            self._num_classes,
            self.weight,
            self.bias,
            path_table=path_table,
            path_code=path_code,
            is_sparse=self._is_sparse,
            name=self._name)
        return out


Z
zhiboniu 已提交
537
class MSELoss(Layer):
538
    r"""
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
    **Mean Square Error Loss**
    Computes the mean square error (squared L2 norm) of given input and label.

    If :attr:`reduction` is set to ``'none'``, loss is calculated as:

    .. math::
        Out = (input - label)^2

    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:

    .. math::
        Out = \operatorname{mean}((input - label)^2)

    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:

    .. math::
        Out = \operatorname{sum}((input - label)^2)

557
    where `input` and `label` are `float32` tensors of same shape.
558 559 560 561

    Parameters:
        reduction (string, optional): The reduction method for the output,
            could be 'none' | 'mean' | 'sum'.
562 563 564
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
565 566
            Default is ``'mean'``.

B
Bai Yifan 已提交
567 568 569 570
    Shape:
        input (Tensor): Input tensor, the data type is float32 or float64
        label (Tensor): Label tensor, the data type is float32 or float64
        output (Tensor): output tensor storing the MSE loss of input and label, the data type is same as input.
571 572 573

    Examples:
        .. code-block:: python
574 575 576 577 578 579 580

            import numpy as np
            import paddle

            input_data = np.array([1.5]).astype("float32")
            label_data = np.array([1.7]).astype("float32")

B
Bai Yifan 已提交
581 582 583 584
            mse_loss = paddle.nn.loss.MSELoss()
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
            output = mse_loss(input, label)
585
            print(output)
B
Bai Yifan 已提交
586
            # [0.04000002]
587 588 589 590 591 592 593 594 595 596 597
    """

    def __init__(self, reduction='mean'):
        super(MSELoss, self).__init__()
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', "
                "but received {}.".format(reduction))
        self.reduction = reduction

    def forward(self, input, label):
Z
zhiboniu 已提交
598
        if not in_dynamic_mode():
B
Bai Yifan 已提交
599 600 601 602
            fluid.data_feeder.check_variable_and_dtype(
                input, 'input', ['float32', 'float64'], 'MSELoss')
            fluid.data_feeder.check_variable_and_dtype(
                label, 'label', ['float32', 'float64'], 'MSELoss')
603

604 605 606 607 608
        if in_dygraph_mode():
            square_out = paddle._C_ops.final_state_square(
                paddle.subtract(input, label))
        else:
            square_out = paddle.square(paddle.subtract(input, label))
609 610 611 612 613 614 615 616 617 618
        if self.reduction == 'none':
            return square_out

        reduce_op = 'reduce_mean'
        if self.reduction == 'sum':
            reduce_op = 'reduce_sum'

        return getattr(fluid.layers, reduce_op)(square_out)


Z
zhiboniu 已提交
619
class L1Loss(Layer):
620
    r"""
L
Leo Chen 已提交
621
    This interface is used to construct a callable object of the ``L1Loss`` class.
622
    The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows.
623

624
     If `reduction` set to ``'none'``, the loss is:
L
Leo Chen 已提交
625 626

    .. math::
627
        Out = \lvert input - label\rvert
628

629
    If `reduction` set to ``'mean'``, the loss is:
630

L
Leo Chen 已提交
631
    .. math::
632
        Out = MEAN(\lvert input - label\rvert)
633

634
    If `reduction` set to ``'sum'``, the loss is:
635

L
Leo Chen 已提交
636
    .. math::
637
        Out = SUM(\lvert input - label\rvert)
L
Leo Chen 已提交
638

639

L
Leo Chen 已提交
640
    Parameters:
641
        reduction (str, optional): Indicate the reduction to apply to the loss,
L
Leo Chen 已提交
642
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
643 644 645
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
L
Leo Chen 已提交
646
            Default is ``'mean'``.
647 648 649
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
650 651
        input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
652
        output (Tensor): The L1 Loss of ``input`` and ``label``.
653 654
            If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` .
            If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
655

L
Leo Chen 已提交
656 657
    Examples:
        .. code-block:: python
C
Chen Long 已提交
658
            
L
Leo Chen 已提交
659
            import paddle
660
            import numpy as np
661

662 663 664 665
            input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
            label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
666

C
Chen Long 已提交
667
            l1_loss = paddle.nn.L1Loss()
668
            output = l1_loss(input, label)
669
            print(output.numpy())
670 671
            # [0.35]

C
Chen Long 已提交
672
            l1_loss = paddle.nn.L1Loss(reduction='sum')
673
            output = l1_loss(input, label)
674
            print(output.numpy())
675 676
            # [1.4]

C
Chen Long 已提交
677
            l1_loss = paddle.nn.L1Loss(reduction='none')
678
            output = l1_loss(input, label)
C
Chen Long 已提交
679
            print(output)
680
            # [[0.20000005 0.19999999]
681
            # [0.2        0.79999995]]
L
Leo Chen 已提交
682 683
    """

684
    def __init__(self, reduction='mean', name=None):
L
Leo Chen 已提交
685 686 687 688 689 690
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)
        super(L1Loss, self).__init__()
        self.reduction = reduction
691
        self.name = name
L
Leo Chen 已提交
692

693
    def forward(self, input, label):
694
        return paddle.nn.functional.l1_loss(
695
            input, label, self.reduction, name=self.name)
C
ceci3 已提交
696 697


Z
zhiboniu 已提交
698
class BCELoss(Layer):
C
ceci3 已提交
699
    """
C
ceci3 已提交
700
    This interface is used to construct a callable object of the ``BCELoss`` class.
701 702
    The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
    and target labels ``label`` . The binary_cross_entropy loss can be described as:
C
ceci3 已提交
703

C
ceci3 已提交
704
    If :attr:`weight` is set, the loss is:
C
ceci3 已提交
705 706

    .. math::
C
ceci3 已提交
707
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
708

C
ceci3 已提交
709
    If :attr:`weight` is None, the loss is:
C
ceci3 已提交
710 711

    .. math::
C
ceci3 已提交
712 713
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

714
    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
C
ceci3 已提交
715

C
ceci3 已提交
716
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
C
ceci3 已提交
717

C
ceci3 已提交
718 719
    .. math::
        Out = MEAN(Out)
720

C
ceci3 已提交
721
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
C
ceci3 已提交
722

C
ceci3 已提交
723 724
    .. math::
        Out = SUM(Out)
C
ceci3 已提交
725

726
    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
C
ceci3 已提交
727 728
    should be numbers between 0 and 1.

C
ceci3 已提交
729
    Parameters:
730 731
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
C
ceci3 已提交
732
            is float32, float64. Default is ``'None'``.
733
        reduction (str, optional): Indicate how to average the loss by batch_size,
C
ceci3 已提交
734
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
C
ceci3 已提交
735
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
736
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
C
ceci3 已提交
737
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
C
ceci3 已提交
738
            Default is ``'mean'``.
739 740 741 742
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
Z
Zhong Hui 已提交
743
        input (Tensor): 2-D tensor with shape: [N, *], N is batch_size, `*` means
744 745 746 747 748 749 750
            number of additional dimensions. The input ``input`` should always
            be the output of sigmod.  Available dtype is float32, float64.
        label (Tensor): 2-D tensor with the same shape as ``input``. The target
            labels which values should be numbers between 0 and 1. Available
            dtype is float32, float64.
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``input`` , else the shape of output is scalar.
C
ceci3 已提交
751

752
    Returns:
C
ceci3 已提交
753 754
        A callable object of BCELoss.

C
ceci3 已提交
755 756
    Examples:
        .. code-block:: python
C
ceci3 已提交
757

C
ceci3 已提交
758 759 760 761
            import numpy as np
            import paddle
            input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
            label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
762

Z
Zhong Hui 已提交
763 764
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
C
Chen Long 已提交
765
            bce_loss = paddle.nn.BCELoss()
766
            output = bce_loss(input, label)
C
Chen Long 已提交
767
            print(output)  # [0.65537095]
768

C
ceci3 已提交
769 770
    """

771
    def __init__(self, weight=None, reduction='mean', name=None):
C
ceci3 已提交
772 773 774 775 776 777 778 779
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(BCELoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
780
        self.name = name
C
ceci3 已提交
781 782

    def forward(self, input, label):
783 784 785
        out = paddle.nn.functional.binary_cross_entropy(
            input, label, self.weight, self.reduction, self.name)
        return out
786 787


Z
zhiboniu 已提交
788
class NLLLoss(Layer):
789
    r"""
S
swtkiwi 已提交
790

791
    This class accepts input and target label and returns negative log likelihood
792
    cross error. It is useful to train a classification problem with C classes.
793

794
    The input for the loss is epected to contain log-probabilities of
795
    each classes. It has to be a Tensor of size either (batch_size, C) or
796 797 798 799
    (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
    The label for the loss should be a class index in the range [0, C-1]
    where C is the number of classes. If ignore_index is specified, the
    specified target value does not contribute to the input gradient.
800

801 802 803
    If the optional argument `weight` is provided, it should be a 1D Tensor
    assigning weight to each of the classed. This is particularly useful
    when you have an unbalanced training set.
804

805 806 807 808
    The loss is calculated as follows.
    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
809 810

        \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
811
        l_n = - w_{y_n} x_{n,y_n}, \quad
812
        w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\},
813 814 815 816 817

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then

    .. math::
818 819 820 821 822 823 824 825 826 827

        \ell(x, y) =
        \left\{
            \begin{array}{lcl}
            \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
            \text{if  reduction} = \text{'mean';}\\
            \sum_{n=1}^N l_n,  &
            \text{if  reduction} = \text{'sum'.}
            \end{array}
        \right.
828 829

    Parameters:
830 831
        weight (Tensor, optional): Weight tensor, a manual rescaling weight given
            to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
832
            it treated as if having all ones. the data type is
833
            float32, float64, Default is ``'None'``.
834 835
        ignore_index (int64, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient.
836
        reduction (str, optional): Indicate how to average the loss,
837
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
838 839 840
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
841
            Default is ``'mean'``.
842 843
         name (str, optional): Name for the operation (optional, default is None).
             For more information, please refer to :ref:`api_guide_Name`.
844

845 846 847 848 849 850 851 852 853
    Shape:
        input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
            But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
            The data type is float32, float64.
        label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
            The data type is int64.
        output (Tensor): the `negative log likelihood loss` between input `x` and `label`.
            If `reduction` is `'none'`, the shape is `[N, *]`.
            If `reduction` is `'sum'` or `'mean'`, the shape is `[1]`.
854 855 856 857

    Examples:
        .. code-block:: python

858
                import paddle
859

860
                nll_loss = paddle.nn.loss.NLLLoss()
861
                log_softmax = paddle.nn.LogSoftmax(axis=1)
862

863 864 865 866 867
                input = paddle.to_tensor([[0.88103855, 0.9908683 , 0.6226845 ],
                                          [0.53331435, 0.07999352, 0.8549948 ],
                                          [0.25879037, 0.39530203, 0.698465  ],
                                          [0.73427284, 0.63575995, 0.18827209],
                                          [0.05689114, 0.0862954 , 0.6325046 ]], "float32")
868
                log_out = log_softmax(input)
869
                label = paddle.to_tensor([0, 2, 1, 1, 0], "int64")
870
                result = nll_loss(log_out, label)
871
                print(result) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [1.07202101])
872

873
    """
874

875 876 877 878 879 880
    def __init__(self,
                 weight=None,
                 ignore_index=-100,
                 reduction='mean',
                 name=None):
        if reduction not in ['sum', 'mean', 'none']:
881
            raise ValueError(
882 883 884 885 886 887 888
                "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
                "'none', but received %s, which is not allowed." % reduction)
        super(NLLLoss, self).__init__()
        self._weight = weight
        self._ignore_index = ignore_index
        self._reduction = reduction
        self._name = name
889

890 891 892 893 894 895 896 897
    def forward(self, input, label):
        return F.nll_loss(
            input,
            label,
            weight=self._weight,
            ignore_index=self._ignore_index,
            reduction=self._reduction,
            name=self._name)
898 899


Z
zhiboniu 已提交
900
class KLDivLoss(Layer):
901
    r"""
902 903 904 905 906 907 908 909 910
    This interface calculates the Kullback-Leibler divergence loss
    between Input(X) and Input(Target). Notes that Input(X) is the
    log-probability and Input(Target) is the probability.

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

    Parameters:
L
LielinJiang 已提交
911 912 913 914 915 916 917
        reduction (Tensor): Indicate how to average the loss,
             the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
             If `reduction` is ``'mean'``, the reduced mean loss is returned;
             If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
             if `reduction` is ``'sum'``, the reduced sum loss is returned;
             if `reduction` is ``'none'``, no reduction will be apllied.
             Default is ``'mean'``.
918 919

    Shape:
920 921 922 923 924 925

        - input (Tensor): (N, *), where * means, any number of additional dimensions.

        - label (Tensor): (N, *), same shape as input.

        - output (Tensor): tensor with shape: [1] by default.
926 927 928 929 930 931 932 933


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            import paddle.nn as nn
934

935 936 937 938
            shape = (5, 20)
            x = np.random.uniform(-10, 10, shape).astype('float32')
            target = np.random.uniform(-10, 10, shape).astype('float32')

L
LielinJiang 已提交
939
            # 'batchmean' reduction, loss shape will be [1]
940
            kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
941 942
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
L
LielinJiang 已提交
943
            # shape=[1]
944

945 946
            # 'mean' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='mean')
947 948
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
949 950 951 952
            # shape=[1]

            # 'sum' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='sum')
953 954
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
955 956 957 958
            # shape=[1]

            # 'none' reduction, loss shape is same with X shape
            kldiv_criterion = nn.KLDivLoss(reduction='none')
959 960
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
961 962 963 964 965 966 967 968
            # shape=[5, 20]
    """

    def __init__(self, reduction='mean'):
        super(KLDivLoss, self).__init__()
        self.reduction = reduction

    def forward(self, input, label):
L
LielinJiang 已提交
969
        out = F.kl_div(input, label, self.reduction)
970 971 972
        return out


Z
zhiboniu 已提交
973
class MarginRankingLoss(Layer):
974
    r"""
975 976

    This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
977
    The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
978 979
    , use the math function as follows.

980
    .. math::
981
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

1000
    Shape:
N
Noel 已提交
1001 1002 1003
    
        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

1004
        other: N-D Tensor, `other` have the same shape and dtype as `input`.
N
Noel 已提交
1005

1006
        label: N-D Tensor, label have the same shape and dtype as `input`.
N
Noel 已提交
1007

1008
        output: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
1009 1010 1011 1012 1013 1014 1015 1016

    Returns:
        A callable object of MarginRankingLoss.

    Examples:

        .. code-block:: python

1017 1018
            import paddle

C
Chen Long 已提交
1019 1020
            input = paddle.to_tensor([[1, 2], [3, 4]], dtype="float32")
            other = paddle.to_tensor([[2, 1], [2, 4]], dtype="float32")
Z
Zhong Hui 已提交
1021
            label = paddle.to_tensor([[1, -1], [-1, -1]], dtype="float32")
1022
            margin_rank_loss = paddle.nn.MarginRankingLoss()
1023
            loss = margin_rank_loss(input, other, label)
1024 1025 1026

            print(loss)
            # [0.75]
1027 1028 1029 1030 1031
    """

    def __init__(self, margin=0.0, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
1032
                "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
1033 1034 1035 1036 1037 1038
                "received %s, which is not allowed." % reduction)
        super(MarginRankingLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.name = name

1039
    def forward(self, input, other, label):
1040
        out = paddle.nn.functional.margin_ranking_loss(
1041
            input, other, label, self.margin, self.reduction, self.name)
1042
        return out
1043 1044


Z
zhiboniu 已提交
1045
class CTCLoss(Layer):
1046 1047
    """

1048 1049 1050
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
1051 1052 1053 1054 1055 1056 1057
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.

    Shape:
1058
        log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
1059 1060 1061
        labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
1062
        norm_by_times (bool, default false) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
1063 1064 1065

    Returns:
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
1066

1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
    Examples:

        .. code-block:: python

            # declarative mode
            import numpy as np
            import paddle

            # length of the longest logit sequence
            max_seq_length = 4
            #length of the longest label sequence
            max_label_length = 3
            # number of logit sequences
            batch_size = 2
            # class num
            class_num = 3

            np.random.seed(1)
            log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
                                    [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

                                    [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
                                    [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

                                    [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
                                    [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

                                    [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
                                    [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

                                    [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
                                    [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
            labels = np.array([[1, 2, 2],
                            [1, 2, 2]]).astype("int32")
            input_lengths = np.array([5, 5]).astype("int64")
            label_lengths = np.array([3, 3]).astype("int64")

1104 1105 1106 1107
            log_probs = paddle.to_tensor(log_probs)
            labels = paddle.to_tensor(labels)
            input_lengths = paddle.to_tensor(input_lengths)
            label_lengths = paddle.to_tensor(label_lengths)
1108

1109 1110
            loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
                input_lengths,
1111
                label_lengths)
1112
            print(loss)  #[3.9179852 2.9076521]
1113

1114 1115
            loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
                input_lengths,
1116
                label_lengths)
1117
            print(loss)  #[1.1376063]
1118 1119 1120 1121 1122 1123 1124
    """

    def __init__(self, blank=0, reduction='mean'):
        super(CTCLoss, self).__init__()
        self.blank = blank
        self.reduction = reduction

1125 1126 1127 1128 1129
    def forward(self,
                log_probs,
                labels,
                input_lengths,
                label_lengths,
H
Hui Zhang 已提交
1130
                norm_by_times=False):
1131 1132 1133 1134 1135 1136 1137
        return paddle.nn.functional.ctc_loss(
            log_probs,
            labels,
            input_lengths,
            label_lengths,
            self.blank,
            self.reduction,
H
Hui Zhang 已提交
1138
            norm_by_times=norm_by_times)
1139 1140


Z
zhiboniu 已提交
1141
class SmoothL1Loss(Layer):
1142
    r"""
1143 1144 1145 1146 1147 1148 1149
    This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

1150
         loss(x,y) = \frac{1}{n}\sum_{i}z_i
1151 1152 1153 1154 1155

    where z_i is given by:

    .. math::

1156 1157
        \mathop{z_i} = \left\{\begin{array}{rcl}
        0.5(x_i - y_i)^2 & & {if |x_i - y_i| < delta} \\
1158
        delta * |x_i - y_i| - 0.5 * delta^2 & & {otherwise}
1159
        \end{array} \right.
1160 1161 1162 1163 1164 1165 1166 1167

    Parameters:
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
1168
        delta (float, optional): Specifies the hyperparameter delta to be used.
1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
            negative/zero values. Default = 1.0
        name (str, optional): Name for the operation (optional, default is
            None). For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is
            (N, C), where C is number of classes, and if shape is more than 2D, this
            is (N, C, D1, D2,..., Dk), k >= 1.
        label (Tensor): Label tensor, the data type is float32 or float64. The shape of label
            is the same as the shape of input.

    Returns:
G
Guanghua Yu 已提交
1183
        The tensor storing the smooth_l1_loss of input and label.
1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197

    Return type: Tensor.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            input_data = np.random.rand(3,3).astype("float32")
            label_data = np.random.rand(3,3).astype("float32")
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
            loss = paddle.nn.SmoothL1Loss()
            output = loss(input, label)
G
Guanghua Yu 已提交
1198
            print(output)
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213
    """

    def __init__(self, reduction='mean', delta=1.0, name=None):
        super(SmoothL1Loss, self).__init__()
        self.reduction = reduction
        self.delta = delta
        self.name = name

    def forward(self, input, label):
        return F.smooth_l1_loss(
            input,
            label,
            reduction=self.reduction,
            delta=self.delta,
            name=self.name)
1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308


class HingeEmbeddingLoss(Layer):
    r"""
    This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
    This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
    and is typically used for learning nonlinear embeddings or semi-supervised learning.

    The loss function for :math:`n`-th sample in the mini-batch is

    .. math::
        l_n = \begin{cases}
            x_n, & \text{if}\; y_n = 1,\\
            \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
        \end{cases}

    and the total loss functions is

    .. math::
        \ell(x, y) = \begin{cases}
            \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
        \end{cases}

    where :math:`L = \{l_1,\dots,l_N\}^\top`.

    Parameters:

        margin (float, optional): Specifies the hyperparameter margin to be used.
            The value determines how large the input need to be to calculate in
            hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss.
            Default = 1.0
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:

        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.

        label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.

    Shape:

        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

        label: N-D Tensor, same shape as the input.

        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

    Returns:

        Tensor, The tensor variable storing the hinge_embedding_loss of input and label.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            # label elements in {1., -1.}
            label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

            hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='none')
            loss = hinge_embedding_loss(input, label)
            print(loss)
            # Tensor([[0., -2., 0.],
            #         [0., -1., 2.],
            #         [1., 1., 1.]])

            hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=1.0, reduction='mean')
            loss = hinge_embedding_loss(input, label)
            print(loss)
            # Tensor([0.22222222])
    """

    def __init__(self, margin=1.0, reduction="mean", name=None):
        super(HingeEmbeddingLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.name = name

    def forward(self, input, label):
        return F.hinge_embedding_loss(
            input,
            label,
            reduction=self.reduction,
            margin=self.margin,
            name=self.name)