loss.py 158.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import math

17
# TODO: define loss functions of neural network
18
import paddle
19
from paddle import _C_ops, _legacy_C_ops, fluid, in_dynamic_mode
20
from paddle.framework import core
21
from paddle.utils import deprecated
22

23
from ...common_ops_import import Variable
24
from ...fluid.data_feeder import check_variable_and_dtype
姜永久 已提交
25
from ...fluid.framework import _current_expected_place, in_dygraph_mode
26 27
from ...fluid.layer_helper import LayerHelper
from ...tensor.manipulation import reshape
28

29 30
__all__ = []

31 32
kIgnoreIndex = -100

33

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
def dice_loss(input, label, epsilon=0.00001, name=None):
    r"""

    Dice loss for comparing the similarity between the input predictions and the label.
    This implementation is for binary classification, where the input is sigmoid
    predictions of each pixel, usually used for segmentation task. The dice loss can
    be defined as the following equation:

    .. math::

        dice\_loss &= 1 - \frac{2 * intersection\_area}{total\_area} \\
                  &= \frac{(total\_area - intersection\_area) - intersection\_area}{total\_area} \\
                  &= \frac{(union\_area - intersection\_area)}{total\_area}


    Parameters:
        input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_k, D]`, where :math:`N_1` is
                          the batch_size, :math:`D` is the number of categories. It is usually the output
                          predictions of sigmoid activation. The data type can be float32 or float64.
        label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_k, 1]`.
                          where :math:`N_1` is the batch_size. The data type can be int32 or int64.
        epsilon (float): The epsilon will be added to the numerator and denominator.
                         If both input and label are empty, it makes sure dice is 1.
                         Default: 0.00001
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`

    Returns:
        Tensor, which shape is [1], data type is the same as `input` .

    Example:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            x = paddle.randn((3,224,224,2))
            label = paddle.randint(high=2, shape=(3,224,224,1))
            predictions = F.softmax(x)
            loss = F.dice_loss(input=predictions, label=label)
    """
    assert input.dtype in (paddle.float32, paddle.float64)
    assert label.dtype in (paddle.int32, paddle.int64)
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    assert (
        len(input.shape) >= 2
    ), "The rank of input should be greater than or equal to 2."
    assert len(input.shape) == len(label.shape), (
        "The rank of input and label should be equal, "
        "but received input: %d, label: %d."
        % (len(input.shape), len(label.shape))
    )
    assert label.shape[-1] == 1, (
        "The last dimension of label should be 1, "
        "but received %d." % label.shape[-1]
    )
    assert (
        input.shape[:-1] == label.shape[:-1]
    ), "All dimensions should be equal except the last one."
    assert (
        input.numel() > 0 and label.numel() > 0
    ), "Any dimension of input and label cannot be equal to 0."
96 97 98 99 100 101

    label = paddle.squeeze(label, [-1])
    label = paddle.nn.functional.one_hot(label, input.shape[-1])
    reduce_dim = list(range(1, len(input.shape)))
    inse = paddle.sum(input * label, axis=reduce_dim)
    dice_denominator = paddle.sum(input, axis=reduce_dim) + paddle.sum(
102 103
        label, axis=reduce_dim
    )
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    dice_score = 1 - inse * 2 / (dice_denominator + epsilon)
    return paddle.mean(dice_score)


def log_loss(input, label, epsilon=1e-4, name=None):
    r"""

    **Negative Log Loss Layer**

    This layer accepts input predictions and target label and returns the
    negative log loss.

    .. math::

        Out = -label * \log{(input + \epsilon)}
              - (1 - label) * \log{(1 - input + \epsilon)}

    Args:
        input (Tensor|list):  A 2-D tensor with shape [N x 1], where N is the
                                batch size. This input is a probability computed
                                by the previous operator. Data type float32.
        label (Tensor|list):  The ground truth which is a 2-D tensor with
                                shape [N x 1], where N is the batch size.
                                Data type float32.
        epsilon (float, optional): A small number for numerical stability. Default 1e-4.
        name(str|None): For detailed information, please refer to
            :ref:`api_guide_Name` . Usually name is no need to set and None by default.

    Returns:
        Tensor, which shape is [N x 1], data type is float32.

    Examples:
        .. code-block:: python

          import paddle
          import paddle.nn.functional as F

          label = paddle.randn((10,1))
          prob = paddle.randn((10,1))
          cost = F.log_loss(input=prob, label=label)
    """
    if in_dygraph_mode():
146
        return _C_ops.log_loss(input, label, epsilon)
147 148 149 150 151 152 153

    helper = LayerHelper('log_loss', **locals())
    check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
    check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')

    loss = helper.create_variable_for_type_inference(dtype=input.dtype)

154 155 156 157 158 159
    helper.append_op(
        type='log_loss',
        inputs={'Predicted': [input], 'Labels': [label]},
        outputs={'Loss': [loss]},
        attrs={'epsilon': epsilon},
    )
160 161 162
    return loss


163 164 165 166 167 168 169 170 171
def fluid_softmax_with_cross_entropy(
    logits,
    label,
    soft_label=False,
    ignore_index=-100,
    numeric_stable_mode=True,
    return_softmax=False,
    axis=-1,
):
172 173
    r"""

174 175
    This operator implements the cross entropy loss function with softmax. This function
    combines the calculation of the softmax operation and the cross entropy loss function
176 177 178 179 180 181
    to provide a more numerically stable gradient.

    Because this operator performs a softmax on logits internally, it expects
    unscaled logits. This operator should not be used with the output of
    softmax operator since that would produce incorrect results.

182 183 184
    When the attribute :attr:`soft_label` is set :attr:`False`, this operators
    expects mutually exclusive hard labels, each sample in a batch is in exactly
    one class with a probability of 1.0. Each sample in the batch will have a
185 186 187 188 189 190 191
    single label.

    The equation is as follows:

    1) Hard label (one-hot label, so every sample has exactly one class)

    .. math::
192
        \\loss_j=-\text{logits}_{label_j} +\log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right), j = 1,..., K
193 194 195 196

    2) Soft label (each sample can have a distribution over all classes)

    .. math::
197
        \\loss_j= -\sum_{i=0}^{K}\text{label}_i\left(\text{logits}_i - \log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right)\right), j = 1,...,K
198 199 200 201

    3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:

    .. math::
202 203 204
        \\max_j&=\max_{i=0}^{K}{\text{logits}_i} \\
                log\_max\_sum_j &= \log\sum_{i=0}^{K}\exp(logits_i - max_j)\\
                softmax_j &= \exp(logits_j - max_j - {log\_max\_sum}_j)
205 206 207 208 209 210

    and then cross entropy loss is calculated by softmax and label.

    Args:
        logits (Tensor): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
        label (Tensor): The ground truth  ``Tensor`` , data type is the same
211 212 213
            as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`,
            Label is a ``Tensor``  in the same shape with :attr:`logits`.
            If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor``
214 215 216 217 218
            in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
        soft_label (bool, optional): A flag to indicate whether to interpretant the given
            labels as soft labels. Default False.
        ignore_index (int, optional): Specifies a target value that is ignored and does
                                      not contribute to the input gradient. Only valid
219
                                      if :attr:`soft_label` is set to :attr:`False`.
220 221 222
                                      Default: kIgnoreIndex(-100).
        numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
                                              numerically stable algorithm. Only valid
223 224 225
                                              when :attr:`soft_label` is :attr:`False`
                                              and GPU is used. When :attr:`soft_label`
                                              is :attr:`True` or CPU is used, the
226 227 228 229 230
                                              algorithm is always numerically stable.
                                              Note that the speed may be slower when use
                                              stable algorithm. Default: True.
        return_softmax (bool, optional): A flag indicating whether to return the softmax
                                         along with the cross entropy loss. Default: False.
231
        axis (int, optional): The index of dimension to perform softmax calculations. It
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
                              should be in range :math:`[-1, rank - 1]`, while :math:`rank`
                              is the rank of input :attr:`logits`. Default: -1.

    Returns:
        ``Tensor`` or Tuple of two ``Tensor`` : Return the cross entropy loss if \
                                                    `return_softmax` is False, otherwise the tuple \
                                                    (loss, softmax), softmax is in the same shape \
                                                    with input logits and cross entropy loss is in \
                                                    the same shape with input logits except shape \
                                                    in dimension :attr:`axis` as 1.

    Examples:
        .. code-block:: python

            import paddle
247 248 249 250 251

            logits = paddle.to_tensor([0.4, 0.6, 0.9])
            label = paddle.randint(high=2, shape=[1], dtype="int64")

            out = paddle.nn.functional.softmax_with_cross_entropy(logits=logits, label=label)
252
            print(out)
253 254
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [1.15328646])
255
    """
姜永久 已提交
256
    if in_dygraph_mode():
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        if core.is_compiled_with_custom_device("npu"):
            if not soft_label:
                valid_label = (
                    paddle.cast(label != ignore_index, dtype=label.dtype)
                    * label
                )
                softmax, loss = _legacy_C_ops.softmax_with_cross_entropy(
                    logits,
                    valid_label,
                    'soft_label',
                    soft_label,
                    'ignore_index',
                    ignore_index,
                    'numeric_stable_mode',
                    numeric_stable_mode,
                    'axis',
                    axis,
                    'use_softmax',
                    True,
                )
            else:
                softmax, loss = _legacy_C_ops.softmax_with_cross_entropy(
                    logits,
                    label,
                    'soft_label',
                    soft_label,
                    'ignore_index',
                    ignore_index,
                    'numeric_stable_mode',
                    numeric_stable_mode,
                    'axis',
                    axis,
                    'use_softmax',
                    True,
                )
292
        else:
姜永久 已提交
293 294 295 296 297 298 299 300 301
            softmax, loss = _C_ops.cross_entropy_with_softmax(
                logits,
                label,
                soft_label,
                True,
                numeric_stable_mode,
                ignore_index,
                axis,
            )
302 303 304 305
        if not return_softmax:
            return loss
        else:
            return loss, softmax
姜永久 已提交
306 307 308 309 310 311 312 313 314 315
    else:
        attrs = {
            'soft_label': soft_label,
            'ignore_index': ignore_index,
            'numeric_stable_mode': numeric_stable_mode,
            'axis': axis,
        }
        helper = LayerHelper('softmax_with_cross_entropy', **locals())
        softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
        loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
316

姜永久 已提交
317 318 319 320 321 322 323
        outputs = {'Softmax': softmax, 'Loss': loss}
        helper.append_op(
            type='softmax_with_cross_entropy',
            inputs={'Logits': logits, 'Label': label},
            outputs=outputs,
            attrs=attrs,
        )
324

姜永久 已提交
325 326
        if return_softmax:
            return loss, softmax
327

姜永久 已提交
328
        return loss
329 330 331


def npair_loss(anchor, positive, labels, l2_reg=0.002):
332 333
    """

334 335 336
    Npair loss requires paired data. Npair loss has two parts: the first part is L2
    regularizer on the embedding vector; the second part is cross entropy loss which
    takes the similarity matrix of anchor and positive as logits.
337

338 339
    For more information, please refer to:
    `Improved Deep Metric Learning with Multi class N pair Loss Objective <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf>`_
340

341
    Args:
342
      anchor(Tensor): embedding vector for the anchor image. shape=[batch_size, embedding_dims],
343
                        the data type is float32 or float64.
344
      positive(Tensor): embedding vector for the positive image. shape=[batch_size, embedding_dims],
345 346 347 348
                        the data type is float32 or float64.
      labels(Tensor): 1-D tensor. shape=[batch_size], the data type is float32 or float64 or int64.
      l2_reg(float32): L2 regularization term on embedding vector, default: 0.002.

349

350 351
    Returns:
      A Tensor representing the npair loss, the data type is the same as anchor, the shape is [1].
352

353 354 355
    Examples:

      .. code-block:: python
356

357
          import paddle
358

359
          DATATYPE = "float32"
360

361 362 363
          anchor = paddle.rand(shape=(18, 6), dtype=DATATYPE)
          positive = paddle.rand(shape=(18, 6), dtype=DATATYPE)
          labels = paddle.rand(shape=(18,), dtype=DATATYPE)
364

365 366
          npair_loss = paddle.nn.functional.npair_loss(anchor, positive, labels, l2_reg = 0.002)
          print(npair_loss)
367

368
    """
S
supplyout 已提交
369 370 371 372
    if anchor.size == 0:
        raise ValueError("The dims of anchor should be greater than 0.")
    if positive.size == 0:
        raise ValueError("The dims of positive should be greater than 0.")
373 374 375 376 377 378 379 380 381
    check_variable_and_dtype(
        anchor, 'anchor', ['float32', 'float64'], 'npair_loss'
    )
    check_variable_and_dtype(
        positive, 'positive', ['float32', 'float64'], 'positive'
    )
    check_variable_and_dtype(
        labels, 'labels', ['float32', 'float64', 'int64'], 'labels'
    )
382 383 384 385 386 387
    Beta = 0.25
    batch_size = labels.shape[0]

    labels = paddle.reshape(labels, shape=[batch_size, 1])
    labels = paddle.tile(labels, repeat_times=[1, batch_size])

388 389 390
    labels = paddle.equal(labels, paddle.transpose(labels, perm=[1, 0])).astype(
        'float32'
    )
391 392
    labels = labels / paddle.sum(labels, axis=1, keepdim=True)

393 394 395
    l2loss = paddle.mean(paddle.sum(paddle.square(anchor), 1)) + paddle.mean(
        paddle.sum(paddle.square(positive), 1)
    )
396 397
    l2loss = l2loss * Beta * l2_reg

398 399 400 401 402 403
    similarity_matrix = paddle.matmul(
        anchor, positive, transpose_x=False, transpose_y=True
    )
    softmax_ce = fluid_softmax_with_cross_entropy(
        logits=similarity_matrix, label=labels, soft_label=True
    )
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
    cross_entropy = paddle.sum(labels * softmax_ce, 0)
    celoss = paddle.mean(cross_entropy)

    return l2loss + celoss


def square_error_cost(input, label):
    r"""

    This op accepts input predictions and target label and returns the
    squared error cost.

    For predictions label, and target label, the equation is:

    .. math::

        Out = (input - label)^2

    Parameters:
        input (Tensor): Input tensor, the data type should be float32.
        label (Tensor): Label tensor, the data type should be float32.

    Returns:
427 428
        Tensor, The tensor storing the element-wise squared error
        difference between input and label.
429 430 431 432 433 434 435 436 437 438 439 440 441

    Examples:

        .. code-block:: python

            import paddle
            input = paddle.to_tensor([1.1, 1.9])
            label = paddle.to_tensor([1.0, 2.0])
            output = paddle.nn.functional.square_error_cost(input, label)
            print(output)
            # [0.01, 0.01]

    """
442
    if in_dygraph_mode():
443 444
        minus_out = _C_ops.subtract(input, label)
        square_out = _C_ops.square(minus_out)
445
        return square_out
姜永久 已提交
446 447 448 449 450 451 452 453 454 455 456 457 458 459
    else:
        check_variable_and_dtype(
            input, "input", ['float32', 'float64'], 'square_error_cost'
        )
        check_variable_and_dtype(
            label, "label", ['float32', 'float64'], 'square_error_cost'
        )
        helper = LayerHelper('square_error_cost', **locals())
        minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
        helper.append_op(
            type='elementwise_sub',
            inputs={'X': [input], 'Y': [label]},
            outputs={'Out': [minus_out]},
        )
460

姜永久 已提交
461 462 463 464 465 466 467 468 469
        square_out = helper.create_variable_for_type_inference(
            dtype=input.dtype
        )
        helper.append_op(
            type='square',
            inputs={'X': [minus_out]},
            outputs={'Out': [square_out]},
        )
        return square_out
470 471


