loss.py 38.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define loss functions of neural network
16
import numpy as np
L
Leo Chen 已提交
17
import paddle.fluid as fluid
18
import paddle.fluid.core as core
19
import paddle
20
from .. import functional as F
21
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
22

L
Leo Chen 已提交
23
__all__ = [
24
    'BCEWithLogitsLoss',
25
    'CrossEntropyLoss',
26
    'MSELoss',
L
Leo Chen 已提交
27
    'L1Loss',
28
    'NLLLoss',
29
    'BCELoss',
30
    'KLDivLoss',
31
    'MarginRankingLoss',
32
    'CTCLoss',
33
    'SmoothL1Loss',
L
Leo Chen 已提交
34 35 36
]


37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
class BCEWithLogitsLoss(fluid.dygraph.Layer):
    """
    This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
    Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
    layer and some reduce operations.

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

    First this operator calculate loss function as follows:

    .. math::
           Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit))

    We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get:

    .. math::
           Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit})

    For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0,
    we reformulate the loss as follows:

    .. math::
           Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|})

    Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shapes:
        logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``logit``
            is usually the output of Linear layer. Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``logit``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``logit`` , else the shape of output is scalar.

    Returns:
        A callable object of BCEWithLogitsLoss.

    Examples:

        .. code-block:: python
            import paddle
            paddle.disable_static()
            logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
            label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
            bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
            output = bce_logit_loss(logit, label)
            print(output.numpy())  # [0.45618808]

    """

    def __init__(self,
                 weight=None,
                 reduction='mean',
                 pos_weight=None,
                 name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(BCEWithLogitsLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
        self.pos_weight = pos_weight
        self.name = name

    def forward(self, logit, label):
        out = paddle.nn.functional.binary_cross_entropy_with_logits(
            logit, label, self.weight, self.reduction, self.pos_weight,
            self.name)
        return out


142 143
class CrossEntropyLoss(fluid.dygraph.Layer):
    """
144 145
	:alias_main: paddle.nn.CrossEntropyLoss
	:alias: paddle.nn.CrossEntropyLoss,paddle.nn.layer.CrossEntropyLoss,paddle.nn.layer.loss.CrossEntropyLoss
S
swtkiwi 已提交
146

147 148
    This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``,
    and ``NLLLoss`` together.
149

150 151
    It is useful when training a classification problem with ``C`` classes.
    If provided, the optional argument ``weight`` should be a 1D Variable assigning
152 153 154
    weight to each of the classes.

    For predictions label, and target label, the loss is calculated as follows.
155

156 157 158 159 160
    .. math::

        loss_j =  -\\text{input[class]} +
        \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K

161 162
    If weight is not ``None``:

163 164 165 166 167 168
    .. math::

        loss_j =  \\text{weight[class]}(-\\text{input[class]} +
        \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K

    Parameters:
169 170
        input (Variable): Input tensor, the data type is float32, float64. Shape is
	    (N, C), where C is number of classes, and if shape is more than 2D, this
171 172
	    is (N, C, D1, D2,..., Dk), k >= 1.
        label (Variable): Label tensor, the data type is int64. Shape is (N), where each
173 174
	    value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
	    (N, D1, D2,..., Dk), k >= 1.
175
        weight (Variable, optional): Weight tensor, a manual rescaling weight given
176 177
            to each class and the shape is (C). It has the same dimensions as class
	    number and the data type is float32, float64. Default is ``'None'``.
178 179 180 181 182 183
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
184 185
        ignore_index (int64, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient. Default is ``-100``.
186

187 188
    Returns:
        The tensor variable storing the cross_entropy_loss of input and label.
189

190
    Return type: Variable.
191

192 193 194 195 196 197 198 199
    Examples:
        .. code-block:: python

            # declarative mode
            import paddle
            import paddle.fluid as fluid
            import numpy as np

200 201 202
            input = fluid.data(name='input', shape=[5, 100], dtype='float64')
            label = fluid.data(name='label', shape=[5], dtype='int64')
            weight = fluid.data(name='weight', shape=[100], dtype='float64')
203
            ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean')
204
            output = ce_loss(input, label)
205 206 207
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
208 209 210
            input_data = np.random.random([5, 100]).astype("float64")
            label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
            weight_data = np.random.random([100]).astype("float64")
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
            output = exe.run(fluid.default_main_program(),
                        feed={"input": input_data, "label": label_data,"weight": weight_data},
                        fetch_list=[output],
                        return_numpy=True)
            print(output)

            # imperative mode
            import paddle.fluid.dygraph as dg
            with dg.guard(place) as g:
                input = dg.to_variable(input_data)
                label = dg.to_variable(label_data)
                weight = dg.to_variable(weight_data)
                ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean')
                output = ce_loss(input, label)
                print(output.numpy())
    """

228
    def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
229 230 231
        super(CrossEntropyLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
232
        self.ignore_index = ignore_index
233 234 235

    def forward(self, input, label):
        fluid.data_feeder.check_variable_and_dtype(
236 237 238
            input, 'input', ['float32', 'float64'], 'cross_entropy_loss')
        fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
                                                   'cross_entropy_loss')
239 240 241

        if self.reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
242 243 244 245
                "The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
                " 'none', but received %s, which is not allowed." %
                self.reduction)

246 247 248
        return paddle.nn.functional.cross_entropy(
            input,
            label,
249
            weight=self.weight,
250 251
            ignore_index=self.ignore_index,
            reduction=self.reduction)
252 253


254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
class MSELoss(fluid.dygraph.layers.Layer):
    """
    **Mean Square Error Loss**
    Computes the mean square error (squared L2 norm) of given input and label.

    If :attr:`reduction` is set to ``'none'``, loss is calculated as:

    .. math::
        Out = (input - label)^2

    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:

    .. math::
        Out = \operatorname{mean}((input - label)^2)

    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:

    .. math::
        Out = \operatorname{sum}((input - label)^2)

274
    where `input` and `label` are `float32` tensors of same shape.
275 276 277 278

    Parameters:
        reduction (string, optional): The reduction method for the output,
            could be 'none' | 'mean' | 'sum'.
279 280 281
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
282 283
            Default is ``'mean'``.

B
Bai Yifan 已提交
284 285 286 287
    Shape:
        input (Tensor): Input tensor, the data type is float32 or float64
        label (Tensor): Label tensor, the data type is float32 or float64
        output (Tensor): output tensor storing the MSE loss of input and label, the data type is same as input.
288 289 290

    Examples:
        .. code-block:: python
291 292 293 294 295 296 297

            import numpy as np
            import paddle

            input_data = np.array([1.5]).astype("float32")
            label_data = np.array([1.7]).astype("float32")

B
Bai Yifan 已提交
298 299 300 301 302 303 304
            paddle.disable_static()
            mse_loss = paddle.nn.loss.MSELoss()
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
            output = mse_loss(input, label)
            print(output.numpy())
            # [0.04000002]
305 306 307 308 309 310 311 312 313 314 315 316
    """

    def __init__(self, reduction='mean'):
        super(MSELoss, self).__init__()
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', "
                "but received {}.".format(reduction))
        self.reduction = reduction

    def forward(self, input, label):
        if not fluid.framework.in_dygraph_mode():