472 473 474 475 476 477 478 479
def edit_distance(
    input,
    label,
    normalized=True,
    ignored_tokens=None,
    input_length=None,
    label_length=None,
):
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
    """
    This op computes the edit distances, also called Levenshtein distance, between a batch of
    hypothesis strings and their references. It measures how dissimilar two strings are by counting
    the minimum number of operations to transform one string into another.
    The operations include insertion, deletion, and substitution.

    For example, given hypothesis string A = "kitten" and reference
    B = "sitting", A will be transformed into B
    at least after two substitutions and one insertion:

    "kitten" -> "sitten" -> "sittin" -> "sitting"

    So the edit distance between A and B is 3.

    The input is a Tensor, the input_length and label_length should be supported.

    The `batch_size` of labels should be same as `input`.

    The output include the edit distance value between every pair of input and related label, and the number of sequence.
    If Attr(normalized) is true,
    the edit distance value will be divided by the length of label.

    Parameters:
        input(Tensor): The input tensor, its rank should be equal to 2 and its data type should be int64.
        label(Tensor): The label tensor, its rank should be equal to 2 and its data type should be int64.
        normalized(bool, default True): Indicated whether to normalize the edit distance.
        ignored_tokens(list<int>, default None): Tokens that will be removed before
                                     calculating edit distance.
        input_length(Tensor): The length for each sequence in `input` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
        label_length(Tensor): The length for each sequence in `label` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
        NOTE: To be avoid unexpected result, the value of every elements in input_length and label_length should be equal to the value of the second dimension of input and label. For example, The input: [[1,2,3,4],[5,6,7,8],[9,10,11,12]], the shape of input is [3,4] and the input_length should be [4,4,4]

    Returns:
513 514 515
        Tuple:
            distance(Tensor): edit distance result, its data type is float32, and its shape is (batch_size, 1).
            sequence_num(Tensor): sequence number, its data type is float32, and its shape is (1,).
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.to_tensor([[1,2,3],[4,5,6],[4,4,4],[1,1,1]], dtype='int64')
            label = paddle.to_tensor([[1,3,4,1],[4,5,8,1],[7,7,7,1],[1,1,1,1]], dtype='int64')
            input_len = paddle.to_tensor([3,3,3,3], dtype='int64')
            label_len = paddle.to_tensor([4,4,4,4], dtype='int64')

            distance, sequence_num = F.loss.edit_distance(input=input, label=label, input_length=input_len, label_length=label_len, normalized=False)

            # print(distance)
            # [[3.]
            #  [2.]
            #  [4.]
            #  [1.]]
            # if set normalized to True
            # [[0.75]
            #  [0.5 ]
            #  [1.  ]
            #  [0.25]
            #
            # print(sequence_num)
            # [4]

    """
545

546 547 548 549 550 551 552
    helper = LayerHelper("edit_distance", **locals())

    # remove some tokens from input and labels
    if ignored_tokens is not None and len(ignored_tokens) > 0:
        erased_input = helper.create_variable_for_type_inference(dtype="int64")
        erased_label = helper.create_variable_for_type_inference(dtype="int64")

553 554 555 556 557 558
        helper.append_op(
            type="sequence_erase",
            inputs={"X": [input]},
            outputs={"Out": [erased_input]},
            attrs={"tokens": ignored_tokens},
        )
559 560
        input = erased_input

561 562 563 564 565 566
        helper.append_op(
            type="sequence_erase",
            inputs={"X": [label]},
            outputs={"Out": [erased_label]},
            attrs={"tokens": ignored_tokens},
        )
567 568
        label = erased_label

Z
zhiboniu 已提交
569
    if in_dygraph_mode():
570 571 572
        return _C_ops.edit_distance(
            input, label, input_length, label_length, normalized
        )
Z
zhiboniu 已提交
573

574 575
    check_variable_and_dtype(input, 'input', ['int64'], 'edit_distance')
    check_variable_and_dtype(label, 'label', ['int64'], 'edit_distance')
576 577 578 579 580 581 582 583
    this_inputs = {"Hyps": [input], "Refs": [label]}
    if input_length is not None and label_length is not None:
        this_inputs['HypsLength'] = [input_length]
        this_inputs['RefsLength'] = [label_length]

    # edit distance op
    edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
    sequence_num = helper.create_variable_for_type_inference(dtype="int64")
584 585 586 587 588 589
    helper.append_op(
        type="edit_distance",
        inputs=this_inputs,
        outputs={"Out": [edit_distance_out], "SequenceNum": [sequence_num]},
        attrs={"normalized": normalized},
    )
590 591 592 593

    return edit_distance_out, sequence_num


594 595 596
def binary_cross_entropy(
    input, label, weight=None, reduction='mean', name=None
):
597
    """
学渣戊's avatar
学渣戊 已提交
598
    Measure the binary_cross_entropy loss between input predictions ``input``
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
    and target labels ``label`` . The binary_cross_entropy loss can be described as:

    If :attr:`weight` is set, the loss is:

    .. math::
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))

    If :attr:`weight` is None, the loss is:

    .. math::
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(Out)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(Out)

    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
    should be numbers between 0 and 1.

    Parameters:
        input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``input``
            should always be the output of sigmod.  Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``input``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
            is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.


    Returns:
学渣戊's avatar
学渣戊 已提交
647
        Tensor. If ``reduction`` is ``'none'``, the shape of output is
648 649 650 651 652 653 654
            same as ``input`` , else the shape of output is scalar.

    Examples:
        .. code-block:: python

            import paddle

655 656
            input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32')
            label = paddle.to_tensor([1.0, 0.0, 1.0], 'float32')
657
            output = paddle.nn.functional.binary_cross_entropy(input, label)
N
Noel 已提交
658
            print(output)  # [0.65537095]
659 660 661 662 663

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in binary_cross_entropy should be 'sum', "
664 665 666
            "'mean' or 'none', but received %s, which is not allowed."
            % reduction
        )
667

J
Jiabin Yang 已提交
668
    if in_dygraph_mode():
669
        out = _C_ops.bce_loss(input, label)
670
        if weight is not None:
671
            out = _C_ops.multiply(out, weight, 'axis', -1)
672 673

        if reduction == 'sum':
674
            return _C_ops.sum(out, [], None, False)
675

676
        elif reduction == 'mean':
677
            return _C_ops.mean_all(out)
678 679 680
        else:
            return out
    else:
姜永久 已提交
681 682 683 684 685 686
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'binary_cross_entropy'
        )
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'binary_cross_entropy'
        )
J
Jiabin Yang 已提交
687

姜永久 已提交
688 689 690 691 692 693 694 695 696 697 698
        sub_name = name if weight is None and reduction == 'none' else None
        helper = LayerHelper("binary_cross_entropy", name=sub_name)
        out = helper.create_variable_for_type_inference(dtype=input.dtype)
        helper.append_op(
            type='bce_loss',
            inputs={
                'X': [input],
                'Label': [label],
            },
            outputs={'Out': [out]},
        )
J
Jiabin Yang 已提交
699

姜永久 已提交
700 701 702 703
        if weight is not None:
            if isinstance(weight, paddle.static.Variable):
                weight_name = name if reduction == 'none' else None
                out = paddle.multiply(out, weight, name=weight_name)
J
Jiabin Yang 已提交
704
            else:
姜永久 已提交
705 706 707 708 709 710 711 712 713 714
                raise ValueError(
                    "The weight is not a Tensor, please convert to Tensor."
                )

        if reduction == 'sum':
            return paddle.sum(out, name=name)
        elif reduction == 'mean':
            return paddle.mean(out, name=name)
        else:
            return out
715 716


717 718 719
def binary_cross_entropy_with_logits(
    logit, label, weight=None, reduction='mean', pos_weight=None, name=None
):
720
    r"""
学渣戊's avatar
学渣戊 已提交
721
    Combine the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
722 723 724 725 726 727 728

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

学渣戊's avatar
学渣戊 已提交
729
    Firstly, calculate loss function as follows:
730 731

    .. math::
732
           Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
733

734
    We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
735 736

    .. math::
737
           Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
738

N
Noel 已提交
739
    For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
740 741 742
    we reformulate the loss as follows:

    .. math::
743
           Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
744

学渣戊's avatar
学渣戊 已提交
745
    Then, if ``weight`` or ``pos_weight`` is not None, then multiply the
746 747 748 749
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

学渣戊's avatar
学渣戊 已提交
750 751
    Finally, apply reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, will return the original loss `Out`.
752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``logit``
            is usually the output of Linear layer. Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``logit``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
学渣戊's avatar
学渣戊 已提交
780
        Tensor. If ``reduction`` is ``'none'``, the shape of output is
781 782 783 784 785 786 787
            same as ``logit`` , else the shape of output is scalar.

    Examples:

        .. code-block:: python

            import paddle
N
Noel 已提交
788

789 790
            logit = paddle.to_tensor([5.0, 1.0, 3.0])
            label = paddle.to_tensor([1.0, 0.0, 1.0])
791
            output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label)
N
Noel 已提交
792
            print(output)  # [0.45618808]
793 794 795 796 797 798

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in binary_cross_entropy_with_logits "
            "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
799 800
            % reduction
        )
801

802
    if in_dygraph_mode():
803 804 805
        one = _C_ops.full(
            [1],
            float(1.0),
806
            logit.dtype,
807 808 809 810 811
            _current_expected_place(),
        )
        out = _C_ops.sigmoid_cross_entropy_with_logits(
            logit, label, False, -100
        )
812
        if pos_weight is not None:
813
            log_weight = _C_ops.add(
814 815
                _C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
            )
816
            out = _C_ops.multiply(out, log_weight)
817
        if weight is not None:
818
            out = _C_ops.multiply(out, weight)
819 820

        if reduction == "sum":
821
            return _C_ops.sum(out, [], None, False)
822
        elif reduction == "mean":
823
            return _C_ops.mean_all(out)
H
hong 已提交
824
        else:
825
            return out
姜永久 已提交
826
    else:
827
        check_variable_and_dtype(
姜永久 已提交
828 829
            logit,
            'logit',
830 831 832 833
            ['float32', 'float64'],
            'binary_cross_entropy_with_logits',
        )
        check_variable_and_dtype(
姜永久 已提交
834 835
            label,
            'label',
836 837 838
            ['float32', 'float64'],
            'binary_cross_entropy_with_logits',
        )
姜永久 已提交
839 840 841
        sigmoid_name = None
        if reduction == 'none' and pos_weight is None and weight is None:
            sigmoid_name = name
842

姜永久 已提交
843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884
        helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())

        out = helper.create_variable_for_type_inference(dtype=logit.dtype)

        helper.append_op(
            type="sigmoid_cross_entropy_with_logits",
            inputs={"X": logit, "Label": label},
            attrs={"ignore_index": kIgnoreIndex, 'normalize': False},
            outputs={"Out": out},
        )

        one = paddle.full(shape=[1], fill_value=1.0, dtype=logit.dtype)
        if pos_weight is not None:
            check_variable_and_dtype(
                pos_weight,
                'pos_weight',
                ['float32', 'float64'],
                'binary_cross_entropy_with_logits',
            )
            log_weight = paddle.add(
                paddle.multiply(label, paddle.subtract(pos_weight, one)), one
            )
            pos_weight_name = (
                name if reduction == 'none' and weight is None else None
            )
            out = paddle.multiply(out, log_weight, name=pos_weight_name)

        if weight is not None:
            check_variable_and_dtype(
                weight,
                'weight',
                ['float32', 'float64'],
                'binary_cross_entropy_with_logits',
            )
            weight_name = name if reduction == 'none' else None
            out = paddle.multiply(out, weight, name=weight_name)

        if reduction == "sum":
            return paddle.sum(out, name=name)
        elif reduction == "mean":
            return paddle.mean(out, name=name)
        return out
885 886


887 888 889 890 891 892 893 894 895 896 897
def hsigmoid_loss(
    input,
    label,
    num_classes,
    weight,
    bias=None,
    path_table=None,
    path_code=None,
    is_sparse=False,
    name=None,
):
898 899 900
    """
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
901

902 903 904
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
905 906

    Comparing to softmax, hsigmoid can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
907 908
    represents the number of classes or the size of word dict.

909 910 911 912
    The API supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_.

    For the custom tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):
913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        input (Tensor): A tensor with the shape [N, D], where N is the size of mini-batch,
            and D is the feature size. Its data type supports float32 or float64.
        label (Tensor): A tensor contains the labels of training data. Its shape is [N, 1]
            and data type is int64.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (path_code and path_table is None are None), `num_classes`
            should not be None. If the custom tree is used (path_code and path_table is None are not None),
            `num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        weight (Tensor): A tensor with shape (num_classes - 1, D), with the same data type as `input`.
        bias (Tensor, optional): A tensor with shape (num_classes - 1, 1), with the same data type as `input`.
            If `bias` is None, no bias will be add. Default is None.
        path_table (Tensor, optional): A tensor that stores each batch of samples' path from leaf to root
            node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i,
            path_table[i] is a np.array like structure and each element in this array is the indexes in parent
            nodes' weight matrix. If `path_table` and `path_code` are None, the default tree will be used.
            Default is None.
        path_code (Tensor, optional): A tensor that stores each batch of samples' code of path from leaf
            to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`.
            Each code of path is consisted with the code of nodes from leaf to root node. If `path_table` and
            `path_code` are None, the default tree will be used. Default is None.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating. If `is_sparse` is True,
            the gradient of `weight` and `input` will be sparse. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as `input`.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            paddle.set_device('cpu')

L
Linjie Chen 已提交
959 960 961 962 963
            input = paddle.uniform([4, 3])
            # [[0.45424712  -0.77296764  0.82943869] # random
            #  [0.85062802  0.63303483  0.35312140] # random
            #  [0.57170701  0.16627562  0.21588242] # random
            #  [0.27610803  -0.99303514  -0.17114788]] # random
964 965 966
            label = paddle.to_tensor([0, 1, 4, 5])
            num_classes = 5
            weight=paddle.uniform([num_classes-1, 3])
L
Linjie Chen 已提交
967 968 969 970
            # [[-0.64477652  0.24821866  -0.17456549] # random
            #  [-0.04635394  0.07473493  -0.25081766] # random
            #  [ 0.05986035  -0.12185556  0.45153677] # random
            #  [-0.66236806  0.91271877  -0.88088769]] # random
971 972

            out=F.hsigmoid_loss(input, label, num_classes, weight)
L
Linjie Chen 已提交
973 974 975 976
            # [[1.96709502]
            #  [2.40019274]
            #  [2.11009121]
            #  [1.92374969]]
977
    """
L
Linjie Chen 已提交
978
    if num_classes < 2:
979
        raise ValueError(f'Expected num_classes >= 2 (got {num_classes})')
L
Linjie Chen 已提交
980

981
    if in_dygraph_mode():
982
        out, _, _ = _C_ops.hsigmoid_loss(
983 984
            input,
            label,
985 986
            weight,
            bias,
987 988 989 990 991 992
            path_table,
            path_code,
            num_classes,
            is_sparse,
            is_sparse,
        )
993
        return out
姜永久 已提交
994
    else:
995

996
        check_variable_and_dtype(
姜永久 已提交
997
            input, 'input', ['float32', 'float64'], 'hsigmoid_loss'
998
        )
姜永久 已提交
999
        check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid_loss')
1000
        check_variable_and_dtype(
姜永久 已提交
1001
            weight, 'weight', ['float32', 'float64'], 'hsigmoid_loss'
1002
        )
姜永久 已提交
1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014
        if bias is not None:
            check_variable_and_dtype(
                bias, 'bias', ['float32', 'float64'], 'hsigmoid_loss'
            )
        if path_table is not None:
            check_variable_and_dtype(
                path_table, 'path_table', ['int64'], 'hsigmoid_loss'
            )
        if path_code is not None:
            check_variable_and_dtype(
                path_code, 'path_code', ['int64'], 'hsigmoid_loss'
            )
1015

姜永久 已提交
1016 1017 1018 1019
        attrs = {
            "num_classes": num_classes,
            "is_sparse": is_sparse,
        }
1020

姜永久 已提交
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
        inputs = {
            "X": input,
            "W": weight,
            "Bias": bias,
            "PathTable": path_table,
            "PathCode": path_code,
            "Label": label,
        }

        helper = LayerHelper('hsigmoid_loss', **locals())
        out = helper.create_variable_for_type_inference(input.dtype)
        pre_out = helper.create_variable_for_type_inference(input.dtype)
        outputs = {"Out": out, "PreOut": pre_out, "W_Out": weight}

        helper.append_op(
            type="hierarchical_sigmoid",
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
        )
        return out
1042 1043


1044
def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
1045
    r"""
1046
    Calculate smooth_l1_loss. Creates a criterion that uses a squared
1047 1048 1049 1050 1051 1052
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

1053
        loss(x,y) = \frac{1}{n}\sum_{i}z_i
1054 1055


1056
    where :math:`z_i` is given by:
1057 1058 1059

    .. math::

1060
        \mathop{z_i} = \left\{\begin{array}{rcl}
1061 1062 1063
                0.5(x_i - y_i)^2 & & {if |x_i - y_i| < \delta} \\
                \delta * |x_i - y_i| - 0.5 * \delta^2 & & {otherwise}
            \end{array} \right.
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076

    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is
            (N, C), where C is number of classes, and if shape is more than 2D, this
            is (N, C, D1, D2,..., Dk), k >= 1.
        label (Tensor): Label tensor, the data type is float32 or float64. The shape of label
            is the same as the shape of input.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
1077
        delta (float, optional): Specifies the hyperparameter :math:`\delta` to be used.
1078 1079 1080
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
            negative/zero values. Default = 1.0
1081
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1082 1083

    Returns:
1084
        Tensor, The tensor variable storing the smooth_l1_loss of input and label.
1085 1086 1087 1088 1089 1090

    Examples:
        .. code-block:: python

            import paddle

1091 1092
            input = paddle.rand([3, 3]).astype('float32')
            label = paddle.rand([3, 3]).astype('float32')
C
Chen Long 已提交
1093
            output = paddle.nn.functional.smooth_l1_loss(input, label)
G
Guanghua Yu 已提交
1094
            print(output)
1095
            # [0.068004]
1096 1097
    """

1098
    if in_dygraph_mode():
1099
        out = _C_ops.huber_loss(input, label, delta)
1100
    else:
1101 1102 1103 1104 1105 1106
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'smooth_l1_loss'
        )
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'smooth_l1_loss'
        )
1107 1108
        helper = LayerHelper('huber_loss', **locals())
        residual = helper.create_variable_for_type_inference(
1109 1110
            dtype=helper.input_dtype()
        )
1111
        out = helper.create_variable_for_type_inference(
1112 1113 1114 1115 1116 1117 1118 1119
            dtype=helper.input_dtype()
        )
        helper.append_op(
            type='huber_loss',
            inputs={'X': input, 'Y': label},
            outputs={'Out': out, 'Residual': residual},
            attrs={'delta': delta},
        )
1120 1121 1122 1123

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in smooth_l1_loss should be 'sum', 'mean' or"
1124 1125
            " 'none', but received %s, which is not allowed." % reduction
        )
1126 1127 1128
    if reduction == 'none':
        return out
    elif reduction == 'mean':
1129
        return paddle.mean(out)
1130
    elif reduction == 'sum':
1131
        return paddle.sum(out)
1132 1133


1134 1135 1136
def margin_ranking_loss(
    input, other, label, margin=0.0, reduction='mean', name=None
):
1137
    r"""
1138

1139
    Calcluate the margin rank loss between the input, other and label, use the math function as follows.
1140

1141
    .. math::
1142
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        input(Tensor): the first input tensor, it's data type should be float32, float64.
        other(Tensor): the second input tensor, it's data type should be float32, float64.
1159
        label(Tensor): the label value corresponding to input, it's data type should be float32, float64.
1160 1161 1162 1163
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

1164
    Returns:
1165
        Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.
1166 1167 1168 1169 1170

    Examples:

        .. code-block:: python

1171 1172
            import paddle

Z
Zhong Hui 已提交
1173 1174 1175
            input = paddle.to_tensor([[1, 2], [3, 4]], dtype='float32')
            other = paddle.to_tensor([[2, 1], [2, 4]], dtype='float32')
            label = paddle.to_tensor([[1, -1], [-1, -1]], dtype='float32')
1176
            loss = paddle.nn.functional.margin_ranking_loss(input, other, label)
N
Noel 已提交
1177
            print(loss) # [0.75]
1178
    """
1179 1180 1181
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
1182 1183
            "received %s, which is not allowed." % reduction
        )
1184
    if in_dygraph_mode():
1185 1186
        out = _C_ops.subtract(other, input)
        out = _C_ops.multiply(out, label)
1187 1188
        if margin != 0.0:
            margin = fluid.dygraph.base.to_variable([margin], dtype=out.dtype)
1189 1190
            out = _C_ops.add(out, margin)
        out = _C_ops.relu(out)
1191
        if reduction == 'sum':
1192
            return _C_ops.sum(out, [], None, False)
1193
        elif reduction == 'mean':
1194
            return _C_ops.mean_all(out)
1195
        return out
姜永久 已提交
1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
    else:
        helper = LayerHelper("margin_ranking_loss", **locals())
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'margin_rank_loss'
        )
        check_variable_and_dtype(
            other, 'other', ['float32', 'float64'], 'margin_rank_loss'
        )
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'margin_rank_loss'
        )
1207

姜永久 已提交
1208 1209 1210
        out = paddle.subtract(input, other)
        neg_label = paddle.neg(label)
        out = paddle.multiply(neg_label, out)
1211

姜永久 已提交
1212 1213 1214 1215 1216 1217
        if margin != 0.0:
            margin_var = out.block.create_var(dtype=out.dtype)
            margin_var = paddle.full(
                shape=[1], fill_value=margin, dtype=out.dtype
            )
            out = paddle.add(out, margin_var)
1218

姜永久 已提交
1219
        result_out = helper.create_variable_for_type_inference(input.dtype)
1220

姜永久 已提交
1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
        if reduction == 'none':
            helper.append_op(
                type="relu", inputs={"X": out}, outputs={"Out": result_out}
            )
            return result_out
        elif reduction == 'sum':
            out = paddle.nn.functional.relu(out)
            attrs = {"dim": [0], "keep_dim": False, "reduce_all": True}
            helper.append_op(
                type="reduce_sum",
                inputs={"X": out},
                outputs={"Out": result_out},
                attrs=attrs,
            )
            return result_out
        elif reduction == 'mean':
            out = paddle.nn.functional.relu(out)
            helper.append_op(
                type="mean",
                inputs={"X": out},
                outputs={"Out": result_out},
                attrs={},
            )
            return result_out
1245 1246


1247
def l1_loss(input, label, reduction='mean', name=None):
1248
    r"""
1249

1250
    Computes the L1 Loss of Tensor ``input`` and ``label`` as follows.
1251

1252
    If `reduction` set to ``'none'``, the loss is:
1253 1254

    .. math::
1255
        Out = \lvert input - label \rvert
1256

1257
    If `reduction` set to ``'mean'``, the loss is:
1258 1259

    .. math::
1260
        Out = MEAN(\lvert input - label \rvert)
1261

1262
    If `reduction` set to ``'sum'``, the loss is:
1263 1264

    .. math::
1265
        Out = SUM(\lvert input - label \rvert)
1266

1267

1268
    Parameters:
N
Noel 已提交
1269 1270
        input (Tensor): The input tensor. The shapes is [N, `*`], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        label (Tensor): label. The shapes is [N, `*`], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
1271
        reduction (str, optional): Indicate the reduction to apply to the loss,
1272
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
1273 1274 1275
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
1276 1277
            Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
N
Noel 已提交
1278

1279
    Returns:
1280
        Tensor, the L1 Loss of Tensor ``input`` and ``label``.
1281
        If `reduction` is ``'none'``, the shape of output loss is :math:`[N, *]`, the same as ``input`` .
1282
        If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
N
Noel 已提交
1283

1284 1285
    Examples:
        .. code-block:: python
N
Noel 已提交
1286

1287
            import paddle
1288

1289 1290
            input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]])
            label = paddle.to_tensor([[1.7, 1], [0.4, 0.5]])
1291

1292
            l1_loss = paddle.nn.functional.l1_loss(input, label)
1293 1294 1295
            print(l1_loss)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [0.34999999])
1296

1297
            l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='none')
1298 1299 1300 1301
            print(l1_loss)
            # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[0.20000005, 0.19999999],
            #         [0.20000000, 0.79999995]])
1302

1303
            l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='sum')
1304 1305 1306
            print(l1_loss)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [1.39999998])
1307

1308 1309 1310 1311
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
1312 1313
            "received %s, which is not allowed." % reduction
        )
1314

1315
    if in_dygraph_mode():
1316 1317
        unreduced = _C_ops.abs(_C_ops.subtract(input, label))

1318
        if reduction == 'mean':
1319
            return _C_ops.mean_all(unreduced)
1320
        elif reduction == 'sum':
1321
            return _C_ops.sum(unreduced, [], None, False)
1322 1323
        else:
            return unreduced
姜永久 已提交
1324 1325
    else:
        check_variable_and_dtype(
1326 1327 1328 1329
            input,
            'input',
            ['float32', 'float64', 'int32', 'int64'],
            'l1_loss',
姜永久 已提交
1330 1331
        )
        check_variable_and_dtype(
1332 1333 1334 1335
            label,
            'label',
            ['float32', 'float64', 'int32', 'int64'],
            'l1_loss',
1336
        )
1337

姜永久 已提交
1338 1339 1340 1341 1342 1343 1344 1345
        if reduction == 'sum':
            unreduced = paddle.abs(paddle.subtract(x=input, y=label))
            return paddle.sum(unreduced, name=name)
        elif reduction == 'mean':
            unreduced = paddle.abs(paddle.subtract(x=input, y=label))
            return paddle.mean(unreduced, name=name)
        else:
            return paddle.abs(paddle.subtract(x=input, y=label, name=name))
1346 1347 1348 1349 1350


def nll_loss(
    input, label, weight=None, ignore_index=-100, reduction='mean', name=None
):
1351 1352
    """
    This api returns negative log likelihood.
1353 1354
    See more detail in :ref:`NLLLoss <api_paddle_nn_NLLLoss>` .

1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365

    Parameters:
         input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
             But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
             The data type is float32, float64.
         label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
             The data type is int64.
         weight (Tensor, optional): Weight tensor, a manual rescaling weight given
             to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
             it treated as if having all ones. the data type is
             float32, float64, Default is ``'None'``.
1366 1367
         ignore_index (int, optional): Specifies a target value that is ignored
             and does not contribute to the input gradient. Default is -100.
1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381
         reduction (str, optional): Indicate how to average the loss,
             the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
             If `reduction` is ``'mean'``, the reduced mean loss is returned;
             if `reduction` is ``'sum'``, the reduced sum loss is returned;
             if `reduction` is ``'none'``, no reduction will be apllied.
             Default is ``'mean'``.
         name (str, optional): Name for the operation (optional, default is None).
             For more information, please refer to :ref:`api_guide_Name`.

    Returns:
         `Tensor`, the value of negative log likelihood loss.

    Examples:
        .. code-block:: python
1382

1383 1384 1385 1386
                import paddle
                from paddle.nn.functional import nll_loss
                log_softmax = paddle.nn.LogSoftmax(axis=1)

1387 1388 1389 1390 1391
                input = paddle.to_tensor([[0.88103855, 0.9908683 , 0.6226845 ],
                          [0.53331435, 0.07999352, 0.8549948 ],
                          [0.25879037, 0.39530203, 0.698465  ],
                          [0.73427284, 0.63575995, 0.18827209],
                          [0.05689114, 0.0862954 , 0.6325046 ]], "float32")
1392
                log_out = log_softmax(input)
1393
                label = paddle.to_tensor([0, 2, 1, 1, 0], "int64")
1394
                result = nll_loss(log_out, label)
1395
                print(result) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [1.07202101])
1396 1397 1398 1399
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
1400 1401
            "'none', but received %s, which is not allowed." % reduction
        )
1402 1403 1404

    input_shape = list(input.shape)
    input_dims = len(input_shape)
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415
    label_shape = list(label.shape)
    label_dims = len(label_shape)

    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            "Expected input_dims - 1 = label_dims or input_dims == label_dims\
             (got input_dims{}, label_dims{})".format(
                input_dims, label_dims
            )
        )

1416
    if input_dims < 2:
1417
        raise ValueError(f'Expected 2 or more dimensions (got {input_dims})')
1418 1419 1420 1421 1422 1423 1424 1425

    if input_shape[1] < 1:
        raise ValueError(
            "Expected 1 or more classess (got num classes{})".format(
                input_shape[1]
            )
        )

1426 1427
    n = input_shape[0]
    c = input_shape[1]
Z
zyfncg 已提交
1428 1429
    if in_dygraph_mode():
        if input_dims != 2 and input_dims != 4:
1430 1431
            input = _C_ops.reshape(input, [n, c, 1, -1])
            label = _C_ops.reshape(label, [n, 1, -1])
Z
zyfncg 已提交
1432
            out_shape = [n] + input_shape[2:]
1433 1434 1435
        out, total_weight = _C_ops.nll_loss(
            input, label, weight, ignore_index, reduction
        )
Z
zyfncg 已提交
1436
        if input_dims != 2 and input_dims != 4 and reduction == 'none':
1437
            out = _C_ops.reshape(out, out_shape)
Z
zyfncg 已提交
1438
        return out
姜永久 已提交
1439 1440 1441
    else:
        helper = LayerHelper('nll_loss', **locals())

1442
        if input_dims != 2 and input_dims != 4:
姜永久 已提交
1443 1444
            input = reshape(input, shape=[n, c, 1, -1])
            label = reshape(label, shape=[n, 1, -1])
1445
            out_shape = [n] + input_shape[2:]
H
hong 已提交
1446

姜永久 已提交
1447 1448
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'nll_loss'
1449
        )
姜永久 已提交
1450 1451 1452 1453 1454 1455
        check_variable_and_dtype(label, 'label', ['int64'], 'nll_loss')
        inputs = {'X': input, 'Label': label}
        attrs = {'reduction': reduction, 'ignore_index': ignore_index}
        if weight is not None:
            if isinstance(weight, Variable):
                inputs['Weight'] = weight
1456

姜永久 已提交
1457 1458 1459 1460 1461
        out = helper.create_variable_for_type_inference(dtype=input.dtype)
        total_weight = helper.create_variable_for_type_inference(
            dtype=input.dtype
        )
        outputs = {'Out': out, 'Total_weight': total_weight}
1462

姜永久 已提交
1463 1464 1465 1466 1467
        helper.append_op(
            type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs
        )
        if input_dims != 2 and input_dims != 4 and reduction == 'none':
            out = reshape(out, shape=out_shape)
1468

姜永久 已提交
1469
        return out
1470 1471


1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581
def poisson_nll_loss(
    input,
    label,
    log_input=True,
    full=False,
    epsilon=1e-8,
    reduction="mean",
    name=None,
):
    r"""Poisson negative log likelihood loss.
    See more detail in :ref:`PoissonNLLLoss <api_paddle_nn_PoissonNLLLoss>` .

    Parameters:
         input (Tensor):
            Input tensor, expectation of underlying Poisson distribution.
            The shape of input tensor should be `(N, *)` or `(*)` where `(*)` denotes any number of extra dimensions.
            It's data type should be float16, bfloat16, float32, float64.
         label (Tensor):
            Label tensor, random sampled from Poisson distribution :math:`label \sim \text{Poisson}(input)`.
            The shape of input tensor should be `(N, *)` or `(*)`, same shape as the input tensor.
            It's data type should be float16, bfloat16, float32, float64.
         log_input (bool, optional):
            Whether to the treat input tensor as log input.
            If ``True`` the loss is computed as, :math:`\exp(\text{input}) - \text{label} * \text{input}` .
            If ``False`` then loss is :math:`\text{input} - \text{label} * \log(\text{input}+\text{epsilon})` .
            Default: ``True``.
         full (bool, optional):
            Whether to compute full loss.
            If ``True``, the Stirling approximation term is added.
            If ``False``, the Stirling approximation is dropped.
            Default: ``False``.
         epsilon (float, optional):
            A small value to avoid evaluation of :math:`\log(0)` when `log_input`\ =\ ``False``. ``epsilon > 0``.
            Default: 1e-8.
         reduction (str, optional):
            Indicate how to reduce the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
            Default is ``'mean'``.
         name (str, optional):
            Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.randn([5, 2], dtype=paddle.float32)
            label = paddle.randn([5, 2], dtype=paddle.float32)
            loss = F.poisson_nll_loss(input, label, log_input=True, reduction='None')
            print(loss)
            loss = F.poisson_nll_loss(input, label, reduction='mean')
            print(loss)

    """
    # check parameter values
    if epsilon <= 0:
        raise ValueError(
            "The value of `epsilon` in poisson_nll_loss should be positve, but received %f, which is not allowed"
            % epsilon
        )

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in poisson_nll_loss should be 'sum', 'mean' or 'none', but "
            "received %s, which is not allowed." % reduction
        )
    # check input dtype and dimension
    check_variable_and_dtype(
        input,
        'input',
        ['float16', 'uint16', 'float32', 'float64'],
        'poisson_nll_loss',
    )
    check_variable_and_dtype(
        label,
        'label',
        ['float16', 'uint16', 'float32', 'float64'],
        'poisson_nll_loss',
    )

    if not (input.shape == label.shape):
        raise ValueError("input's shape must equal to label's shape")

    label = paddle.cast(label, input.dtype)
    loss_out = 0
    if log_input:
        loss_out = paddle.exp(input) - label * input
    else:
        loss_out = input - label * paddle.log(input + epsilon)
    if full:
        stirling_approx = (
            label * paddle.log(label)
            - label
            + 0.5 * paddle.log(2 * math.pi * label)
        )
        loss_out += paddle.where(
            stirling_approx <= 1,
            paddle.zeros_like(stirling_approx),
            stirling_approx,
        )
    if reduction == 'mean':
        loss_out = paddle.mean(loss_out)
    elif reduction == 'sum':
        loss_out = paddle.sum(loss_out)
    return loss_out