B
Bai Yifan 已提交
317 318 319 320
            fluid.data_feeder.check_variable_and_dtype(
                input, 'input', ['float32', 'float64'], 'MSELoss')
            fluid.data_feeder.check_variable_and_dtype(
                label, 'label', ['float32', 'float64'], 'MSELoss')
321 322 323 324 325 326 327 328 329 330 331 332 333

        square_out = fluid.layers.square(
            fluid.layers.elementwise_sub(input, label))
        if self.reduction == 'none':
            return square_out

        reduce_op = 'reduce_mean'
        if self.reduction == 'sum':
            reduce_op = 'reduce_sum'

        return getattr(fluid.layers, reduce_op)(square_out)


L
Leo Chen 已提交
334 335 336
class L1Loss(fluid.dygraph.Layer):
    """
    This interface is used to construct a callable object of the ``L1Loss`` class.
337
    The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows.
338

339
     If `reduction` set to ``'none'``, the loss is:
L
Leo Chen 已提交
340 341

    .. math::
342
        Out = \lvert input - label\rvert
343

344
    If `reduction` set to ``'mean'``, the loss is:
345

L
Leo Chen 已提交
346
    .. math::
347
        Out = MEAN(\lvert input - label\rvert)
348

349
    If `reduction` set to ``'sum'``, the loss is:
350

L
Leo Chen 已提交
351
    .. math::
352
        Out = SUM(\lvert input - label\rvert)
L
Leo Chen 已提交
353

354

L
Leo Chen 已提交
355
    Parameters:
356
        reduction (str, optional): Indicate the reduction to apply to the loss,
L
Leo Chen 已提交
357
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
358 359 360
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
L
Leo Chen 已提交
361
            Default is ``'mean'``.
362 363 364
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
365 366
        input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
367
        output (Tensor): The L1 Loss of ``input`` and ``label``.
368 369
            If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` .
            If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
370

L
Leo Chen 已提交
371 372 373
    Examples:
        .. code-block:: python
            import paddle
374 375 376
            import numpy as np

            paddle.disable_static()
377
            input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
378
            label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
379
            input = paddle.to_variable(input_data)
380 381 382
            label = paddle.to_variable(label_data)

            l1_loss = paddle.nn.loss.L1Loss()
383
            output = l1_loss(input, label)
384
            print(output.numpy())
385 386 387
            # [0.35]

            l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
388
            output = l1_loss(input, label)
389
            print(output.numpy())
390 391 392
            # [1.4]

            l1_loss = paddle.nn.loss.L1Loss(reduction='none')
393
            output = l1_loss(input, label)
394
            print(output.numpy())
395 396
            # [[0.20000005 0.19999999]
            # [0.2        0.79999995]]
L
Leo Chen 已提交
397 398
    """