1582
def kl_div(input, label, reduction='mean', name=None):
1583
    r"""
1584
    Calculate the Kullback-Leibler divergence loss
1585 1586 1587 1588 1589 1590 1591
    between Input(X) and Input(Target). Notes that Input(X) is the
    log-probability and Input(Target) is the probability.

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

1592
    Here :math:`x` is input and :math:`y` is label.
1593

1594
    If `reduction` is ``'none'``, the output loss is the same shape as the input, and the loss at each point is calculated separately. There is no reduction to the result.
1595

1596
    If `reduction` is ``'mean'``, the output loss is the shape of [1], and the output is the average of all losses.
1597

1598
    If `reduction` is ``'sum'``, the output loss is the shape of [1], and the output is the sum of all losses.
1599

1600
    If `reduction` is ``'batchmean'``, the output loss is the shape of [N], N is the batch size, and the output is the sum of all losses divided by the batch size.
1601 1602

    Args:
1603
        input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means
1604
            any number of additional dimensions. It's data type should be float32, float64.
1605
        label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64.
1606 1607 1608 1609 1610 1611 1612
        reduction (str, optional): Indicate how to average the loss,
            the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
            If `reduction` is ``'mean'``, the reduced mean loss is returned;
            If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
            if `reduction` is ``'sum'``, the reduced sum loss is returned;
            if `reduction` is ``'none'``, no reduction will be apllied.
            Default is ``'mean'``.
1613
        name(str, optional): Name for the operation (optional, default is None). For more information,
1614 1615 1616 1617 1618 1619 1620 1621 1622 1623
            please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: The KL divergence loss. The data type is same as input tensor

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F
1624

1625
            shape = (5, 20)
1626 1627
            x = paddle.uniform(shape, min=-10, max=10).astype('float32')
            target = paddle.uniform(shape, min=-10, max=10).astype('float32')
1628

L
LielinJiang 已提交
1629
            # 'batchmean' reduction, loss shape will be [1]
1630
            pred_loss = F.kl_div(x, target, reduction='batchmean')
L
LielinJiang 已提交
1631
            # shape=[1]
1632

1633
            # 'mean' reduction, loss shape will be [1]
1634
            pred_loss = F.kl_div(x, target, reduction='mean')
1635 1636 1637
            # shape=[1]

            # 'sum' reduction, loss shape will be [1]
1638
            pred_loss = F.kl_div(x, target, reduction='sum')
1639 1640 1641
            # shape=[1]

            # 'none' reduction, loss shape is same with input shape
1642
            pred_loss = F.kl_div(x, target, reduction='none')
1643 1644 1645
            # shape=[5, 20]

    """
L
LielinJiang 已提交
1646
    # ugly type promotion
1647 1648 1649 1650
    if (
        fluid.data_feeder.convert_dtype(input.dtype) == 'float32'
        and fluid.data_feeder.convert_dtype(label.dtype) == 'float64'
    ):
1651
        input = paddle.cast(input, 'float64')
1652 1653 1654 1655
    elif (
        fluid.data_feeder.convert_dtype(input.dtype) == 'float64'
        and fluid.data_feeder.convert_dtype(label.dtype) == 'float32'
    ):
1656
        label = paddle.cast(label, 'float64')
L
LielinJiang 已提交
1657

1658
    if in_dygraph_mode():
1659
        out = _C_ops.kldiv_loss(input, label, 'none')
1660 1661 1662 1663 1664 1665 1666 1667 1668
        if reduction == 'mean':
            out = paddle.mean(out)
        elif reduction == 'sum':
            out = paddle.sum(out)
        elif reduction == 'batchmean':
            if len(input.shape) > 0:
                batch_size = input.shape[0]
                out = paddle.sum(out) / batch_size
        return out
姜永久 已提交
1669 1670
    else:
        helper = LayerHelper('kl_div', **locals())
1671

姜永久 已提交
1672 1673 1674 1675 1676 1677 1678
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'kl_div'
        )
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'kl_div'
        )
        fluid.data_feeder.check_type(reduction, 'reduction', str, 'kl_div')
1679

姜永久 已提交
1680 1681 1682 1683 1684 1685 1686
        loss = helper.create_variable_for_type_inference(dtype=input.dtype)
        helper.append_op(
            type='kldiv_loss',
            inputs={'X': input, 'Target': label},
            outputs={'Loss': loss},
            attrs={'reduction': 'none'},
        )
1687

姜永久 已提交
1688 1689 1690 1691 1692 1693 1694 1695
        if reduction == 'mean':
            loss = paddle.mean(loss)
        elif reduction == 'sum':
            loss = paddle.sum(loss)
        elif reduction == 'batchmean':
            batch_size = paddle.shape(input)[0]
            loss = paddle.sum(loss) / batch_size
        return loss
1696 1697


1698
def mse_loss(input, label, reduction='mean', name=None):
1699
    r"""
1700
    Accept input predications and label and returns the mean square error.
1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729

    If :attr:`reduction` is set to ``'none'``, loss is calculated as:

    .. math::
        Out = (input - label)^2

    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:

    .. math::
        Out = \operatorname{mean}((input - label)^2)

    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:

    .. math::
        Out = \operatorname{sum}((input - label)^2)

    Parameters:
        input (Tensor): Input tensor, the data type should be float32 or float64.
        label (Tensor): Label tensor, the data type should be float32 or float64.
        reduction (string, optional): The reduction method for the output,
            could be 'none' | 'mean' | 'sum'.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.


    Returns:
1730
        Tensor, The tensor tensor storing the mean square error difference of input and label.
1731

1732 1733 1734
    Examples:

        .. code-block:: python
1735

1736 1737
            import paddle
            mse_loss = paddle.nn.loss.MSELoss()
1738 1739
            input = paddle.to_tensor(1.5)
            label = paddle.to_tensor(1.7)
1740
            output = mse_loss(input, label)
B
Bai Yifan 已提交
1741
            print(output)
1742 1743 1744 1745 1746 1747 1748
            # [0.04000002]

    """

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'mse_loss' should be 'sum', 'mean' or 'none', "
1749 1750
            "but received {}.".format(reduction)
        )
1751

Z
zhiboniu 已提交
1752
    if not in_dynamic_mode():
1753 1754 1755 1756 1757 1758
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'mse_loss'
        )
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'mse_loss'
        )
1759 1760

    if reduction == 'none':
1761
        return paddle.square(paddle.subtract(input, label), name=name)
1762
    elif reduction == 'mean':
1763 1764 1765
        return paddle.mean(
            paddle.square(paddle.subtract(input, label)), name=name
        )
1766
    else:
1767 1768 1769
        return paddle.sum(
            paddle.square(paddle.subtract(input, label)), name=name
        )
1770 1771


1772 1773 1774 1775 1776 1777 1778 1779 1780
def ctc_loss(
    log_probs,
    labels,
    input_lengths,
    label_lengths,
    blank=0,
    reduction='mean',
    norm_by_times=False,
):
1781 1782
    """

1783 1784 1785
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
1786 1787 1788
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
1789
        log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
1790 1791 1792
        labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
1793 1794 1795
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default: 0.
        reduction (str, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default: ``'mean'``.
        norm_by_times (bool, optional): Whether to normalize the gradients by the number of time-step, which is also the sequence's length. There is no need to normalize the gradients if reduction mode is 'mean'. Default: False.
H
Hui Zhang 已提交
1796

1797 1798
    Returns:
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
1799

1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816
    Examples:

        .. code-block:: python

            # declarative mode
            import paddle.nn.functional as F
            import paddle

            # length of the longest logit sequence
            max_seq_length = 4
            #length of the longest label sequence
            max_label_length = 3
            # number of logit sequences
            batch_size = 2
            # class num
            class_num = 3

1817
            log_probs = paddle.to_tensor([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829
                                    [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

                                    [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
                                    [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

                                    [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
                                    [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

                                    [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
                                    [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

                                    [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
1830 1831 1832 1833 1834 1835
                                    [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]],
                                    dtype="float32")
            labels = paddle.to_tensor([[1, 2, 2],
                                    [1, 2, 2]], dtype="int32")
            input_lengths = paddle.to_tensor([5, 5], dtype="int64")
            label_lengths = paddle.to_tensor([3, 3], dtype="int64")
1836

1837 1838 1839 1840
            loss = F.ctc_loss(log_probs, labels,
                input_lengths,
                label_lengths,
                blank=0,
1841
                reduction='none')
1842 1843 1844
            print(loss)
            # Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [3.91798496, 2.90765190])
1845

1846 1847 1848 1849 1850
            loss = F.ctc_loss(log_probs, labels,
                input_lengths,
                label_lengths,
                blank=0,
                reduction='mean')
1851 1852 1853
            print(loss)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [1.13760614])
1854 1855 1856

    """

1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873
    def warpctc(
        input,
        label,
        blank=0,
        norm_by_times=False,
        input_length=None,
        label_length=None,
    ):
        if in_dygraph_mode():
            if input_length is None or label_length is None:
                raise ValueError(
                    "input_length and label_length must not be None in dygraph mode!"
                )
            loss_out = _C_ops.warpctc(
                input, label, input_length, label_length, blank, norm_by_times
            )
            return loss_out
姜永久 已提交
1874 1875
        else:
            helper = LayerHelper('warpctc', **locals())
1876
            check_variable_and_dtype(
姜永久 已提交
1877
                input, 'input', ['float32', 'float64'], "warpctc"
1878
            )
姜永久 已提交
1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889
            check_variable_and_dtype(label, 'label', ['int32'], "warpctc")
            this_inputs = {'Logits': [input], 'Label': [label]}
            if input_length is not None and label_length is not None:
                check_variable_and_dtype(
                    input_length, 'LogitsLength', ['int64'], "warpctc"
                )
                check_variable_and_dtype(
                    label_length, 'LabelLength', ['int64'], "warpctc"
                )
                this_inputs['LogitsLength'] = [input_length]
                this_inputs['LabelLength'] = [label_length]
1890

姜永久 已提交
1891 1892 1893 1894 1895 1896
            loss_out = helper.create_variable_for_type_inference(
                dtype=input.dtype
            )
            grad_out = helper.create_variable_for_type_inference(
                dtype=input.dtype
            )
1897

姜永久 已提交
1898 1899 1900 1901 1902 1903 1904 1905 1906 1907
            helper.append_op(
                type='warpctc',
                inputs=this_inputs,
                outputs={'WarpCTCGrad': [grad_out], 'Loss': [loss_out]},
                attrs={
                    'blank': blank,
                    'norm_by_times': norm_by_times,
                },
            )
            return loss_out
1908 1909

    loss_out = warpctc(
1910 1911
        log_probs, labels, blank, norm_by_times, input_lengths, label_lengths
    )
1912

Z
zhiboniu 已提交
1913
    loss_out = paddle.squeeze(loss_out, [-1])
1914 1915
    assert reduction in ['mean', 'sum', 'none']
    if reduction == 'mean':
S
ShenLiang 已提交
1916
        loss_out = paddle.mean(loss_out / label_lengths)
1917 1918 1919
    elif reduction == 'sum':
        loss_out = paddle.sum(loss_out)
    return loss_out
H
Hui Zhang 已提交
1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043


def rnnt_loss(
    input,
    label,
    input_lengths,
    label_lengths,
    blank=0,
    fastemit_lambda=0.001,
    reduction='mean',
    name=None,
):
    """
    An operator integrating the open source Warp-Transducer library (https://github.com/b-flo/warp-transducer.git)
    to compute Sequence Transduction with Recurrent Neural Networks (RNN-T) loss.

    Parameters:
        input (Tensor): The logprobs sequence with padding, which is a 4-D Tensor. The tensor shape is [B, Tmax, Umax, D], where Tmax, is the longest length of input logit sequence. The data type should be float32 or float64.
        label (Tensor): The ground truth sequence with padding, which must be a 2-D Tensor. The tensor shape is [B, Umax], where Umax is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
        blank (int, optional): The blank label index of RNN-T loss, which is in the half-opened interval [0, B). The data type must be int32. Default is 0.
        fastemit_lambda (float, default 0.001): Regularization parameter for FastEmit (https://arxiv.org/pdf/2010.11148.pdf)
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output will be sum of loss and be divided by the batch_size; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, The RNN-T loss between ``logprobs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``logprobs``.

    Examples:

        .. code-block:: python

            # declarative mode
            import paddle.nn.functional as F
            import numpy as np
            import paddle
            import functools

            fn = functools.partial(F.rnnt_loss, reduction='sum', fastemit_lambda=0.0, blank=0)

            acts = np.array([[[[0.1, 0.6, 0.1, 0.1, 0.1],
                            [0.1, 0.1, 0.6, 0.1, 0.1],
                            [0.1, 0.1, 0.2, 0.8, 0.1]],
                            [[0.1, 0.6, 0.1, 0.1, 0.1],
                            [0.1, 0.1, 0.2, 0.1, 0.1],
                            [0.7, 0.1, 0.2, 0.1, 0.1]]]])
            labels = [[1, 2]]

            acts = paddle.to_tensor(acts, stop_gradient=False)

            lengths = [acts.shape[1]] * acts.shape[0]
            label_lengths = [len(l) for l in labels]
            labels = paddle.to_tensor(labels, paddle.int32)
            lengths = paddle.to_tensor(lengths, paddle.int32)
            label_lengths = paddle.to_tensor(label_lengths, paddle.int32)

            costs = fn(acts, labels, lengths, label_lengths)
            print(costs)
            # Tensor(shape=[1], dtype=float64, place=Place(gpu:0), stop_gradient=False,
            #        [4.49566677])
    """

    def warprnnt(
        input, label, input_length, label_length, blank=0, fastemit_lambda=0.001
    ):
        if in_dygraph_mode():
            loss_out = _C_ops.warprnnt(
                input,
                label,
                input_length,
                label_length,
                blank,
                fastemit_lambda,
            )
            return loss_out
        helper = LayerHelper('warprnnt', **locals())
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], "warprnnt"
        )
        check_variable_and_dtype(label, 'label', ['int32'], "warprnnt")
        check_variable_and_dtype(
            input_length, 'input_lengths', ['int32'], "warprnnt"
        )
        check_variable_and_dtype(
            label_length, 'label_lengths', ['int32'], "warprnnt"
        )
        this_inputs = {
            'input': [input],
            'label': [label],
            'input_lengths': [input_length],
            'label_lengths': [label_length],
        }

        loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
        grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)

        helper.append_op(
            type='warprnnt',
            inputs=this_inputs,
            outputs={'warprnntgrad': [grad_out], 'loss': [loss_out]},
            attrs={
                'blank': blank,
                'fastemit_lambda': fastemit_lambda,
            },
        )
        return loss_out

    B = input.shape[0]

    # NOTE manually done log_softmax for CPU version,
    # log_softmax is computed within GPU version.

    # (B,)
    loss_out = warprnnt(
        input, label, input_lengths, label_lengths, blank, fastemit_lambda
    )

    assert reduction in ['mean', 'sum', 'none']
    if reduction == 'mean':
        loss_out = paddle.sum(loss_out, name=name) / B
    elif reduction == 'sum':
        loss_out = paddle.sum(loss_out, name=name)
    return loss_out
2044 2045


2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056
def margin_cross_entropy(
    logits,
    label,
    margin1=1.0,
    margin2=0.5,
    margin3=0.0,
    scale=64.0,
    group=None,
    return_softmax=False,
    reduction='mean',
):
2057
    r"""
2058 2059
    .. math::

2060
        L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
2061

2062
    where the :math:`\theta_{y_i}` is the angle between the feature :math:`x` and
2063 2064 2065 2066
    the representation of class :math:`i`. The details of ArcFace loss
    could be referred to https://arxiv.org/abs/1801.07698.

    .. hint::
2067 2068 2069 2070
        The API supports single GPU and multi GPU, and don't supports CPU.
        For data parallel mode, set ``group=False``.
        For model parallel mode, set ``group=None`` or the group instance return by paddle.distributed.new_group.
        And logits.shape[-1] can be different at each rank.
2071 2072

    Args:
G
Guoxia Wang 已提交
2073
        logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
2074
                The logits is shard_logits when using model parallel.
G
Guoxia Wang 已提交
2075 2076 2077 2078 2079
        label (Tensor): shape[N] or shape[N, 1], the groud truth label.
        margin1 (float, optional): m1 of margin loss, default value is `1.0`.
        margin2 (float, optional): m2 of margin loss, default value is `0.5`.
        margin3 (float, optional): m3 of margin loss, default value is `0.0`.
        scale (float, optional): s of margin loss, default value is `64.0`.
2080
        group (Group, optional): The group instance return by paddle.distributed.new_group
2081 2082
            or ``None`` for global default group or ``False`` for data parallel (do not communication cross ranks).
            Default is ``None``.
2083 2084 2085 2086 2087 2088 2089 2090
        return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
        reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                    If :attr:`reduction` is ``'mean'``, return the average of loss;
                    If :attr:`reduction` is ``'sum'``, return the sum of loss;
                    If :attr:`reduction` is ``'none'``, no reduction will be applied.
                    Default value is `'mean'`.

    Returns:
2091 2092 2093 2094 2095 2096
        Tensor|tuple[Tensor, Tensor], return the cross entropy loss if
            `return_softmax` is False, otherwise the tuple (loss, softmax),
            softmax is shard_softmax when using model parallel, otherwise
            softmax is in the same shape with input logits. If
            ``reduction == None``, the shape of loss is ``[N, 1]``, otherwise
            the shape is ``[1]``.
2097 2098 2099 2100

    Examples:

    .. code-block:: python
G
Guoxia Wang 已提交
2101
        :name: code-example1
2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135

        # required: gpu
        # Single GPU
        import paddle
        m1 = 1.0
        m2 = 0.5
        m3 = 0.0
        s = 64.0
        batch_size = 2
        feature_length = 4
        num_classes = 4

        label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')

        X = paddle.randn(
            shape=[batch_size, feature_length],
            dtype='float64')
        X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
        X = paddle.divide(X, X_l2)

        W = paddle.randn(
            shape=[feature_length, num_classes],
            dtype='float64')
        W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
        W = paddle.divide(W, W_l2)

        logits = paddle.matmul(X, W)
        loss, softmax = paddle.nn.functional.margin_cross_entropy(
            logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)

        print(logits)
        print(label)
        print(loss)
        print(softmax)
2136

2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149
        #Tensor(shape=[2, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[ 0.85204151, -0.55557678,  0.04994566,  0.71986042],
        #        [-0.20198586, -0.35270476, -0.55182702,  0.09749021]])
        #Tensor(shape=[2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
        #       [2, 3])
        #Tensor(shape=[2, 1], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[82.37059586],
        #        [12.13448420]])
        #Tensor(shape=[2, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[0.99978819, 0.00000000, 0.00000000, 0.00021181],
        #        [0.99992995, 0.00006468, 0.00000000, 0.00000537]])

    .. code-block:: python
G
Guoxia Wang 已提交
2150
        :name: code-example2
2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196

        # required: distributed
        # Multi GPU, test_margin_cross_entropy.py
        import paddle
        import paddle.distributed as dist
        strategy = dist.fleet.DistributedStrategy()
        dist.fleet.init(is_collective=True, strategy=strategy)
        rank_id = dist.get_rank()
        m1 = 1.0
        m2 = 0.5
        m3 = 0.0
        s = 64.0
        batch_size = 2
        feature_length = 4
        num_class_per_card = [4, 8]
        num_classes = paddle.sum(paddle.to_tensor(num_class_per_card))

        label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
        label_list = []
        dist.all_gather(label_list, label)
        label = paddle.concat(label_list, axis=0)

        X = paddle.randn(
            shape=[batch_size, feature_length],
            dtype='float64')
        X_list = []
        dist.all_gather(X_list, X)
        X = paddle.concat(X_list, axis=0)
        X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
        X = paddle.divide(X, X_l2)

        W = paddle.randn(
            shape=[feature_length, num_class_per_card[rank_id]],
            dtype='float64')
        W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
        W = paddle.divide(W, W_l2)

        logits = paddle.matmul(X, W)
        loss, softmax = paddle.nn.functional.margin_cross_entropy(
            logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)

        print(logits)
        print(label)
        print(loss)
        print(softmax)

2197
        # python -m paddle.distributed.launch --gpus=0,1 test_margin_cross_entropy.py
2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240
        ## for rank0 input
        #Tensor(shape=[4, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[ 0.32888934,  0.02408748, -0.02763289,  0.18173063],
        #        [-0.52893978, -0.10623845, -0.21596515, -0.06432517],
        #        [-0.00536345, -0.03924667,  0.66735314, -0.28640926],
        #        [-0.09907366, -0.48534973, -0.10365338, -0.39472322]])
        #Tensor(shape=[4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
        #       [11, 1 , 10, 11])

        ## for rank1 input
        #Tensor(shape=[4, 8], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
        #       [[ 0.68654754,  0.28137170,  0.69694954, -0.60923933, -0.57077653,  0.54576703, -0.38709028,  0.56028204],
        #        [-0.80360371, -0.03042448, -0.45107338,  0.49559349,  0.69998950, -0.45411693,  0.61927630, -0.82808600],
        #        [ 0.11457570, -0.34785879, -0.68819499, -0.26189226, -0.48241491, -0.67685711,  0.06510185,  0.49660849],
        #        [ 0.31604851,  0.52087884,  0.53124749, -0.86176582, -0.43426329,  0.34786144, -0.10850784,  0.51566383]])
        #Tensor(shape=[4], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
        #       [11, 1 , 10, 11])

        ## for rank0 output
        #Tensor(shape=[4, 1], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[38.96608230],
        #        [81.28152394],
        #        [69.67229865],
        #        [31.74197251]])
        #Tensor(shape=[4, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[0.00000000, 0.00000000, 0.00000000, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000000],
        #        [0.00000000, 0.00000000, 0.99998205, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000000]])
        ## for rank1 output
        #Tensor(shape=[4, 1], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
        #       [[38.96608230],
        #        [81.28152394],
        #        [69.67229865],
        #        [31.74197251]])
        #Tensor(shape=[4, 8], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
        #       [[0.33943993, 0.00000000, 0.66051859, 0.00000000, 0.00000000, 0.00004148, 0.00000000, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000207, 0.99432097, 0.00000000, 0.00567696, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00001795],
        #        [0.00000069, 0.33993085, 0.66006319, 0.00000000, 0.00000000, 0.00000528, 0.00000000, 0.00000000]])
    """

    assert reduction in ['mean', 'sum', 'none', None]
2241
    if not (group is False or group is None or hasattr(group, 'is_member')):
2242 2243
        raise ValueError(
            'Expected group is False, None or instance of paddle.distributed.collective.Group \
2244 2245 2246 2247
             (got group: {})'.format(
                group
            )
        )
2248 2249 2250
        return

    if hasattr(group, 'is_member') and not group.is_member():
2251 2252
        return

2253
    ring_id = 0
2254 2255
    rank = 0
    nranks = 1
2256
    if group is not False:
2257 2258 2259 2260
        ring_id = 0 if group is None else group.id
        if core.is_compiled_with_dist():
            parallel_env = paddle.distributed.ParallelEnv()
            global_rank = parallel_env.rank
2261 2262 2263 2264 2265
            rank = (
                global_rank
                if group is None
                else group.get_group_rank(global_rank)
            )
2266
            nranks = parallel_env.world_size if group is None else group.nranks
2267 2268 2269 2270 2271

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
2272
            'Expected input_dims - 1 = label_dims or input_dims == label_dims\
2273
             (got input_dims{}, label_dims{})'.format(
2274 2275 2276
                input_dims, label_dims
            )
        )
2277 2278 2279
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

2280
    if in_dygraph_mode():
2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292
        softmax, loss = _C_ops.margin_cross_entropy(
            logits,
            label,
            return_softmax,
            ring_id,
            rank,
            nranks,
            margin1,
            margin2,
            margin3,
            scale,
        )
2293 2294 2295 2296 2297 2298 2299 2300
        if reduction == 'mean':
            loss = paddle.mean(loss)
        elif reduction == 'sum':
            loss = paddle.sum(loss)
        if not return_softmax:
            return loss
        else:
            return loss, softmax
姜永久 已提交
2301 2302 2303 2304 2305 2306 2307
    else:
        op_type = 'margin_cross_entropy'
        helper = LayerHelper(op_type, **locals())
        softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
        loss = helper.create_variable_for_type_inference(dtype=logits.dtype)

        check_variable_and_dtype(
2308
            logits,
姜永久 已提交
2309 2310 2311
            'logits',
            ['float16', 'float32', 'float64'],
            'margin_cross_entropy',
2312
        )
姜永久 已提交
2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332
        check_variable_and_dtype(
            label, 'label', ['int32', 'int64'], 'margin_cross_entropy'
        )

        helper.append_op(
            type=op_type,
            inputs={'Logits': logits, 'Label': label},
            outputs={'Softmax': softmax, 'Loss': loss},
            attrs={
                'return_softmax': return_softmax,
                'ring_id': ring_id,
                'rank': rank,
                'nranks': nranks,
                'margin1': margin1,
                'margin2': margin2,
                'margin3': margin3,
                'scale': scale,
            },
        )

2333 2334 2335 2336
        if reduction == 'mean':
            loss = paddle.mean(loss)
        elif reduction == 'sum':
            loss = paddle.sum(loss)
姜永久 已提交
2337

2338 2339 2340 2341 2342 2343
        if not return_softmax:
            return loss
        else:
            return loss, softmax


2344 2345 2346 2347
@deprecated(
    since="2.0.0",
    update_to="paddle.nn.functional.cross_entropy",
    level=1,
2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361
    reason=(
        'Please notice that behavior of "paddle.nn.functional.softmax_with_cross_entropy" '
        'and "paddle.nn.functional.cross_entropy" is different.'
    ),
)
def softmax_with_cross_entropy(
    logits,
    label,
    soft_label=False,
    ignore_index=-100,
    numeric_stable_mode=True,
    return_softmax=False,
    axis=-1,
):
2362
    r"""
2363 2364
    This operator implements the cross entropy loss function with softmax. This function
    combines the calculation of the softmax operation and the cross entropy loss function
2365 2366 2367 2368 2369 2370
    to provide a more numerically stable gradient.

    Because this operator performs a softmax on logits internally, it expects
    unscaled logits. This operator should not be used with the output of
    softmax operator since that would produce incorrect results.

2371 2372 2373
    When the attribute :attr:`soft_label` is set :attr:`False`, this operators
    expects mutually exclusive hard labels, each sample in a batch is in exactly
    one class with a probability of 1.0. Each sample in the batch will have a
2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399
    single label.

    The equation is as follows:

    1) Hard label (one-hot label, so every sample has exactly one class)

    .. math::
        \\loss_j=-\text{logits}_{label_j} +\log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right), j = 1,..., K

    2) Soft label (each sample can have a distribution over all classes)

    .. math::
        \\loss_j= -\sum_{i=0}^{K}\text{label}_i\left(\text{logits}_i - \log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right)\right), j = 1,...,K

    3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:

    .. math::
        \\max_j&=\max_{i=0}^{K}{\text{logits}_i} \\
                log\_max\_sum_j &= \log\sum_{i=0}^{K}\exp(logits_i - max_j)\\
                softmax_j &= \exp(logits_j - max_j - {log\_max\_sum}_j)

    and then cross entropy loss is calculated by softmax and label.

    Args:
        logits (Tensor): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
        label (Tensor): The ground truth  ``Tensor`` , data type is the same
2400 2401 2402
            as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`,
            Label is a ``Tensor``  in the same shape with :attr:`logits`.
            If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor``
2403 2404 2405 2406 2407
            in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
        soft_label (bool, optional): A flag to indicate whether to interpretant the given
            labels as soft labels. Default False.
        ignore_index (int, optional): Specifies a target value that is ignored and does
                                      not contribute to the input gradient. Only valid
2408
                                      if :attr:`soft_label` is set to :attr:`False`.
2409 2410 2411
                                      Default: kIgnoreIndex(-100).
        numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
                                              numerically stable algorithm. Only valid
2412 2413 2414
                                              when :attr:`soft_label` is :attr:`False`
                                              and GPU is used. When :attr:`soft_label`
                                              is :attr:`True` or CPU is used, the
2415 2416 2417 2418 2419
                                              algorithm is always numerically stable.
                                              Note that the speed may be slower when use
                                              stable algorithm. Default: True.
        return_softmax (bool, optional): A flag indicating whether to return the softmax
                                         along with the cross entropy loss. Default: False.
2420
        axis (int, optional): The index of dimension to perform softmax calculations. It
2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435
                              should be in range :math:`[-1, rank - 1]`, while :math:`rank`
                              is the rank of input :attr:`logits`. Default: -1.

    Returns:
        ``Tensor`` or Tuple of two ``Tensor`` : Return the cross entropy loss if \
                                                    `return_softmax` is False, otherwise the tuple \
                                                    (loss, softmax), softmax is in the same shape \
                                                    with input logits and cross entropy loss is in \
                                                    the same shape with input logits except shape \
                                                    in dimension :attr:`axis` as 1.

    Examples:
        .. code-block:: python

            import paddle
2436 2437 2438 2439 2440

            logits = paddle.to_tensor([0.4, 0.6, 0.9], dtype="float32")
            label = paddle.to_tensor([1], dtype="int64")

            out = paddle.nn.functional.softmax_with_cross_entropy(logits=logits, label=label)
2441
            print(out)
2442 2443
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [1.15328646])
2444
    """
2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466
    return fluid_softmax_with_cross_entropy(
        logits,
        label,
        soft_label,
        ignore_index,
        numeric_stable_mode,
        return_softmax,
        axis,
    )