399
    def __init__(self, reduction='mean', name=None):
L
Leo Chen 已提交
400 401 402 403 404 405
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)
        super(L1Loss, self).__init__()
        self.reduction = reduction
406
        self.name = name
L
Leo Chen 已提交
407

408
    def forward(self, input, label):
409
        return paddle.nn.functional.l1_loss(
410
            input, label, self.reduction, name=self.name)
C
ceci3 已提交
411 412 413 414


class BCELoss(fluid.dygraph.Layer):
    """
C
ceci3 已提交
415
    This interface is used to construct a callable object of the ``BCELoss`` class.
416 417
    The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
    and target labels ``label`` . The binary_cross_entropy loss can be described as:
C
ceci3 已提交
418

C
ceci3 已提交
419
    If :attr:`weight` is set, the loss is:
C
ceci3 已提交
420 421

    .. math::
C
ceci3 已提交
422
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
423

C
ceci3 已提交
424
    If :attr:`weight` is None, the loss is:
C
ceci3 已提交
425 426

    .. math::
C
ceci3 已提交
427 428
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

429
    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
C
ceci3 已提交
430

C
ceci3 已提交
431
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
C
ceci3 已提交
432

C
ceci3 已提交
433 434
    .. math::
        Out = MEAN(Out)
435

C
ceci3 已提交
436
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
C
ceci3 已提交
437

C
ceci3 已提交
438 439
    .. math::
        Out = SUM(Out)
C
ceci3 已提交
440

441
    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
C
ceci3 已提交
442 443
    should be numbers between 0 and 1.

C
ceci3 已提交
444
    Parameters:
445 446
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
C
ceci3 已提交
447
            is float32, float64. Default is ``'None'``.
448
        reduction (str, optional): Indicate how to average the loss by batch_size,
C
ceci3 已提交
449
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
C
ceci3 已提交
450
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
451
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
C
ceci3 已提交
452
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
C
ceci3 已提交
453
            Default is ``'mean'``.
454 455 456 457 458 459 460 461 462 463 464 465
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        input (Tensor): 2-D tensor with shape: (N, *), N is batch_size, `*` means
            number of additional dimensions. The input ``input`` should always
            be the output of sigmod.  Available dtype is float32, float64.
        label (Tensor): 2-D tensor with the same shape as ``input``. The target
            labels which values should be numbers between 0 and 1. Available
            dtype is float32, float64.
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``input`` , else the shape of output is scalar.
C
ceci3 已提交
466

467
    Returns:
C
ceci3 已提交
468 469
        A callable object of BCELoss.

C
ceci3 已提交
470 471
    Examples:
        .. code-block:: python
C
ceci3 已提交
472

C
ceci3 已提交
473 474 475 476
            import numpy as np
            import paddle
            input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
            label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
477 478 479 480 481 482 483 484 485

            paddle.disable_static()
            input = paddle.to_variable(input_data)
            label = paddle.to_variable(label_data)
            bce_loss = paddle.nn.loss.BCELoss()
            output = bce_loss(input, label)
            print(output.numpy())  # [0.65537095]
            paddle.enable_static()

C
ceci3 已提交
486 487
    """