def cross_entropy(
    input,
    label,
    weight=None,
    ignore_index=-100,
    reduction='mean',
    soft_label=False,
    axis=-1,
    use_softmax=True,
    name=None,
):
2467
    r"""
2468

2469
    By default, the cross entropy loss function is implemented using softmax. This function
2470 2471
    combines the calculation of the softmax operation and the cross entropy loss function
    to provide a more numerically stable computing.
2472

2473
    Calculate the cross entropy loss function without softmax when use_softmax=False.
2474

2475
    By default, calculate the mean of the result, and you can also affect
2476
    the default behavior by using the reduction parameter. Please refer to the part of
2477
    parameters for details.
2478

2479
    Can be used to calculate the softmax cross entropy loss with soft and hard labels.
2480
    Where, the hard labels mean the actual label value, 0, 1, 2, etc.  And the soft labels
2481
    mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
2482

2483
    The calculation includes the following two steps.
2484

2485
    - **1.softmax cross entropy**
2486

2487
        1. Hard label (each sample can only be assigned into one category)
2488

2489
        1.1. when use_softmax=True
2490

2491 2492
            .. math::
              \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N
2493

2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534
            where, N is the number of samples and C is the number of categories.

        1.2. when use_softmax=False

            .. math::
              \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).


        2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1).

        2.1. when use_softmax=True

            .. math::
              \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories.

        2.2. when use_softmax=False

            .. math::
              \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).




    - **2. Weight and reduction processing**

        1. Weight

            If the ``weight`` parameter is ``None`` , go to the next step directly.

            If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight
            according to soft_label = False or True as follows.

            1.1. Hard labels (soft_label = False)

            .. math::
2535
                \\loss_j=loss_j*weight[label_j]
2536

2537

2538 2539 2540 2541 2542 2543 2544
            1.2. Soft labels (soft_label = True)

             .. math::
                \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right)

        2. reduction

2545
            2.1 if the ``reduction`` parameter is ``none``
2546 2547 2548

                Return the previous result directly

2549
            2.2 if the ``reduction`` parameter is ``sum``
2550 2551 2552 2553 2554 2555

                Return the sum of the previous results

            .. math::
               \\loss=\sum_{j}loss_j

2556 2557
            2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to
            the ``weight`` parameter as follows.
2558

2559
            2.3.1. If the  ``weight``  parameter is ``None``
2560 2561 2562

                   Return the average value of the previous results

2563
            .. math::
2564 2565 2566 2567 2568 2569 2570 2571
                \\loss=\sum_{j}loss_j/N

                  where, N is the number of samples and C is the number of categories.

            2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned

            1. Hard labels (soft_label = False)

2572
            .. math::
2573
                \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j]
2574 2575 2576

            2. Soft labels (soft_label = True)

2577
            .. math::
2578
                \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right)
2579 2580


2581
    Parameters:
2582
        input (Tensor): the data type is float32, float64. Shape is :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes, ``k >= 1`` .
2583

2584
            Note:
2585
                1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the output of softmax operator, which will produce incorrect results.
2586
                2. when use_softmax=False, it expects the output of softmax operator.
2587

2588
        label (Tensor):
2589 2590 2591 2592
            1. If soft_label=False, the shape is
            :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1.
            the data type is int32, int64, float32, float64, where each value is [0, C-1].

2593
            2. If soft_label=True, the shape and data type should be same with ``input`` ,
2594 2595
            and the sum of the labels for each sample should be 1.

2596
        weight (Tensor, optional): a manual rescaling weight given to each class.
2597
            If given, has to be a Tensor of size C and the data type is float32, float64.
2598
            Default is ``'None'`` .
2599
        ignore_index (int64, optional): Specifies a target value that is ignored
2600 2601
            and does not contribute to the loss. A negative value means that no label
            value needs to be ignored. Only valid when soft_label = False.
2602
            Default is ``-100`` .
2603
        reduction (str, optional): Indicate how to average the loss by batch_size,
2604 2605
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
H
Hui Zhang 已提交
2606
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
2607 2608
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
2609 2610
        soft_label (bool, optional): Indicate whether label is soft. Default is ``False``.
        axis (int, optional):The index of dimension to perform softmax calculations.
2611 2612
            It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the
            number of dimensions of input :attr:`input`.
2613
            Default is ``-1`` .
2614
        use_softmax (bool, optional): Indicate whether compute softmax before cross_entropy.
2615
            Default is ``True``.
2616
        name (str, optional): The name of the operator. Default is ``None`` .
2617
            For more information, please refer to :ref:`api_guide_Name` .
2618 2619 2620

    Returns:

2621 2622
        Tensor. Return the softmax cross_entropy loss of ``input`` and ``label``.
        The data type is the same as input.
2623

2624
        If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``.
2625

2626
        If :attr:`reduction` is ``'none'``:
C
Chen Long 已提交
2627

2628
        1. If soft_label = False, the dimension of return value is the same with ``label`` .
C
Chen Long 已提交
2629

2630
        2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` .
2631

2632
    Examples:
2633
        .. code-block:: python
2634 2635

            # hard labels
2636 2637 2638 2639 2640
            import paddle
            paddle.seed(99999)
            N=100
            C=200
            reduction='mean'
2641
            input =  paddle.rand([N, C], dtype='float64')
2642
            label =  paddle.randint(0, C, shape=[N], dtype='int64')
2643 2644
            weight = paddle.rand([C], dtype='float64')

2645 2646 2647
            cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
                weight=weight, reduction=reduction)
            dy_ret = cross_entropy_loss(
2648 2649 2650 2651 2652
                                        input,
                                        label)
            print(dy_ret)
            # Tensor(shape=[1], dtype=float64, place=Place(gpu:0), stop_gradient=True,
            #        [5.34043430])
2653 2654

        .. code-block:: python
2655 2656

            # soft labels
2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669
            import paddle
            paddle.seed(99999)
            axis = -1
            ignore_index = -100
            N = 4
            C = 3
            shape = [N, C]
            reduction='mean'
            weight = None
            logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels /= paddle.sum(labels, axis=axis, keepdim=True)
            paddle_loss_mean = paddle.nn.functional.cross_entropy(
2670 2671 2672 2673 2674 2675 2676 2677 2678
                                                                    logits,
                                                                    labels,
                                                                    soft_label=True,
                                                                    axis=axis,
                                                                    weight=weight,
                                                                    reduction=reduction)
            print(paddle_loss_mean)
            # Tensor(shape=[1], dtype=float64, place=Place(gpu:0), stop_gradient=True,
            #        [1.11043464])
C
Chen Long 已提交
2679

2680 2681 2682 2683
    """

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
2684 2685
            "The value of 'reduction' in softmax_cross_entropy"
            "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
2686 2687
            % reduction
        )
2688
    if ignore_index > 0 and soft_label:
2689 2690
        raise ValueError(
            "When soft_label == True, the value of 'ignore_index' in softmax_cross_entropy"
2691 2692 2693
            "should be '-100', but received %s, which is not allowed."
            % ignore_index
        )
2694

2695
    input_dims = len(list(input.shape))
2696 2697 2698
    if input_dims == 0:
        raise ValueError('The dimention of input should be larger than zero!')

2699 2700 2701
    label_dims = len(list(label.shape))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=axis)
2702

2703
    if in_dygraph_mode():
2704
        if not soft_label:
2705 2706 2707
            valid_label = (
                paddle.cast(label != ignore_index, dtype=label.dtype) * label
            )
2708 2709 2710
        if core.is_compiled_with_custom_device(
            "npu"
        ) or core.is_compiled_with_custom_device("mlu"):
2711
            if not soft_label:
2712
                _, out = _legacy_C_ops.softmax_with_cross_entropy(
2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725
                    input,
                    valid_label,
                    'soft_label',
                    soft_label,
                    'ignore_index',
                    ignore_index,
                    'numeric_stable_mode',
                    True,
                    'axis',
                    axis,
                    'use_softmax',
                    use_softmax,
                )
2726
            else:
2727
                _, out = _legacy_C_ops.softmax_with_cross_entropy(
2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740
                    input,
                    label,
                    'soft_label',
                    soft_label,
                    'ignore_index',
                    ignore_index,
                    'numeric_stable_mode',
                    True,
                    'axis',
                    axis,
                    'use_softmax',
                    use_softmax,
                )
2741
        else:
2742 2743 2744
            _, out = _C_ops.cross_entropy_with_softmax(
                input, label, soft_label, use_softmax, True, ignore_index, axis
            )
2745 2746 2747 2748

        if weight is not None:

            # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
2749
            if soft_label:
2750 2751 2752 2753
                # chajchaj:
                # weight's shape is C, where C is class num.
                # for 1d case: label's shape is [N,C], weight_gather's shape is N.
                # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
2754 2755 2756 2757 2758 2759
                weight_gather = paddle.matmul(
                    x=paddle.cast(label, weight.dtype),
                    y=weight,
                    transpose_x=False,
                    transpose_y=True,
                )
2760 2761 2762 2763
                out_shape = list(out.shape)
                weight_gather_reshape = reshape(weight_gather, shape=out_shape)
                out = paddle.cast(out, weight_gather_reshape.dtype)

2764
                out = _C_ops.multiply(out, weight_gather_reshape)
2765 2766 2767 2768 2769
            else:
                if input.shape[axis] != weight.shape[-1]:
                    raise ValueError(
                        "input's class_dimension({}) must equal to "
                        "weight's class_dimension({}) "
2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781
                        "when weight is provided".format(
                            input.shape[axis], weight.shape[-1]
                        )
                    )

                ignore_weight_mask = paddle.cast(
                    (label != ignore_index), out.dtype
                )
                if (
                    ignore_weight_mask.ndim > 1
                    and ignore_weight_mask.shape[axis] == 1
                ):
2782
                    # TODO: Temporarily use squeeze instead of squeeze_
2783 2784 2785
                    ignore_weight_mask = paddle.squeeze(
                        ignore_weight_mask, axis
                    )
2786
                if axis != -1 and axis != valid_label.ndim - 1:
2787 2788 2789 2790 2791 2792 2793 2794 2795
                    temp_perm = (
                        list(range(axis % valid_label.ndim))
                        + list(
                            range(
                                (axis % valid_label.ndim + 1), valid_label.ndim
                            )
                        )
                        + [axis % valid_label.ndim]
                    )
2796
                    weight_gather = _C_ops.gather_nd(
2797 2798
                        weight, valid_label.transpose(temp_perm)
                    )
2799
                else:
2800
                    weight_gather = _C_ops.gather_nd(weight, valid_label)
2801 2802 2803
                weight_gather = _C_ops.multiply(
                    weight_gather, ignore_weight_mask
                )
2804
                input_shape = list(label.shape)
2805 2806 2807
                weight_gather_reshape = reshape(
                    weight_gather, shape=input_shape
                )
2808
                out = paddle.cast(out, weight_gather_reshape.dtype)
2809
                out = _C_ops.multiply(out, weight_gather_reshape)
2810 2811 2812 2813 2814

        if reduction == "sum":
            #   because of fluid_softmax_with_cross_entropy op's inner logic,
            #   in the out tensor of this op, the loss of sample with class_index==ignore_index is 0
            #   so, reduce_sum all directly is ok
2815
            return _C_ops.sum(out, [], None, False)
2816 2817 2818 2819 2820 2821 2822
        elif reduction == "mean":
            # 1. if weight==none,
            #     numerator: reduce_sum all loss directly is ok causeof fluid_softmax_with_cross_entropy's inner logic
            #     denominator: count sample num with class_index!=ignore_index
            # 2. else
            #     numerator: loss's weighted sum
            #     denominator: cal the sum of weight where the sample's class_index!=ignore_index
H
huangjun12 已提交
2823 2824 2825
            is_ignore = label == ignore_index
            mask = ~is_ignore
            if paddle.count_nonzero(is_ignore) > 0:  # ignore label
2826
                out_sum = _C_ops.sum(out, [], None, False)
2827 2828 2829 2830 2831
                # for each label[i],set 1 or 0, according to ignore_index
                # mask[i]=0, if label[i]==ignore_index
                # mask[i]=1, otherwise
                if weight is None:
                    mask = paddle.cast(mask, dtype=out_sum.dtype)
2832
                    count = _C_ops.sum(mask, [], None, False)
2833 2834 2835
                    ret = out_sum / (count + (count == 0.0))
                else:
                    mask = paddle.cast(mask, weight_gather_reshape.dtype)
2836 2837 2838
                    weight_ignored = _C_ops.multiply(
                        mask, weight_gather_reshape
                    )
2839
                    weight_sum = _C_ops.sum(weight_ignored, [], None, False)
2840 2841 2842
                    ret = out_sum / (weight_sum + (weight_sum == 0.0))
                return ret
            elif weight is not None:
2843
                out_sum = _C_ops.sum(out, [], None, False)
2844 2845 2846
                total_weight = _C_ops.sum(
                    weight_gather_reshape, [], None, False
                )
2847 2848
                return out_sum / (total_weight + (total_weight == 0.0))
            else:
2849
                return _C_ops.mean_all(out)
2850 2851 2852 2853 2854 2855

        else:
            if input_dims - 1 == label_dims:
                out = paddle.squeeze(out, axis=axis)
            return out

姜永久 已提交
2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886
    else:
        check_variable_and_dtype(
            input,
            'input',
            ['float16', 'float32', 'float64'],
            'softmax_cross_entropy',
        )
        check_variable_and_dtype(
            label,
            'label',
            ['uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
            'softmax_cross_entropy',
        )
        attrs = {
            'soft_label': soft_label,
            'ignore_index': ignore_index,
            'numeric_stable_mode': True,
            'axis': axis,
            'use_softmax': use_softmax,
        }
        helper = LayerHelper('softmax_with_cross_entropy', **locals())
        softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
        out = helper.create_variable_for_type_inference(dtype=input.dtype)

        outputs = {'Softmax': softmax, 'Loss': out}
        helper.append_op(
            type='softmax_with_cross_entropy',
            inputs={'Logits': input, 'Label': label},
            outputs=outputs,
            attrs=attrs,
        )
2887

2888
        if weight is not None:
姜永久 已提交
2889 2890 2891 2892 2893 2894 2895
            check_variable_and_dtype(
                weight,
                'weight',
                ['float32', 'float64'],
                'softmax_cross_entropy',
            )
            weight_name = name if reduction == 'none' else None
2896
            if soft_label:
2897
                # chajchaj:
姜永久 已提交
2898
                # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
H
HydrogenSulfate 已提交
2899
                # weight's shape is C, where C is class num.
2900 2901
                # for 1d case: label's shape is [N,C], weight_gather's shape is N.
                # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
2902 2903 2904 2905 2906 2907
                weight_gather = paddle.matmul(
                    x=paddle.cast(label, weight.dtype),
                    y=weight,
                    transpose_x=False,
                    transpose_y=True,
                )
姜永久 已提交
2908

2909 2910 2911 2912
                out_shape = list(out.shape)
                weight_gather_reshape = reshape(weight_gather, shape=out_shape)
                out = paddle.cast(out, weight_gather_reshape.dtype)
            else:
2913 2914 2915 2916
                if input.shape[axis] != weight.shape[-1]:
                    raise ValueError(
                        "input's class_dimension({}) must equal to "
                        "weight's class_dimension({}) "
2917 2918 2919 2920 2921
                        "when weight is provided".format(
                            input.shape[axis], weight.shape[-1]
                        )
                    )

姜永久 已提交
2922 2923 2924
                valid_label = paddle.multiply(
                    paddle.cast(label != ignore_index, dtype=label.dtype), label
                )
2925
                ignore_weight_mask = paddle.cast(
姜永久 已提交
2926
                    (label != ignore_index), input.dtype
2927 2928 2929 2930 2931 2932 2933 2934
                )
                if (
                    ignore_weight_mask.ndim > 1
                    and ignore_weight_mask.shape[axis] == 1
                ):
                    ignore_weight_mask = paddle.squeeze(
                        ignore_weight_mask, axis
                    )
H
HydrogenSulfate 已提交
2935
                if axis != -1 and axis != valid_label.ndim - 1:
2936 2937 2938 2939 2940 2941 2942 2943 2944
                    temp_perm = (
                        list(range(axis % valid_label.ndim))
                        + list(
                            range(
                                (axis % valid_label.ndim + 1), valid_label.ndim
                            )
                        )
                        + [axis % valid_label.ndim]
                    )
姜永久 已提交
2945 2946
                    weight_gather = paddle.gather_nd(
                        weight, paddle.transpose(valid_label, temp_perm)
2947
                    )
2948
                else:
姜永久 已提交
2949 2950
                    weight_gather = paddle.gather_nd(weight, valid_label)
                weight_gather = paddle.multiply(
2951 2952
                    weight_gather, ignore_weight_mask
                )
姜永久 已提交
2953

2954
                input_shape = list(label.shape)
2955 2956 2957
                weight_gather_reshape = reshape(
                    weight_gather, shape=input_shape
                )
姜永久 已提交
2958
            out = paddle.multiply(out, weight_gather_reshape, name=weight_name)
2959

2960
        if reduction == "sum":
姜永久 已提交
2961
            return paddle.sum(out, name=name)
2962
        elif reduction == "mean":
姜永久 已提交
2963 2964
            if ignore_index >= 0:
                out_sum = paddle.sum(out, name=name)
H
HydrogenSulfate 已提交
2965 2966 2967
                # for each label[i],set 1 or 0, according to ignore_index
                # mask[i]=0, if label[i]==ignore_index
                # mask[i]=1, otherwise
姜永久 已提交
2968
                mask = label != ignore_index
2969
                if weight is None:
2970
                    mask = paddle.cast(mask, dtype=out_sum.dtype)
姜永久 已提交
2971
                    count = paddle.sum(mask, name=name)
2972
                    ret = out_sum / (count + (count == 0.0))
2973 2974
                else:
                    mask = paddle.cast(mask, weight_gather_reshape.dtype)
姜永久 已提交
2975
                    weight_ignored = paddle.multiply(
2976 2977
                        mask, weight_gather_reshape
                    )
姜永久 已提交
2978
                    weight_sum = paddle.sum(weight_ignored, name=name)
2979
                    ret = out_sum / (weight_sum + (weight_sum == 0.0))
2980 2981
                return ret
            elif weight is not None:
姜永久 已提交
2982 2983
                out_sum = paddle.sum(out, name=name)
                total_weight = paddle.sum(weight_gather_reshape)
2984
                return out_sum / (total_weight + (total_weight == 0.0))
2985
            else:
姜永久 已提交
2986 2987
                return paddle.mean(out, name=name)

2988
        else:
2989 2990 2991
            if input_dims - 1 == label_dims:
                out = paddle.squeeze(out, axis=axis)

姜永久 已提交
2992
            return out
2993 2994


2995 2996 2997 2998 2999 3000 3001 3002 3003
def sigmoid_focal_loss(
    logit,
    label,
    normalizer=None,
    alpha=0.25,
    gamma=2.0,
    reduction='sum',
    name=None,
):
3004
    r"""
3005 3006 3007 3008 3009 3010
    `Focal Loss <https://arxiv.org/abs/1708.02002>`_ is proposed to address the
    foreground-background class imbalance for classification tasks. It down-weights
    easily-classified examples and thus focuses training on hard examples. For example,
    it is used in one-stage object detection where the foreground-background class
    imbalance is extremely high.

3011
    This operator measures focal loss function as follows:
3012 3013

    .. math::
3014
           Out = -Labels * alpha * {(1 - \sigma(Logit))}^{gamma}\log(\sigma(Logit)) - (1 - Labels) * (1 - alpha) * {\sigma(Logit)}^{gamma}\log(1 - \sigma(Logit))
3015

3016
    We know that :math:`\sigma(Logit) = \frac{1}{1 + \exp(-Logit)}`.
3017 3018 3019 3020 3021

    Then, if :attr:`normalizer` is not None, this operator divides the
    normalizer tensor on the loss `Out`:

    .. math::
3022
           Out = \frac{Out}{normalizer}
3023 3024 3025 3026 3027 3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target ``label`` is 0 for the negative class and is 1 for the positive class.

    Args:
        logit (Tensor): The input logit tensor. The shape is [N, *], where N is batch_size,
            `*` means any number of additional dimensions. The ``logit`` is usually the
            output of a convolution layer. Available dtype is float32, float64.
        label (Tensor): The target label tensor with the same shape as
            ``logit``. The target label whose value should be numbers between 0 and 1.
            Available dtype is float32, float64.
        normalizer (Tensor, optional): The number normalizes the focal loss. It has to be
3039 3040
            a 1-D Tensor with shape `[1, ]` or 0-D Tensor with shape `[]`. The data type
            is float32, float64. For object detection task, it is the number of positive samples.
3041 3042
            If set to None, the focal loss will not be normalized. Default is None.
        alpha(int|float, optional): Hyper-parameter to balance the positive and negative example,
3043
            it should be between 0 and 1.  Default value is set to 0.25.
3044 3045 3046 3047 3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063 3064 3065 3066 3067
        gamma(int|float, optional): Hyper-parameter to modulate the easy and hard examples.
            Default value is set to 2.0.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'sum'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as ``logit``. The same dtype as ``logit`` tensor.

    Examples:

        .. code-block:: python

            import paddle

            logit = paddle.to_tensor([[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]], dtype='float32')
            label = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32')
            one = paddle.to_tensor([1.], dtype='float32')
            fg_label = paddle.greater_equal(label, one)
3068
            fg_num = paddle.sum(paddle.cast(fg_label, dtype='float32'))
3069
            output = paddle.nn.functional.sigmoid_focal_loss(logit, label, normalizer=fg_num)
3070
            print(output)  # [0.65782464]
3071 3072 3073 3074 3075 3076

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in sigmoid_focal_loss "
            "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
3077 3078
            % reduction
        )
3079 3080

    if normalizer is not None:
3081 3082 3083 3084 3085 3086
        check_variable_and_dtype(
            normalizer,
            'normalizer',
            ['float32', 'float64'],
            'sigmoid_focal_loss',
        )
3087 3088 3089 3090
        normalizer_shape = list(normalizer.shape)
        normalizer_dims = len(normalizer_shape)
        if normalizer_dims > 1:
            raise ValueError(
3091
                "Expected zero or one dimension of normalizer in sigmoid_focal_loss but got {}.".format(
3092 3093 3094
                    normalizer_dims
                )
            )
3095

3096 3097
    if in_dygraph_mode():
        place = _current_expected_place()
3098
        one = _C_ops.full(logit.shape, float(1.0), logit.dtype, place)
3099

3100 3101 3102
        loss = _C_ops.sigmoid_cross_entropy_with_logits(
            logit, label, False, -100
        )
3103

3104
        pred = _C_ops.sigmoid(logit)
3105

3106 3107
        p_t = _C_ops.add(
            _C_ops.multiply(pred, label),
3108 3109 3110 3111
            _C_ops.multiply(
                _C_ops.subtract(one, pred), _C_ops.subtract(one, label)
            ),
        )
3112 3113

        alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype)
3114 3115
        alpha_t = _C_ops.add(
            _C_ops.multiply(alpha, label),
3116 3117 3118 3119
            _C_ops.multiply(
                _C_ops.subtract(one, alpha), _C_ops.subtract(one, label)
            ),
        )
3120
        loss = _C_ops.multiply(alpha_t, loss)
3121 3122

        gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype)
3123 3124
        gamma_t = _C_ops.pow(_C_ops.subtract(one, p_t), gamma)
        loss = _C_ops.multiply(gamma_t, loss)
3125 3126

        if normalizer is not None:
3127
            loss = _C_ops.divide(loss, normalizer)
3128 3129

        if reduction == "sum":
3130
            return _C_ops.sum(loss, [], None, False)
3131
        elif reduction == "mean":
3132
            return _C_ops.mean_all(loss)
3133 3134 3135

        return loss

姜永久 已提交
3136 3137 3138
    else:
        check_variable_and_dtype(
            logit, 'logit', ['float32', 'float64'], 'sigmoid_focal_loss'
3139
        )
姜永久 已提交
3140 3141
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'sigmoid_focal_loss'
3142
        )
3143

姜永久 已提交
3144 3145 3146 3147 3148
        bce_name = None
        if reduction == 'none' and normalizer is None:
            bce_name = name
        loss = paddle.nn.functional.binary_cross_entropy_with_logits(
            logit, label, reduction='none', name=bce_name
3149
        )
3150

姜永久 已提交
3151 3152
        pred = paddle.nn.functional.sigmoid(logit)
        p_t = pred * label + (1 - pred) * (1 - label)
3153

姜永久 已提交
3154 3155
        alpha_t = alpha * label + (1 - alpha) * (1 - label)
        loss = paddle.multiply(alpha_t, loss)
3156

姜永久 已提交
3157 3158
        gamma_t = paddle.pow((1 - p_t), gamma)
        loss = paddle.multiply(gamma_t, loss)
3159

姜永久 已提交
3160 3161 3162
        if normalizer is not None:
            normalizer_name = name if reduction == 'none' else None
            loss = paddle.divide(loss, normalizer, name=normalizer_name)
3163

姜永久 已提交
3164 3165 3166 3167
        if reduction == 'mean':
            loss = paddle.mean(loss, name=name)
        elif reduction == 'sum':
            loss = paddle.sum(loss, name=name)
3168

姜永久 已提交
3169
        return loss
3170 3171


3172 3173 3174
def multi_label_soft_margin_loss(
    input, label, weight=None, reduction="mean", name=None
):
Y
yangguohao 已提交
3175
    r"""
3176 3177 3178 3179 3180 3181 3182 3183 3184 3185 3186 3187 3188
    Calculate a multi-class multi-classification
    hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
    and output :math:`y` (which is a 2D `Tensor` of target class indices).
    For each sample in the mini-batch:

    .. math::
        \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}

    where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
    :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
    :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
    and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
    :math:`y` and :math:`x` must have the same size.
Y
yangguohao 已提交
3189

3190 3191 3192 3193 3194 3195 3196 3197 3198 3199 3200 3201 3202 3203
    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
        label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is the same as the shape of input.
        weight (Tensor,optional): a manual rescaling weight given to each class.
                If given, has to be a Tensor of size C and the data type is float32, float64.
                Default is ``'None'`` .
        reduction (str, optional): Indicate how to average the loss by batch_size,
                the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
                If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
                If :attr:`reduction` is ``'sum'``, the summed loss is returned.
                Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
                For more information, please refer to :ref:`api_guide_Name`.
Y
yangguohao 已提交
3204

3205 3206 3207 3208 3209
    Shape:
        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
        label: N-D Tensor, same shape as the input.
        weight:N-D Tensor, the shape is [N,1]
        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.
Y
yangguohao 已提交
3210

3211 3212
    Returns:
        Tensor, The tensor variable storing the multi_label_soft_margin_loss of input and label.
Y
yangguohao 已提交
3213

3214 3215
    Examples:
        .. code-block:: python
Y
yangguohao 已提交
3216

3217 3218 3219 3220 3221 3222 3223 3224 3225 3226 3227
            import paddle
            import paddle.nn.functional as F
            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            # label elements in {1., -1.}
            label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
            loss = F.multi_label_soft_margin_loss(input, label, reduction='none')
            print(loss)
            # Tensor([3.49625897, 0.71111226, 0.43989015])
            loss = F.multi_label_soft_margin_loss(input, label, reduction='mean')
            print(loss)
            # Tensor([1.54908717])
Y
yangguohao 已提交
3228 3229 3230 3231
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'multi_label_soft_margin_loss' should be 'sum', 'mean' or 'none', "
3232 3233
            "but received {}.".format(reduction)
        )
Y
yangguohao 已提交
3234 3235

    if not (input.shape == label.shape):
3236 3237 3238 3239
        raise ValueError(
            "The input and label should have same dimension,"
            "but received {}!={}".format(input.shape, label.shape)
        )
Y
yangguohao 已提交
3240

姜永久 已提交
3241
    if not in_dygraph_mode():
3242 3243 3244 3245 3246 3247 3248 3249 3250 3251 3252 3253
        check_variable_and_dtype(
            input,
            'input',
            ['float32', 'float64'],
            'multilabel_soft_margin_loss',
        )
        check_variable_and_dtype(
            label,
            'label',
            ['float32', 'float64'],
            'multilabel_soft_margin_loss',
        )
Y
yangguohao 已提交
3254

3255 3256 3257 3258
    loss = -(
        label * paddle.nn.functional.log_sigmoid(input)
        + (1 - label) * paddle.nn.functional.log_sigmoid(-input)
    )
Y
yangguohao 已提交
3259 3260

    if weight is not None:
姜永久 已提交
3261
        if not in_dygraph_mode():
3262 3263 3264 3265 3266 3267
            check_variable_and_dtype(
                weight,
                'weight',
                ['float32', 'float64'],
                'multilabel_soft_margin_loss',
            )
Y
yangguohao 已提交
3268 3269 3270 3271 3272 3273 3274 3275 3276 3277 3278 3279
        loss = loss * weight

    loss = loss.mean(axis=-1)  # only return N loss values

    if reduction == "none":
        return loss
    elif reduction == "mean":
        return paddle.mean(loss)
    elif reduction == "sum":
        return paddle.sum(loss)


3280 3281
def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None):
    r"""
3282
    Calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
3283 3284 3285 3286 3287 3288 3289 3290 3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302 3303 3304 3305 3306 3307 3308 3309 3310 3311 3312 3313 3314 3315 3316 3317 3318 3319 3320 3321 3322 3323 3324 3325 3326 3327 3328 3329 3330 3331 3332 3333 3334 3335 3336 3337 3338 3339 3340 3341 3342 3343 3344 3345 3346 3347 3348 3349 3350 3351 3352 3353 3354 3355 3356
    This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
    and is typically used for learning nonlinear embeddings or semi-supervised learning.

    The loss function for :math:`n`-th sample in the mini-batch is

    .. math::
        l_n = \begin{cases}
            x_n, & \text{if}\; y_n = 1,\\
            \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
        \end{cases}

    and the total loss functions is

    .. math::
        \ell(x, y) = \begin{cases}
            \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
        \end{cases}

    where :math:`L = \{l_1,\dots,l_N\}^\top`.

    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64.
            the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
        label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64.
            The shape of label is the same as the shape of input.
        margin (float, optional): Specifies the hyperparameter margin to be used.
            The value determines how large the input need to be to calculate in
            hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss.
            Default = 1.0
        reduction (str, optional): Indicate how to average the loss by batch_size.
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:

        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

        label: N-D Tensor, same shape as the input. tensor elements should containing 1 or -1, the data type is float32 or float64.

        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

    Returns:
        Tensor. The tensor variable storing the hinge_embedding_loss of input and label.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            # label elements in {1., -1.}
            label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

            loss = F.hinge_embedding_loss(input, label, margin=1.0, reduction='none')
            print(loss)
            # Tensor([[0., -2., 0.],
            #         [0., -1., 2.],
            #         [1., 1., 1.]])

            loss = F.hinge_embedding_loss(input, label, margin=1.0, reduction='mean')
            print(loss)
            # Tensor([0.22222222])
    """

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'hinge_embedding_loss' should be 'sum', 'mean' or 'none', "
3357 3358
            "but received {}.".format(reduction)
        )
3359

姜永久 已提交
3360
    if not in_dygraph_mode():
3361 3362 3363 3364 3365 3366
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'hinge_embedding_loss'
        )
        check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'hinge_embedding_loss'
        )
3367 3368

    zero_ = paddle.zeros([1], dtype=input.dtype)
3369 3370 3371
    loss = paddle.where(label == 1.0, input, zero_) + paddle.where(
        label == -1.0, paddle.nn.functional.relu(margin - input), zero_
    )
3372 3373 3374 3375 3376 3377 3378

    if reduction == 'mean':
        return paddle.mean(loss, name=name)
    elif reduction == 'sum':
        return paddle.sum(loss, name=name)
    elif reduction == 'none':
        return loss
3379 3380


3381 3382 3383
def cosine_embedding_loss(
    input1, input2, label, margin=0, reduction='mean', name=None
):
3384
    r"""
3385
    Compute the cosine embedding loss of Tensor ``input1``, ``input2`` and ``label`` as follows.
3386 3387 3388 3389 3390 3391 3392 3393 3394 3395 3396 3397 3398 3399 3400

    If label = 1, then the loss value can be calculated as follow:

    .. math::
        Out = 1 - cos(input1, input2)

    If label = -1, then the loss value can be calculated as follow:

    .. math::
        Out = max(0, cos(input1, input2)) - margin

    The operator cos can be described as follow:
     .. math::
        cos(x1, x2) = \frac{x1 \cdot{} x2}{\Vert x1 \Vert_2 * \Vert x2 \Vert_2}

3401 3402
    Parameters:
        input1 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, which can be 0, 'M' means the length of input array.
3403
                         Available dtypes are float32, float64.
3404
        input2 (Tensor): tensor with shape: [N, M] or [M], 'N' means batch size, which can be 0, 'M' means the length of input array.
3405
                         Available dtypes are float32, float64.
3406
        label (Tensor): tensor with shape: [N] or [1], 'N' means the length of input array. The target labels values should be -1 or 1.
3407 3408 3409 3410 3411 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 3431 3432 3433 3434 3435 3436 3437 3438 3439 3440 3441 3442 3443
                         Available dtypes are int32, int64, float32, float64.
        margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
                         :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
                         default value is :math:`0`.
        reduction (string, optional): Specifies the reduction to apply to the output:
                         ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
                         ``'mean'``: the sum of the output will be divided by the number of elements in the output
                         ``'sum'``: the output will be summed.
        name (str, optional): Name for the operation (optional, default is None).
                         For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, the cosine embedding Loss of Tensor ``input1`` ``input2`` and ``label``.
            If `reduction` is ``'none'``, the shape of output loss is [N], the same as ``input`` .
            If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].

    Examples:
        .. code-block:: python

            import paddle

            input1 = paddle.to_tensor([[1.6, 1.2, -0.5], [3.2, 2.6, -5.8]], 'float32')
            input2 = paddle.to_tensor([[0.5, 0.5, -1.8], [2.3, -1.4, 1.1]], 'float32')
            label = paddle.to_tensor([1, -1], 'int64')

            output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='mean')
            print(output)  # [0.21155193]

            output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='sum')
            print(output)  # [0.42310387]

            output = paddle.nn.functional.cosine_embedding_loss(input1, input2, label, margin=0.5, reduction='none')
            print(output)  # [0.42310387, 0.        ]

    """
    if len(label.shape) != 1:
        raise ValueError(
3444 3445
            "1D target tensor expected, multi-target not supported"
        )
3446 3447 3448 3449

    if input1.shape != input2.shape:
        raise ValueError(
            "the shape of input tensor 1 should be equal to input tensor 2, but found inputs with "
3450 3451
            "different sizes"
        )
3452 3453 3454 3455 3456 3457 3458 3459

    if len(input1.shape) > 2:
        raise ValueError(
            "1D target tensor expects 1D or 2D input tensors, but found inputs with different sizes"
        )

    if input1.dtype not in [paddle.float32, paddle.float64]:
        raise ValueError(
3460 3461
            "The data type of input Variable must be 'float32' or 'float64'"
        )
3462
    if label.dtype not in [
3463 3464 3465 3466
        paddle.int32,
        paddle.int64,
        paddle.float32,
        paddle.float64,
3467 3468 3469 3470 3471 3472 3473 3474 3475 3476 3477 3478 3479 3480 3481 3482 3483 3484 3485 3486 3487 3488 3489
    ]:
        raise ValueError(
            "The data type of label Variable must be 'int32', 'int64', 'float32', 'float64'"
        )

    prod_sum = (input1 * input2).sum(axis=-1)
    mag_square1 = paddle.square(input1).sum(axis=-1) + 10e-12
    mag_square2 = paddle.square(input2).sum(axis=-1) + 10e-12
    denom = paddle.sqrt(mag_square1 * mag_square2)
    cos = prod_sum / denom
    zeros = paddle.zeros_like(cos)
    pos = 1 - cos
    neg = paddle.clip(cos - margin, min=0)
    out_pos = paddle.where(label == 1, pos, zeros)
    out_neg = paddle.where(label == -1, neg, zeros)
    out = out_pos + out_neg

    if reduction == 'none':
        return out
    if reduction == 'mean':
        return paddle.mean(out, name=name)
    elif reduction == 'sum':
        return paddle.sum(out, name=name)