488
    def __init__(self, weight=None, reduction='mean', name=None):
C
ceci3 已提交
489 490 491 492 493 494 495 496
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
                "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
                "received %s, which is not allowed." % reduction)

        super(BCELoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
497
        self.name = name
C
ceci3 已提交
498 499

    def forward(self, input, label):
500 501 502
        out = paddle.nn.functional.binary_cross_entropy(
            input, label, self.weight, self.reduction, self.name)
        return out
503 504 505 506


class NLLLoss(fluid.dygraph.Layer):
    """
507 508
	:alias_main: paddle.nn.NLLLoss
	:alias: paddle.nn.NLLLoss,paddle.nn.layer.NLLLoss,paddle.nn.layer.loss.NLLLoss
S
swtkiwi 已提交
509

510
    This class accepts input and target label and returns negative log likelihood
511
    cross error. It is useful to train a classification problem with C classes.
512

513
    The input for the loss is epected to contain log-probabilities of
514
    each classes. It has to be a Tensor of size either (batch_size, C) or
515 516 517 518
    (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
    The label for the loss should be a class index in the range [0, C-1]
    where C is the number of classes. If ignore_index is specified, the
    specified target value does not contribute to the input gradient.
519

520 521 522
    If the optional argument `weight` is provided, it should be a 1D Tensor
    assigning weight to each of the classed. This is particularly useful
    when you have an unbalanced training set.
523

524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    The loss is calculated as follows.
    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:

    .. math::
        \ell(x, y) = L = \{l_1,\dots,l_N\}^\\top, \quad
        l_n = - w_{y_n} x_{n,y_n}, \quad
        w_{c} = \\text{weight}[c] \cdot \mathbb{1}\{c \\not= \\text{ignore\\_index}\},

    where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
    (default ``'mean'``), then

    .. math::
        \ell(x, y) = \\begin{cases}
            \\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n}} l_n, &
            \\text{if reduction} = \\text{'mean';}\\\\
            \\sum_{n=1}^N l_n,  &
            \\text{if reduction} = \\text{'sum'.}
        \\end{cases}

    Parameters:
544 545
        weight (Tensor, optional): Weight tensor, a manual rescaling weight given
            to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
546
            it treated as if having all ones. the data type is
547
            float32, float64, Default is ``'None'``.
548 549
        ignore_index (int64, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient.
550
        reduction (str, optional): Indicate how to average the loss,
551
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
552 553 554
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
555
            Default is ``'mean'``.
556 557
         name (str, optional): Name for the operation (optional, default is None).
             For more information, please refer to :ref:`api_guide_Name`.
558

559 560 561 562 563 564 565 566 567
    Shape:
        input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
            But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
            The data type is float32, float64.
        label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
            The data type is int64.
        output (Tensor): the `negative log likelihood loss` between input `x` and `label`.
            If `reduction` is `'none'`, the shape is `[N, *]`.
            If `reduction` is `'sum'` or `'mean'`, the shape is `[1]`.
568 569 570 571

    Examples:
        .. code-block:: python

572 573
                import paddle
                import numpy as np
574

575 576
                nll_loss = paddle.nn.layer.NLLLoss()
                log_softmax = paddle.nn.LogSoftmax(axis=1)
577

578 579 580 581 582 583
                input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ],
                                 [0.53331435, 0.07999352, 0.8549948 ],
                                 [0.25879037, 0.39530203, 0.698465  ],
                                 [0.73427284, 0.63575995, 0.18827209],
                                 [0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32)
                label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64)
584

585 586 587 588 589 590 591
                place = paddle.CPUPlace()
                paddle.disable_static(place)
                input = paddle.to_variable(input_np)
                log_out = log_softmax(input)
                label = paddle.to_variable(label_np)
                result = nll_loss(log_out, label)
                print(result.numpy()) # [1.0720209]
592

593
    """
594

595 596 597 598 599 600
    def __init__(self,
                 weight=None,
                 ignore_index=-100,
                 reduction='mean',
                 name=None):
        if reduction not in ['sum', 'mean', 'none']:
601
            raise ValueError(
602 603 604 605 606 607 608
                "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
                "'none', but received %s, which is not allowed." % reduction)
        super(NLLLoss, self).__init__()
        self._weight = weight
        self._ignore_index = ignore_index
        self._reduction = reduction
        self._name = name
609

610 611 612 613 614 615 616 617
    def forward(self, input, label):
        return F.nll_loss(
            input,
            label,
            weight=self._weight,
            ignore_index=self._ignore_index,
            reduction=self._reduction,
            name=self._name)
618 619


620 621 622 623 624 625 626 627 628 629 630
class KLDivLoss(fluid.dygraph.Layer):
    """
    This interface calculates the Kullback-Leibler divergence loss
    between Input(X) and Input(Target). Notes that Input(X) is the
    log-probability and Input(Target) is the probability.

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

    Parameters:
631
        reduction (str, optional): Indicate how to average the loss,
632
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
633
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
634 635 636
            Default is ``'mean'``.

    Shape:
637 638 639 640 641 642

        - input (Tensor): (N, *), where * means, any number of additional dimensions.

        - label (Tensor): (N, *), same shape as input.

        - output (Tensor): tensor with shape: [1] by default.
643 644 645 646 647 648 649 650


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            import paddle.nn as nn
651

652
            paddle.disable_static()
653 654 655 656 657 658 659

            shape = (5, 20)
            x = np.random.uniform(-10, 10, shape).astype('float32')
            target = np.random.uniform(-10, 10, shape).astype('float32')

            # 'batchmean' reduction, loss shape will be [N]
            kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
660 661
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
662
            # shape=[5]
663

664 665
            # 'mean' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='mean')
666 667
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
668 669 670 671
            # shape=[1]

            # 'sum' reduction, loss shape will be [1]
            kldiv_criterion = nn.KLDivLoss(reduction='sum')
672 673
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
674 675 676 677
            # shape=[1]

            # 'none' reduction, loss shape is same with X shape
            kldiv_criterion = nn.KLDivLoss(reduction='none')
678 679
            pred_loss = kldiv_criterion(paddle.to_tensor(x),
                                        paddle.to_tensor(target))
680 681 682 683 684 685 686 687 688 689 690 691
            # shape=[5, 20]
    """

    def __init__(self, reduction='mean'):
        super(KLDivLoss, self).__init__()
        self.reduction = reduction

    def forward(self, input, label):
        out = paddle.nn.functional.kl_div(input, label, self.reduction)
        return out


692 693 694 695
class MarginRankingLoss(fluid.dygraph.Layer):
    """

    This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
696
    The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
697 698
    , use the math function as follows.

699
    .. math::
700
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

719
    Shape:
720 721
        input: N-D Tensor, the shape is [N, *], N is batch size and `*` means any number of additional dimensions., available dtype is float32, float64.
        other: N-D Tensor, `other` have the same shape and dtype as `input`.
722 723
        label: N-D Tensor, label have the same shape and dtype as `input`.
        output: If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
724 725 726 727 728 729 730 731

    Returns:
        A callable object of MarginRankingLoss.

    Examples:

        .. code-block:: python

732 733 734
            import numpy as np
            import paddle

735
            paddle.disable_static()
736

737 738
            input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype("float32"))
            other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype("float32"))
739
            label = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype("float32"))
740
            margin_rank_loss = paddle.nn.MarginRankingLoss()
741
            loss = margin_rank_loss(input, other, label)
742 743 744 745 746 747
            print(loss.numpy()) # [0.75]
    """

    def __init__(self, margin=0.0, reduction='mean', name=None):
        if reduction not in ['sum', 'mean', 'none']:
            raise ValueError(
748
                "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
749 750 751 752 753 754
                "received %s, which is not allowed." % reduction)
        super(MarginRankingLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
        self.name = name

755
    def forward(self, input, other, label):
756
        out = paddle.nn.functional.margin_ranking_loss(
757
            input, other, label, self.margin, self.reduction, self.name)
758
        return out
759 760


761 762 763 764 765
class CTCLoss(fluid.dygraph.Layer):
    """
	:alias_main: paddle.nn.CTCLoss
	:alias: paddle.nn.CTCLoss, paddle.nn.layer.CTCLoss, paddle.nn.layer.loss.CTCLoss

766 767 768
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
769 770 771 772 773 774 775 776 777 778 779 780 781 782
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.

    Shape:
        log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32.
        labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.

    Returns:
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
783

784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
    Examples:

        .. code-block:: python

            # declarative mode
            import numpy as np
            import paddle

            # length of the longest logit sequence
            max_seq_length = 4
            #length of the longest label sequence
            max_label_length = 3
            # number of logit sequences
            batch_size = 2
            # class num
            class_num = 3

            np.random.seed(1)
            log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
                                    [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

                                    [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
                                    [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

                                    [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
                                    [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

                                    [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
                                    [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

                                    [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
                                    [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
            labels = np.array([[1, 2, 2],
                            [1, 2, 2]]).astype("int32")
            input_lengths = np.array([5, 5]).astype("int64")
            label_lengths = np.array([3, 3]).astype("int64")

            paddle.disable_static()
822 823 824 825
            log_probs = paddle.to_tensor(log_probs)
            labels = paddle.to_tensor(labels)
            input_lengths = paddle.to_tensor(input_lengths)
            label_lengths = paddle.to_tensor(label_lengths)
826

827 828
            loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
                input_lengths,
829 830 831
                label_lengths)
            print(loss.numpy())  #[3.9179852 2.9076521]

832 833
            loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
                input_lengths,
834 835 836 837 838 839 840 841 842 843 844 845 846 847 848
                label_lengths)
            print(loss.numpy())  #[1.1376063]
    """

    def __init__(self, blank=0, reduction='mean'):
        super(CTCLoss, self).__init__()
        self.blank = blank
        self.reduction = reduction

    def forward(self, log_probs, labels, input_lengths, label_lengths):
        return paddle.nn.functional.ctc_loss(log_probs, labels, input_lengths,
                                             label_lengths, self.blank,
                                             self.reduction)


849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875
class SmoothL1Loss(fluid.dygraph.Layer):
    """
    This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

         loss(x,y)=\\frac{1}{n}\\sum_{i}z_i

    where z_i is given by:

    .. math::

         \\mathop{z_i}=\\left\\{\\begin{array}{rcl}
        0.5(x_i - y_i)^2 & & {if |x_i - y_i| < delta} \\\\
        delta * |x_i - y_i| - 0.5 * delta^2 & & {otherwise}
        \\end{array} \\right.

    Parameters:
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
876
        delta (float, optional): Specifies the hyperparameter delta to be used.
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
            negative/zero values. Default = 1.0
        name (str, optional): Name for the operation (optional, default is
            None). For more information, please refer to :ref:`api_guide_Name`.

    Call Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is
            (N, C), where C is number of classes, and if shape is more than 2D, this
            is (N, C, D1, D2,..., Dk), k >= 1.
        label (Tensor): Label tensor, the data type is float32 or float64. The shape of label
            is the same as the shape of input.

    Returns:
        The tensor variable storing the smooth_l1_loss of input and label.

    Return type: Tensor.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            paddle.disable_static()
            input_data = np.random.rand(3,3).astype("float32")
            label_data = np.random.rand(3,3).astype("float32")
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
            loss = paddle.nn.SmoothL1Loss()
            output = loss(input, label)
            print(output.numpy())
    """

    def __init__(self, reduction='mean', delta=1.0, name=None):
        super(SmoothL1Loss, self).__init__()
        self.reduction = reduction
        self.delta = delta
        self.name = name

    def forward(self, input, label):
        return F.smooth_l1_loss(
            input,
            label,
            reduction=self.reduction,
            delta=self.delta,
            name=self.name)