Y
yangguohao 已提交
3490 3491


3492 3493 3494 3495 3496 3497 3498 3499 3500 3501
def triplet_margin_with_distance_loss(
    input,
    positive,
    negative,
    distance_function=None,
    margin=1.0,
    swap=False,
    reduction='mean',
    name=None,
):
Y
yangguohao 已提交
3502 3503 3504 3505 3506 3507 3508 3509 3510 3511 3512 3513 3514 3515 3516 3517 3518 3519 3520
    r"""
    Measures the triplet loss given an input
    tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
    This is used for measuring a relative similarity between samples. A triplet
    is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
    examples` respectively). The shapes of all input tensors should be
    :math:`(N, D)`.

    The loss function for each sample in the mini-batch is:

    .. math::
        L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}


    where the default distance function

    .. math::
        d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p

3521
    or user can defined their own distance functions. `margin` is a nonnegative margin representing the minimum difference
Y
yangguohao 已提交
3522 3523 3524 3525 3526 3527 3528 3529 3530 3531 3532 3533 3534 3535 3536
    between the positive and negative distances that is required for the loss to be 0. If `swap` is true, it will compare distance of (input, negative) with
    distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.

    Parameters:

        input (Tensor):Input tensor, the data type is float32 or float64.
            the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

        positive (Tensor):Positive tensor, the data type is float32 or float64.
            The shape of label is the same as the shape of input.

        negative (Tensor):Negative tensor, the data type is float32 or float64.
            The shape of label is the same as the shape of input.

        distance_function (callable, optional): Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.
3537

3538 3539
        margin (float, optional): A nonnegative margin representing the minimum difference
            between the positive and negative distances required for the loss to be 0. Default value is :math:`1`.
3540

Y
yangguohao 已提交
3541 3542 3543 3544 3545 3546 3547 3548 3549 3550 3551
        swap (bool, optional):The distance swap changes the negative distance to the swap distance (distance between positive samples
                and negative samples) if swap distance smaller than negative distance. Default: ``False``.

        reduction (str, optional):Indicate how to average the loss by batch_size.
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
3552

Y
yangguohao 已提交
3553 3554 3555 3556 3557 3558 3559 3560 3561 3562 3563 3564 3565 3566 3567 3568 3569 3570 3571 3572 3573 3574 3575
    Returns:
        Output: Tensor. The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            loss = F.triplet_margin_with_distance_loss(input, positive, negative, margin=1.0, reduction='none')
            print(loss)
            # Tensor([0.        , 0.57496738, 0.        ])


            loss = F.triplet_margin_with_distance_loss(input, positive, negative, margin=1.0, reduction='mean')
            print(loss)
            # Tensor([0.19165580])

    """
    if reduction not in ['sum', 'mean', 'none']:
3576 3577 3578 3579 3580
        raise ValueError(
            "'reduction' in 'triplet_margin_with_distance_loss' "
            "should be 'sum', 'mean' or 'none', "
            "but received {}.".format(reduction)
        )
Y
yangguohao 已提交
3581 3582 3583 3584
    if margin < 0:
        raise ValueError(
            "The margin between positive samples and negative samples should be greater than 0."
        )
姜永久 已提交
3585
    if not in_dygraph_mode():
3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598 3599 3600 3601 3602 3603
        check_variable_and_dtype(
            input,
            'input',
            ['float32', 'float64'],
            'triplet_margin_with_distance_loss',
        )
        check_variable_and_dtype(
            positive,
            'positive',
            ['float32', 'float64'],
            'triplet_margin_with_distance_loss',
        )
        check_variable_and_dtype(
            negative,
            'negative',
            ['float32', 'float64'],
            'triplet_margin_with_distance_loss',
        )
Y
yangguohao 已提交
3604 3605

    if not (input.shape == positive.shape == negative.shape):
3606 3607 3608 3609 3610
        raise ValueError(
            "input's shape must equal to "
            "positive's shape and  "
            "negative's shape"
        )
Y
yangguohao 已提交
3611

3612 3613 3614
    distance_function = (
        distance_function
        if distance_function is not None
Y
yangguohao 已提交
3615
        else paddle.nn.PairwiseDistance(2)
3616
    )
Y
yangguohao 已提交
3617 3618 3619 3620 3621 3622 3623 3624 3625 3626 3627

    positive_dist = distance_function(input, positive)
    negative_dist = distance_function(input, negative)

    if swap:
        swap_dist = distance_function(positive, negative)
        negative_dist = paddle.minimum(negative_dist, swap_dist)

    if not paddle.all(positive_dist > 0) or not paddle.all(negative_dist > 0):
        raise ValueError(
            "The positive distance or negative distance should be greater than 0, "
3628 3629
            "The distance functions should be checked."
        )
Y
yangguohao 已提交
3630 3631 3632 3633 3634 3635 3636 3637 3638

    loss = paddle.clip(positive_dist - negative_dist + margin, min=0.0)

    if reduction == 'mean':
        return paddle.mean(loss, name=name)
    elif reduction == 'sum':
        return paddle.sum(loss, name=name)
    elif reduction == 'none':
        return loss
Y
yangguohao 已提交
3639 3640


3641 3642 3643 3644 3645 3646 3647 3648 3649 3650 3651
def triplet_margin_loss(
    input,
    positive,
    negative,
    margin=1.0,
    p=2,
    epsilon=1e-6,
    swap=False,
    reduction='mean',
    name=None,
):
Y
yangguohao 已提交
3652 3653 3654 3655 3656 3657 3658 3659 3660 3661 3662 3663 3664 3665 3666 3667 3668 3669 3670 3671 3672 3673 3674 3675 3676 3677 3678 3679 3680 3681 3682 3683 3684 3685 3686 3687 3688 3689 3690 3691 3692 3693 3694 3695 3696 3697 3698 3699 3700 3701 3702 3703 3704 3705 3706 3707 3708 3709 3710 3711 3712 3713 3714 3715 3716 3717 3718 3719 3720 3721 3722 3723 3724 3725 3726 3727
    r"""
        Measures the triplet loss given an input
        tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
        This is used for measuring a relative similarity between samples. A triplet
        is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative
        examples` respectively). The shapes of all input tensors should be
        :math:`(N, *)`.

        The loss function for each sample in the mini-batch is:

        .. math::
            L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}


        where

        .. math::
            d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p

    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64.
            the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.

        positive (Tensor): Positive tensor, the data type is float32 or float64.
            The shape of label is the same as the shape of input.

        negative (Tensor): Negative tensor, the data type is float32 or float64.
            The shape of label is the same as the shape of input.

        margin (float, Optional): Default: :math:`1`.

        p (int, Optional): The norm degree for pairwise distance. Default: :math:`2`.

        epsilon (float, Optional): Add small value to avoid division by zero,
            default value is 1e-6.

        swap (bool,Optional): The distance swap change the negative distance to the distance between
            positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`.
            Default: ``False``.


        reduction (str, Optional):Indicate how to average the loss by batch_size.
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``

        name (str, Optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Output: Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
            negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
            loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='none')
            print(loss)
            # Tensor([0.        , 0.57496738, 0.        ])


            loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='mean')
            print(loss)
            # Tensor([0.19165580])

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'triplet_margin_loss' should be 'sum', 'mean' or 'none', "
3728 3729
            "but received {}.".format(reduction)
        )
Y
yangguohao 已提交
3730 3731 3732 3733
    if margin < 0:
        raise ValueError(
            "The margin between positive samples and negative samples should be greater than 0."
        )
姜永久 已提交
3734
    if not in_dygraph_mode():
3735 3736 3737 3738 3739 3740 3741 3742 3743
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'triplet_margin_loss'
        )
        check_variable_and_dtype(
            positive, 'positive', ['float32', 'float64'], 'triplet_margin_loss'
        )
        check_variable_and_dtype(
            negative, 'negative', ['float32', 'float64'], 'triplet_margin_loss'
        )
Y
yangguohao 已提交
3744 3745

    if not (input.shape == positive.shape == negative.shape):
3746 3747 3748 3749 3750
        raise ValueError(
            "input's shape must equal to "
            "positive's shape and  "
            "negative's shape"
        )
Y
yangguohao 已提交
3751 3752 3753 3754 3755 3756 3757 3758 3759 3760 3761 3762 3763 3764 3765 3766 3767

    distance_function = paddle.nn.PairwiseDistance(p, epsilon=epsilon)
    positive_dist = distance_function(input, positive)
    negative_dist = distance_function(input, negative)

    if swap:
        swap_dist = distance_function(positive, negative)
        negative_dist = paddle.minimum(negative_dist, swap_dist)

    loss = paddle.clip(positive_dist - negative_dist + margin, min=0.0)

    if reduction == 'mean':
        return paddle.mean(loss, name=name)
    elif reduction == 'sum':
        return paddle.sum(loss, name=name)
    elif reduction == 'none':
        return loss
3768 3769


3770 3771 3772 3773 3774 3775 3776 3777 3778
def multi_margin_loss(
    input,
    label,
    p: int = 1,
    margin: float = 1.0,
    weight=None,
    reduction='mean',
    name=None,
):
Y
yangguohao 已提交
3779 3780 3781 3782 3783 3784 3785 3786 3787 3788 3789 3790 3791 3792 3793 3794 3795 3796 3797 3798 3799 3800 3801 3802 3803 3804 3805 3806 3807 3808 3809 3810 3811 3812 3813 3814 3815 3816 3817 3818 3819 3820 3821 3822 3823 3824 3825 3826 3827 3828 3829 3830 3831 3832 3833 3834 3835 3836 3837 3838 3839 3840
    r"""
        Measures a multi-class classification hinge loss between input :math:`input` and label :math:`label`:

        For i-th mini-batch sample, the loss in terms of the 1D input :math:`input_i` and scalar
        output :math:`label_i` is:

        .. math::
            \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, \text{margin} - input_i[label_i] + input_i[j])^p}{\text{C}}

        where :math:`0 \leq j \leq \text{C}-1`, :math:`0 \leq i \leq \text{N}-1` and :math:`j \neq label_i`.

        Optionally, you can give non-equal weighting on the classes by passing
        a 1D :attr:`weight` tensor into the constructor.

        The loss function for i-th sample then becomes:

        .. math::
            \text{loss}(input_i, label_i) = \frac{\sum_{j} \max(0, weight[label_i] * (\text{margin} - input_i[label_i] + input_i[j]))^p}{\text{C}}


    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes.

        label (Tensor): Label tensor, the data type is int32 or int64. The shape of label is (N,)

        p (int, Optional): The power num. Default: :math:`1`.

        margin (float, Optional): Default: :math:`1`.

        weight (Tensor,optional): a manual rescaling weight given to each class.
                If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
                Default is ``'None'`` .


        reduction (str, Optional):Indicate how to calculate the loss by batch_size.
            the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``

        name (str, Optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Output: Tensor. The tensor variable storing the multi_margin_loss of input and label.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
            label = paddle.to_tensor([1, 2, 1], dtype=paddle.int32)
            loss = F.multi_margin_loss(input, label, margin=1.0, reduction='none')
            print(loss)

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'multi_margin_loss' should be 'sum', 'mean' or 'none', "
3841 3842
            "but received {}.".format(reduction)
        )
Y
yangguohao 已提交
3843

姜永久 已提交
3844
    if not in_dygraph_mode():
3845 3846 3847 3848 3849 3850
        check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'multi_margin_loss'
        )
        check_variable_and_dtype(
            label, 'label', ['int32', 'int64'], 'multi_margin_loss'
        )
Y
yangguohao 已提交
3851 3852 3853 3854
    if not (input.shape[0] == label.shape[0]):
        raise ValueError(
            "The label's shape[0] should be equal to input's shape[0], "
            "but received input's shape[0] {} and label's shape[0]:{}. ".format(
3855 3856 3857
                input.shape[0], label.shape[0]
            )
        )
Y
yangguohao 已提交
3858 3859 3860
    label = label.reshape((-1, 1))
    index_sample = paddle.index_sample(input, label)
    if weight is not None:
姜永久 已提交
3861
        if not in_dygraph_mode():
3862 3863 3864
            check_variable_and_dtype(
                weight, 'weight', ['float32', 'float64'], 'multi_margin_loss'
            )
Y
yangguohao 已提交
3865 3866 3867
        if not (input.shape[1] == weight.shape[0]):
            raise ValueError(
                "The weight's shape[0] should be equal to input's shape[1]"
3868 3869 3870 3871
                "but received weight's shape[0]: {} and input's shape[1]: {}".format(
                    weight.shape[0], input.shape[1]
                )
            )
Y
yangguohao 已提交
3872 3873 3874
        weight = paddle.gather(weight, label, axis=0).reshape((-1, 1))
        loss = paddle.mean(
            paddle.pow(
3875 3876 3877 3878 3879
                paddle.clip(weight * (margin - index_sample + input), min=0.0),
                p,
            ),
            axis=1,
        ) - weight * (margin**p / paddle.shape(input)[1])
Y
yangguohao 已提交
3880
    else:
3881 3882 3883 3884 3885 3886 3887 3888 3889
        loss = (
            paddle.mean(
                paddle.pow(
                    paddle.clip(margin - index_sample + input, min=0.0), p
                ),
                axis=1,
            )
            - margin**p / paddle.shape(input)[1]
        )
Y
yangguohao 已提交
3890 3891 3892 3893 3894 3895 3896 3897 3898

    if reduction == 'mean':
        return paddle.mean(loss, name=name)
    elif reduction == 'sum':
        return paddle.sum(loss, name=name)
    elif reduction == 'none':
        return loss


3899 3900
def soft_margin_loss(input, label, reduction='mean', name=None):
    """
3901

3902 3903 3904 3905 3906 3907 3908 3909
    The API measures the soft margin loss between input predictions ``input``
    and target labels ``label`` . It can be described as:

    .. math::
        Out = log(1 + exp((-label * input)))

    Parameters:

3910
        input (Tensor): The input predications tensor with shape: ``[N, *]``,
3911
            N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf.
3912
            Available dtype is float32, float64.
3913 3914 3915 3916 3917 3918 3919 3920 3921 3922 3923 3924 3925 3926 3927 3928 3929

        label (Tensor): The target labels tensor with the same shape as
            ``input``. The target labels which values should be numbers -1 or 1.
            Available dtype is int32, int64, float32, float64.

        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.

        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:

3930
        Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is [1].
3931 3932 3933 3934 3935 3936 3937 3938 3939

    Examples:
        .. code-block:: python

            import paddle

            input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
            label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
            output = paddle.nn.functional.soft_margin_loss(input, label)
3940 3941 3942 3943 3944 3945 3946
            print(output)
            # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [0.64022040])

            input = paddle.uniform(shape=(5, 5), dtype="float32", min=0.1, max=0.8)
            label = paddle.randint(0, 2, shape=(5, 5), dtype="int64")
            label[label==0]=-1
3947 3948

            output = paddle.nn.functional.soft_margin_loss(input, label, reduction='none')
3949 3950 3951 3952 3953 3954 3955
            print(output)
            # Tensor(shape=[5, 5], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[1.09917796, 0.52613139, 0.56263304, 0.82736146, 0.38776723],
            #         [1.07179427, 1.11924267, 0.49877715, 1.10026348, 0.46184641],
            #         [0.84367639, 0.74795729, 0.44629076, 0.55123353, 0.77659678],
            #         [0.39465919, 0.76651484, 0.54485321, 0.76609844, 0.77166790],
            #         [0.51283568, 0.84757161, 0.78913331, 1.05268764, 0.45318675]])
3956

3957 3958 3959 3960
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in soft_margin_loss should be 'sum', "
3961 3962 3963
            "'mean' or 'none', but received %s, which is not allowed."
            % reduction
        )
3964

姜永久 已提交
3965
    if not in_dygraph_mode():
3966
        fluid.data_feeder.check_variable_and_dtype(
3967 3968 3969 3970 3971 3972 3973 3974
            input, 'input', ['float32', 'float64'], 'soft_margin_loss'
        )
        fluid.data_feeder.check_variable_and_dtype(
            label,
            'label',
            ['int32', 'int64', 'float32', 'float64'],
            'soft_margin_loss',
        )
3975 3976

    if not (input.shape == label.shape):
3977
        raise ValueError("input's shape must equal to " "label's shape")
3978

3979
    label = paddle.cast(label, input.dtype)
3980 3981 3982 3983 3984 3985 3986 3987
    out = paddle.log(1 + paddle.exp(-label * input))

    if reduction == 'sum':
        return paddle.sum(out, name=name)
    elif reduction == 'mean':
        return paddle.mean(out, name=name)
    else:
        return out