loss.py 93.3 KB
Newer Older
1
# -*- coding: utf-8 -*
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import paddle
17 18 19
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
import paddle.fluid as fluid
20

21
# TODO: define loss functions of neural network
22
import numpy as np
23 24 25 26
import paddle
import paddle.fluid as fluid
from ...fluid.framework import core, in_dygraph_mode
from ...fluid.layers.nn import _elementwise_op_in_dygraph
Z
zhiboniu 已提交
27 28 29
from ...fluid.layers import dice_loss  # noqa: F401
from ...fluid.layers import log_loss  # noqa: F401
from ...fluid.layers import npair_loss  # noqa: F401
30
from ...tensor.manipulation import reshape
Z
zhiboniu 已提交
31 32
from ...fluid.layers import softmax_with_cross_entropy as fluid_softmax_with_cross_entropy
from ...fluid.layers import square_error_cost  # noqa: F401
33

Z
zhiboniu 已提交
34
from ...fluid.layers import edit_distance  # noqa: F401
35
from ...fluid.layers import huber_loss
36
from ...fluid.layer_helper import LayerHelper
37
from ...fluid.framework import in_dygraph_mode
38
from ...fluid.framework import _varbase_creator
39
from ...static import Variable
40
from paddle.utils import deprecated
W
wanghuancoder 已提交
41
from paddle import _C_ops
42

43 44
__all__ = []

45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
def binary_cross_entropy(input, label, weight=None, reduction='mean',
                         name=None):
    """
    This op measures the binary_cross_entropy loss between input predictions ``input``
    and target labels ``label`` . The binary_cross_entropy loss can be described as:

    If :attr:`weight` is set, the loss is:

    .. math::
        Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))

    If :attr:`weight` is None, the loss is:

    .. math::
        Out = -1 * (label * log(input) + (1 - label) * log(1 - input))

    If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(Out)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(Out)

    Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
    should be numbers between 0 and 1.

    Parameters:
        input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``input``
            should always be the output of sigmod.  Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``input``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, has to be a Tensor of size nbatch and the data type
            is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.


    Returns:
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``input`` , else the shape of output is scalar.

    Examples:
        .. code-block:: python

            import paddle

106 107
            input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32')
            label = paddle.to_tensor([1.0, 0.0, 1.0], 'float32')
108
            output = paddle.nn.functional.binary_cross_entropy(input, label)
N
Noel 已提交
109
            print(output)  # [0.65537095]
110 111 112 113 114 115 116 117 118

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in binary_cross_entropy should be 'sum', "
            "'mean' or 'none', but received %s, which is not allowed." %
            reduction)

    if in_dygraph_mode():
W
wanghuancoder 已提交
119
        out = _C_ops.bce_loss(input, label)
120
        if weight is not None:
W
wanghuancoder 已提交
121
            out = _C_ops.elementwise_mul(out, weight, 'axis', -1)
122 123

        if reduction == 'sum':
W
wanghuancoder 已提交
124 125
            return _C_ops.reduce_sum(out, 'dim', [0], 'keep_dim', False,
                                     "reduce_all", True)
126
        elif reduction == 'mean':
W
wanghuancoder 已提交
127
            return _C_ops.mean(out)
128 129 130 131 132 133 134 135 136
        else:
            return out

    fluid.data_feeder.check_variable_and_dtype(
        input, 'input', ['float32', 'float64'], 'binary_cross_entropy')
    fluid.data_feeder.check_variable_and_dtype(
        label, 'label', ['float32', 'float64'], 'binary_cross_entropy')

    sub_name = name if weight is None and reduction is 'none' else None
137 138 139 140 141 142 143 144 145
    helper = LayerHelper("binary_cross_entropy", name=sub_name)
    out = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='bce_loss',
        inputs={
            'X': [input],
            'Label': [label],
        },
        outputs={'Out': [out]})
146 147

    if weight is not None:
148
        if isinstance(weight, paddle.static.Variable):
149
            weight_name = name if reduction is 'none' else None
150
            out = paddle.multiply(out, weight, name=weight_name)
151 152 153 154 155 156 157 158 159 160 161 162
        else:
            raise ValueError(
                "The weight is not a Tensor, please convert to Tensor.")

    if reduction == 'sum':
        return paddle.sum(out, name=name)
    elif reduction == 'mean':
        return paddle.mean(out, name=name)
    else:
        return out


163 164 165 166 167 168
def binary_cross_entropy_with_logits(logit,
                                     label,
                                     weight=None,
                                     reduction='mean',
                                     pos_weight=None,
                                     name=None):
169
    r"""
170 171 172 173 174 175 176 177 178 179 180 181 182
    This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
    Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
    layer and some reduce operations.

    This measures the element-wise probability error in classification tasks
    in which each class is independent.
    This can be thought of as predicting labels for a data-point, where labels
    are not mutually exclusive. For example, a news article can be about
    politics, technology or sports at the same time or none of these.

    First this operator calculate loss function as follows:

    .. math::
183
           Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
184

185
    We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
186 187

    .. math::
188
           Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
189

N
Noel 已提交
190
    For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
191 192 193
    we reformulate the loss as follows:

    .. math::
194
           Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238

    Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
    weight tensor on the loss `Out`. The ``weight`` tensor will attach different
    weight on every items in the batch. The ``pos_weight`` will attach different
    weight on the positive label of each class.

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target labels ``label`` should be numbers between 0 and 1.

    Args:
        logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
            N is batch_size, `*` means number of additional dimensions. The ``logit``
            is usually the output of Linear layer. Available dtype is float32, float64.
        label (Tensor): The target labels tensor. 2-D tensor with the same shape as
            ``logit``. The target labels which values should be numbers between 0 and 1.
            Available dtype is float32, float64.
        weight (Tensor, optional): A manual rescaling weight given to the loss of each
            batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
            The data type is float32, float64. Default is ``'None'``.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'mean'``.
        pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
            with length equal to the number of classes. The data type is float32, float64.
            Default is ``'None'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
            same as ``logit`` , else the shape of output is scalar.

    Examples:

        .. code-block:: python

            import paddle
N
Noel 已提交
239

240 241
            logit = paddle.to_tensor([5.0, 1.0, 3.0])
            label = paddle.to_tensor([1.0, 0.0, 1.0])
242
            output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label)
N
Noel 已提交
243
            print(output)  # [0.45618808]
244 245 246 247 248 249 250 251 252 253

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in binary_cross_entropy_with_logits "
            "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
            % reduction)

    if in_dygraph_mode():
        one = _varbase_creator(dtype=logit.dtype)
W
wanghuancoder 已提交
254 255 256 257
        _C_ops.fill_constant(one, 'value',
                             float(1.0), 'force_cpu', False, 'dtype', one.dtype,
                             'str_value', '1.0', 'shape', [1])
        out = _C_ops.sigmoid_cross_entropy_with_logits(logit, label)
258
        if pos_weight is not None:
W
wanghuancoder 已提交
259 260 261 262 263
            log_weight = _C_ops.elementwise_add(
                _C_ops.elementwise_mul(label,
                                       _C_ops.elementwise_sub(pos_weight, one)),
                one)
            out = _C_ops.elementwise_mul(out, log_weight)
264
        if weight is not None:
W
wanghuancoder 已提交
265
            out = _C_ops.elementwise_mul(out, weight)
266 267

        if reduction == "sum":
W
wanghuancoder 已提交
268
            return _C_ops.reduce_sum(out, 'reduce_all', True)
269
        elif reduction == "mean":
W
wanghuancoder 已提交
270
            return _C_ops.mean(out)
271 272 273 274 275 276 277 278 279 280 281 282 283
        else:
            return out

    fluid.data_feeder.check_variable_and_dtype(
        logit, 'logit', ['float32', 'float64'],
        'binary_cross_entropy_with_logits')
    fluid.data_feeder.check_variable_and_dtype(
        label, 'label', ['float32', 'float64'],
        'binary_cross_entropy_with_logits')
    sigmoid_name = None
    if reduction == 'none' and pos_weight is None and weight is None:
        sigmoid_name = name

284
    out = paddle.fluid.layers.sigmoid_cross_entropy_with_logits(
285 286
        logit, label, name=sigmoid_name)

287 288
    one = paddle.fluid.layers.fill_constant(
        shape=[1], value=1.0, dtype=logit.dtype)
289 290 291 292 293
    if pos_weight is not None:
        fluid.data_feeder.check_variable_and_dtype(
            pos_weight, 'pos_weight', ['float32', 'float64'],
            'binary_cross_entropy_with_logits')
        log_weight = paddle.add(
294
            paddle.multiply(label, paddle.subtract(pos_weight, one)), one)
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
        pos_weight_name = name if reduction == 'none' and weight is None else None
        out = paddle.multiply(out, log_weight, name=pos_weight_name)

    if weight is not None:
        fluid.data_feeder.check_variable_and_dtype(
            weight, 'weight', ['float32', 'float64'],
            'binary_cross_entropy_with_logits')
        weight_name = name if reduction == 'none' else None
        out = paddle.multiply(out, weight, name=weight_name)

    if reduction == "sum":
        return paddle.sum(out, name=name)
    elif reduction == "mean":
        return paddle.mean(out, name=name)
    return out


312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
def hsigmoid_loss(input,
                  label,
                  num_classes,
                  weight,
                  bias=None,
                  path_table=None,
                  path_code=None,
                  is_sparse=False,
                  name=None):
    """
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
    Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
    represents the number of classes or the size of word dict.

    The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_. For the custom
    tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        input (Tensor): A tensor with the shape [N, D], where N is the size of mini-batch,
            and D is the feature size. Its data type supports float32 or float64.
        label (Tensor): A tensor contains the labels of training data. Its shape is [N, 1]
            and data type is int64.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (path_code and path_table is None are None), `num_classes`
            should not be None. If the custom tree is used (path_code and path_table is None are not None),
            `num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        weight (Tensor): A tensor with shape (num_classes - 1, D), with the same data type as `input`.
        bias (Tensor, optional): A tensor with shape (num_classes - 1, 1), with the same data type as `input`.
            If `bias` is None, no bias will be add. Default is None.
        path_table (Tensor, optional): A tensor that stores each batch of samples' path from leaf to root
            node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i,
            path_table[i] is a np.array like structure and each element in this array is the indexes in parent
            nodes' weight matrix. If `path_table` and `path_code` are None, the default tree will be used.
            Default is None.
        path_code (Tensor, optional): A tensor that stores each batch of samples' code of path from leaf
            to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`.
            Each code of path is consisted with the code of nodes from leaf to root node. If `path_table` and
            `path_code` are None, the default tree will be used. Default is None.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating. If `is_sparse` is True,
            the gradient of `weight` and `input` will be sparse. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as `input`.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            paddle.set_device('cpu')

            input = paddle.uniform([2, 3])
            # [[-0.8018668   0.8736385  -0.9064771 ] # random
            #  [-0.10228515 -0.87188244 -0.8783718 ]] # random
            label = paddle.to_tensor([0, 1, 4, 5])
            num_classes = 5
            weight=paddle.uniform([num_classes-1, 3])
            # [[-0.24148715  0.8449961  -0.7399121 ] # random
            #  [-0.9800559   0.43509364  0.9091208 ] # random
            #  [ 0.60194826  0.10430074 -0.4521166 ] # random
            #  [-0.4469818  -0.01536179 -0.604454  ]] # random

            out=F.hsigmoid_loss(input, label, num_classes, weight)
            # [[3.0159328]
            #  [2.2407534]]
    """

    if in_dygraph_mode():
W
wanghuancoder 已提交
396
        out, _, _ = _C_ops.hierarchical_sigmoid(
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
            input, weight, label, path_table, path_code, bias, 'num_classes',
            num_classes, 'is_sparse', is_sparse, 'remote_prefetch', is_sparse)
        return out

    check_variable_and_dtype(input, 'input', ['float32', 'float64'],
                             'hsigmoid_loss')
    check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid_loss')
    check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
                             'hsigmoid_loss')
    if bias is not None:
        check_variable_and_dtype(bias, 'bias', ['float32', 'float64'],
                                 'hsigmoid_loss')
    if path_table is not None:
        check_variable_and_dtype(path_table, 'path_table', ['int64'],
                                 'hsigmoid_loss')
    if path_code is not None:
        check_variable_and_dtype(path_code, 'path_code', ['int64'],
                                 'hsigmoid_loss')

    attrs = {
        "num_classes": num_classes,
        "is_sparse": is_sparse,
        "remote_prefetch": is_sparse
    }

    inputs = {
        "X": input,
        "W": weight,
        "Bias": bias,
        "PathTable": path_table,
        "PathCode": path_code,
        "Label": label
    }

    helper = LayerHelper('hsigmoid_loss', **locals())
    out = helper.create_variable_for_type_inference(input.dtype)
    pre_out = helper.create_variable_for_type_inference(input.dtype)
    outputs = {"Out": out, "PreOut": pre_out, "W_Out": weight}

    helper.append_op(
        type="hierarchical_sigmoid",
        inputs=inputs,
        outputs=outputs,
        attrs=attrs)
    return out


444
def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
445
    r"""
446 447 448 449 450 451 452
    This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
    term if the absolute element-wise error falls below 1 and an L1 term otherwise.
    In some cases it can prevent exploding gradients and it is more robust and less
    sensitivity to outliers. Also known as the Huber loss:

    .. math::

453
         loss(x,y) = \frac{1}{n}\sum_{i}z_i
454 455 456 457 458 459


    where z_i is given by:

    .. math::

460 461
        \mathop{z_i} = \left\{\begin{array}{rcl}
        0.5(x_i - y_i)^2 & & {if |x_i - y_i| < delta} \\
462
        delta * |x_i - y_i| - 0.5 * delta^2 & & {otherwise}
463
        \end{array} \right.
464 465 466 467 468 469 470 471 472 473 474 475 476

    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64. Shape is
            (N, C), where C is number of classes, and if shape is more than 2D, this
            is (N, C, D1, D2,..., Dk), k >= 1.
        label (Tensor): Label tensor, the data type is float32 or float64. The shape of label
            is the same as the shape of input.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
477
        delta (float, optional): Specifies the hyperparameter delta to be used.
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
            The value determines how large the errors need to be to use L1. Errors
            smaller than delta are minimized with L2. Parameter is ignored for
            negative/zero values. Default = 1.0
        name (str, optional): Name for the operation (optional, default is
            None). For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        The tensor variable storing the smooth_l1_loss of input and label.

    Return type: Tensor.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            input_data = np.random.rand(3,3).astype("float32")
            label_data = np.random.rand(3,3).astype("float32")
            input = paddle.to_tensor(input_data)
            label = paddle.to_tensor(label_data)
C
Chen Long 已提交
499
            output = paddle.nn.functional.smooth_l1_loss(input, label)
G
Guanghua Yu 已提交
500
            print(output)
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
    """
    fluid.data_feeder.check_variable_and_dtype(
        input, 'input', ['float32', 'float64'], 'smooth_l1_loss')
    fluid.data_feeder.check_variable_and_dtype(
        label, 'label', ['float32', 'float64'], 'smooth_l1_loss')

    out = huber_loss(input=input, label=label, delta=delta)

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in smooth_l1_loss should be 'sum', 'mean' or"
            " 'none', but received %s, which is not allowed." % reduction)
    if reduction == 'none':
        return out
    elif reduction == 'mean':
516
        return paddle.mean(out)
517
    elif reduction == 'sum':
518
        return paddle.sum(out)
519 520


521 522
def margin_ranking_loss(input,
                        other,
523
                        label,
524 525 526
                        margin=0.0,
                        reduction='mean',
                        name=None):
527
    r"""
528

529
    This op the calcluate the the margin rank loss between the input, other and label, use the math function as follows.
530

531
    .. math::
532
        margin\_rank\_loss = max(0, -label * (input - other) + margin)
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548

    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:

    .. math::
        Out = MEAN(margin\_rank\_loss)

    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:

    .. math::
        Out = SUM(margin\_rank\_loss)

    If :attr:`reduction` set to ``'none'``, just return the origin ``margin_rank_loss``.

    Parameters:
        input(Tensor): the first input tensor, it's data type should be float32, float64.
        other(Tensor): the second input tensor, it's data type should be float32, float64.
549
        label(Tensor): the label value corresponding to input, it's data type should be float32, float64.
550 551 552 553 554 555 556 557 558 559
        margin (float, optional): The margin value to add, default value is 0;
        reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Returns: Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as `input` .The same dtype as input tensor.

    Examples:

        .. code-block:: python

560 561
            import paddle

Z
Zhong Hui 已提交
562 563 564
            input = paddle.to_tensor([[1, 2], [3, 4]], dtype='float32')
            other = paddle.to_tensor([[2, 1], [2, 4]], dtype='float32')
            label = paddle.to_tensor([[1, -1], [-1, -1]], dtype='float32')
565
            loss = paddle.nn.functional.margin_ranking_loss(input, other, label)
N
Noel 已提交
566
            print(loss) # [0.75]
567
    """
568 569 570 571
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
            "received %s, which is not allowed." % reduction)
572
    if fluid.framework.in_dygraph_mode():
W
wanghuancoder 已提交
573 574
        out = _C_ops.elementwise_sub(other, input)
        out = _C_ops.elementwise_mul(out, label)
575 576
        if margin != 0.0:
            margin = fluid.dygraph.base.to_variable([margin], dtype=out.dtype)
W
wanghuancoder 已提交
577 578
            out = _C_ops.elementwise_add(out, margin)
        out = _C_ops.relu(out)
579
        if reduction == 'sum':
W
wanghuancoder 已提交
580
            return _C_ops.reduce_sum(out, 'reduce_all', True)
581
        elif reduction == 'mean':
W
wanghuancoder 已提交
582
            return _C_ops.mean(out)
583 584 585 586 587 588 589 590
        return out

    helper = LayerHelper("margin_ranking_loss", **locals())
    fluid.data_feeder.check_variable_and_dtype(
        input, 'input', ['float32', 'float64'], 'margin_rank_loss')
    fluid.data_feeder.check_variable_and_dtype(
        other, 'other', ['float32', 'float64'], 'margin_rank_loss')
    fluid.data_feeder.check_variable_and_dtype(
591
        label, 'label', ['float32', 'float64'], 'margin_rank_loss')
592

593
    out = paddle.subtract(other, input)
594
    out = paddle.multiply(out, label)
595 596 597

    if margin != 0.0:
        margin_var = out.block.create_var(dtype=out.dtype)
598 599
        paddle.fluid.layers.fill_constant(
            [1], out.dtype, margin, out=margin_var)
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
        out = paddle.add(out, margin_var)

    result_out = helper.create_variable_for_type_inference(input.dtype)

    if reduction == 'none':
        helper.append_op(
            type="relu", inputs={"X": out}, outputs={"Out": result_out})
        return result_out
    elif reduction == 'sum':
        out = paddle.nn.functional.relu(out)
        attrs = {"dim": [0], "keep_dim": False, "reduce_all": True}
        helper.append_op(
            type="reduce_sum",
            inputs={"X": out},
            outputs={"Out": result_out},
            attrs=attrs)
        return result_out
    elif reduction == 'mean':
        out = paddle.nn.functional.relu(out)
        helper.append_op(
            type="mean",
            inputs={"X": out},
            outputs={"Out": result_out},
            attrs={})
        return result_out


627
def l1_loss(input, label, reduction='mean', name=None):
628
    r"""
629
    This operator computes the L1 Loss of Tensor ``input`` and ``label`` as follows.
630

631
    If `reduction` set to ``'none'``, the loss is:
632 633

    .. math::
634
        Out = \lvert input - label \rvert
635

636
    If `reduction` set to ``'mean'``, the loss is:
637 638

    .. math::
639
        Out = MEAN(\lvert input - label \rvert)
640

641
    If `reduction` set to ``'sum'``, the loss is:
642 643

    .. math::
644
        Out = SUM(\lvert input - label \rvert)
645

646

647
    Parameters:
N
Noel 已提交
648 649
        input (Tensor): The input tensor. The shapes is [N, `*`], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
        label (Tensor): label. The shapes is [N, `*`], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
650
        reduction (str, optional): Indicate the reduction to apply to the loss,
651
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
652 653 654
            If `reduction` is ``'none'``, the unreduced loss is returned;
            If `reduction` is ``'mean'``, the reduced mean loss is returned.
            If `reduction` is ``'sum'``, the reduced sum loss is returned.
655 656
            Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
N
Noel 已提交
657

658
    Returns:
659 660 661
        Tensor, the L1 Loss of Tensor ``input`` and ``label``.
            If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` .
            If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
N
Noel 已提交
662

663 664
    Examples:
        .. code-block:: python
N
Noel 已提交
665

666
            import paddle
667

668 669
            input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]])
            label = paddle.to_tensor([[1.7, 1], [0.4, 0.5]])
670

671
            l1_loss = paddle.nn.functional.l1_loss(input, label)
672
            print(l1_loss.numpy())
673 674
            # [0.35]

675
            l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='none')
676
            print(l1_loss.numpy())
677 678 679
            # [[0.20000005 0.19999999]
            # [0.2        0.79999995]]

680
            l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='sum')
681
            print(l1_loss.numpy())
682 683 684 685 686 687 688 689 690
            # [1.4]
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
            "received %s, which is not allowed." % reduction)

    if in_dygraph_mode():
        unreduced = _elementwise_op_in_dygraph(
691
            input, label, axis=-1, act='abs', op_name='elementwise_sub')
692
        if reduction == 'mean':
W
wanghuancoder 已提交
693
            return _C_ops.mean(unreduced)
694
        elif reduction == 'sum':
W
wanghuancoder 已提交
695 696
            return _C_ops.reduce_sum(unreduced, 'dim', [0], 'keep_dim', False,
                                     'reduce_all', True)
697 698 699 700
        else:
            return unreduced

    fluid.data_feeder.check_variable_and_dtype(
701
        input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
702 703 704 705
    fluid.data_feeder.check_variable_and_dtype(
        label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')

    if reduction == 'sum':
706
        unreduced = paddle.fluid.layers.elementwise_sub(input, label, act='abs')
707 708
        return paddle.sum(unreduced, name=name)
    elif reduction == 'mean':
709
        unreduced = paddle.fluid.layers.elementwise_sub(input, label, act='abs')
710 711
        return paddle.mean(unreduced, name=name)
    else:
712 713
        return paddle.fluid.layers.elementwise_sub(
            input, label, act='abs', name=name)
714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751


def nll_loss(input,
             label,
             weight=None,
             ignore_index=-100,
             reduction='mean',
             name=None):
    """
    This api returns negative log likelihood.
    See more detail in :ref:`api_nn_loss_NLLLoss` .

    Parameters:
         input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes.
             But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`.
             The data type is float32, float64.
         label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`.
             The data type is int64.
         weight (Tensor, optional): Weight tensor, a manual rescaling weight given
             to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
             it treated as if having all ones. the data type is
             float32, float64, Default is ``'None'``.
         ignore_index (int64, optional): Specifies a target value that is ignored
             and does not contribute to the input gradient.
         reduction (str, optional): Indicate how to average the loss,
             the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
             If `reduction` is ``'mean'``, the reduced mean loss is returned;
             if `reduction` is ``'sum'``, the reduced sum loss is returned;
             if `reduction` is ``'none'``, no reduction will be apllied.
             Default is ``'mean'``.
         name (str, optional): Name for the operation (optional, default is None).
             For more information, please refer to :ref:`api_guide_Name`.

    Returns:
         `Tensor`, the value of negative log likelihood loss.

    Examples:
        .. code-block:: python
752

753 754 755 756
                import paddle
                from paddle.nn.functional import nll_loss
                log_softmax = paddle.nn.LogSoftmax(axis=1)

757 758 759 760 761
                input = paddle.to_tensor([[0.88103855, 0.9908683 , 0.6226845 ],
                          [0.53331435, 0.07999352, 0.8549948 ],
                          [0.25879037, 0.39530203, 0.698465  ],
                          [0.73427284, 0.63575995, 0.18827209],
                          [0.05689114, 0.0862954 , 0.6325046 ]], "float32")
762
                log_out = log_softmax(input)
763
                label = paddle.to_tensor([0, 2, 1, 1, 0], "int64")
764
                result = nll_loss(log_out, label)
765
                print(result) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, [1.07202101])
766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
            "'none', but received %s, which is not allowed." % reduction)

    input_shape = list(input.shape)
    input_dims = len(input_shape)
    if input_dims < 2:
        raise ValueError('Expected 2 or more dimensions (got {})'.format(
            input_dims))
    n = input_shape[0]
    c = input_shape[1]
    if in_dygraph_mode():
        if input_dims != 2 and input_dims != 4:
W
wanghuancoder 已提交
781 782
            input, _ = _C_ops.reshape2(input, None, 'shape', [n, c, 1, -1])
            label, _ = _C_ops.reshape2(label, None, 'shape', [n, 1, -1])
783
            out_shape = [n] + input_shape[2:]
W
wanghuancoder 已提交
784 785 786
        out, total_weight = _C_ops.nll_loss(input, label, weight,
                                            'ignore_index', ignore_index,
                                            'reduction', reduction)
787
        if input_dims != 2 and input_dims != 4 and reduction == 'none':
W
wanghuancoder 已提交
788
            out, _ = _C_ops.reshape2(out, None, 'shape', out_shape)
789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
        return out

    helper = LayerHelper('nll_loss', **locals())

    if input_dims != 2 and input_dims != 4:
        input = reshape(input, shape=[n, c, 1, -1])
        label = reshape(label, shape=[n, 1, -1])
        out_shape = [n] + input_shape[2:]

    fluid.data_feeder.check_variable_and_dtype(
        input, 'input', ['float32', 'float64'], 'nll_loss')
    fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
                                               'nll_loss')
    inputs = {'X': input, 'Label': label}
    attrs = {'reduction': reduction, 'ignore_index': ignore_index}
    if weight is not None:
        if isinstance(weight, Variable):
            inputs['Weight'] = weight

    out = helper.create_variable_for_type_inference(dtype=input.dtype)
    total_weight = helper.create_variable_for_type_inference(dtype=input.dtype)
    outputs = {'Out': out, 'Total_weight': total_weight}

    helper.append_op(
        type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs)
    if input_dims != 2 and input_dims != 4 and reduction == 'none':
        out = reshape(out, shape=out_shape)

    return out
818 819


820
def kl_div(input, label, reduction='mean', name=None):
821
    r"""
822 823 824 825 826 827 828 829 830 831 832
    This operator calculates the Kullback-Leibler divergence loss
    between Input(X) and Input(Target). Notes that Input(X) is the
    log-probability and Input(Target) is the probability.

    KL divergence loss is calculated as follows:

    $$l(x, y) = y * (\log(y) - x)$$

    While :math:`x` is input and :math:`y` is label.

    While :attr:`reduction` is :attr:`none`, output loss is in
833
    the same shape as input, loss in each point is calculated
834
    seperately and no reduction is applied.
835

836 837
    While :attr:`reduction` is :attr:`mean`, output loss is in
    shape of [1] and loss value is the mean value of all losses.
838

839 840
    While :attr:`reduction` is :attr:`sum`, output loss is in
    shape of [1] and loss value is the sum value of all losses.
841 842

    While :attr:`reduction` is :attr:`batchmean`, output loss is
843 844 845 846
    in shape of [1] and loss value is the sum value of all losses
    divided by batch size.

    Args:
847
        input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means
848 849 850 851 852 853 854 855 856
             any number of additional dimensions. It's data type should be float32, float64.
        label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64.
        reduction (Tensor): Indicate how to average the loss,
             the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
             If `reduction` is ``'mean'``, the reduced mean loss is returned;
             If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
             if `reduction` is ``'sum'``, the reduced sum loss is returned;
             if `reduction` is ``'none'``, no reduction will be apllied.
             Default is ``'mean'``.
857
        name(str, optional): Name for the operation (optional, default is None). For more information,
858 859 860 861 862 863 864 865 866 867 868
            please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: The KL divergence loss. The data type is same as input tensor

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np
            import paddle.nn.functional as F
869

870 871 872 873
            shape = (5, 20)
            input = np.random.uniform(-10, 10, shape).astype('float32')
            target = np.random.uniform(-10, 10, shape).astype('float32')

L
LielinJiang 已提交
874
            # 'batchmean' reduction, loss shape will be [1]
875 876
            pred_loss = F.kl_div(paddle.to_tensor(input),
                                 paddle.to_tensor(target), reduction='batchmean')
L
LielinJiang 已提交
877
            # shape=[1]
878

879
            # 'mean' reduction, loss shape will be [1]
880 881
            pred_loss = F.kl_div(paddle.to_tensor(input),
                                 paddle.to_tensor(target), reduction='mean')
882 883 884
            # shape=[1]

            # 'sum' reduction, loss shape will be [1]
885 886
            pred_loss = F.kl_div(paddle.to_tensor(input),
                                 paddle.to_tensor(target), reduction='sum')
887 888 889
            # shape=[1]

            # 'none' reduction, loss shape is same with input shape
890 891
            pred_loss = F.kl_div(paddle.to_tensor(input),
                                 paddle.to_tensor(target), reduction='none')
892 893 894
            # shape=[5, 20]

    """
L
LielinJiang 已提交
895 896 897 898
    # ugly type promotion
    if fluid.data_feeder.convert_dtype(
            input.dtype) == 'float32' and fluid.data_feeder.convert_dtype(
                label.dtype) == 'float64':
899
        input = paddle.cast(input, 'float64')
L
LielinJiang 已提交
900 901 902
    elif fluid.data_feeder.convert_dtype(
            input.dtype) == 'float64' and fluid.data_feeder.convert_dtype(
                label.dtype) == 'float32':
903
        label = paddle.cast(label, 'float64')
L
LielinJiang 已提交
904

905
    if paddle.in_dynamic_mode():
906 907 908 909 910 911 912 913 914
        out = _C_ops.kldiv_loss(input, label, 'reduction', 'none')
        if reduction == 'mean':
            out = paddle.mean(out)
        elif reduction == 'sum':
            out = paddle.sum(out)
        elif reduction == 'batchmean':
            if len(input.shape) > 0:
                batch_size = input.shape[0]
                out = paddle.sum(out) / batch_size
915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930
        return out

    helper = LayerHelper('kl_div', **locals())

    fluid.data_feeder.check_variable_and_dtype(input, 'input',
                                               ['float32', 'float64'], 'kl_div')
    fluid.data_feeder.check_variable_and_dtype(label, 'label',
                                               ['float32', 'float64'], 'kl_div')
    fluid.data_feeder.check_type(reduction, 'reduction', str, 'kl_div')

    loss = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='kldiv_loss',
        inputs={'X': input,
                'Target': label},
        outputs={'Loss': loss},
931 932 933 934 935 936 937 938 939
        attrs={'reduction': 'none'})

    if reduction == 'mean':
        loss = paddle.mean(loss)
    elif reduction == 'sum':
        loss = paddle.sum(loss)
    elif reduction == 'batchmean':
        batch_size = paddle.shape(input)[0]
        loss = paddle.sum(loss) / batch_size
940 941 942
    return loss


943
def mse_loss(input, label, reduction='mean', name=None):
944
    r"""
945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
    This op accepts input predications and label and returns the mean square error.

    If :attr:`reduction` is set to ``'none'``, loss is calculated as:

    .. math::
        Out = (input - label)^2

    If :attr:`reduction` is set to ``'mean'``, loss is calculated as:

    .. math::
        Out = \operatorname{mean}((input - label)^2)

    If :attr:`reduction` is set to ``'sum'``, loss is calculated as:

    .. math::
        Out = \operatorname{sum}((input - label)^2)

    Parameters:
        input (Tensor): Input tensor, the data type should be float32 or float64.
        label (Tensor): Label tensor, the data type should be float32 or float64.
        reduction (string, optional): The reduction method for the output,
            could be 'none' | 'mean' | 'sum'.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
            If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.


    Returns:
        Tensor: The tensor tensor storing the mean square error difference of input and label.

    Return type: Tensor.
978

979 980 981
    Examples:

        .. code-block:: python
982

983 984
            import paddle
            mse_loss = paddle.nn.loss.MSELoss()
985 986
            input = paddle.to_tensor(1.5)
            label = paddle.to_tensor(1.7)
987
            output = mse_loss(input, label)
B
Bai Yifan 已提交
988
            print(output)
989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004
            # [0.04000002]

    """

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'mse_loss' should be 'sum', 'mean' or 'none', "
            "but received {}.".format(reduction))

    if not paddle.fluid.framework.in_dygraph_mode():
        paddle.fluid.data_feeder.check_variable_and_dtype(
            input, 'input', ['float32', 'float64'], 'mse_loss')
        paddle.fluid.data_feeder.check_variable_and_dtype(
            label, 'label', ['float32', 'float64'], 'mse_loss')

    if reduction == 'none':
1005
        return paddle.square(paddle.subtract(input, label), name=name)
1006 1007
    elif reduction == 'mean':
        return paddle.mean(
1008
            paddle.square(paddle.subtract(input, label)), name=name)
1009
    else:
1010
        return paddle.sum(paddle.square(paddle.subtract(input, label)),
1011
                          name=name)
1012 1013


1014 1015 1016 1017 1018
def ctc_loss(log_probs,
             labels,
             input_lengths,
             label_lengths,
             blank=0,
1019
             reduction='mean',
H
Hui Zhang 已提交
1020
             norm_by_times=False):
1021 1022
    """

1023 1024 1025
    An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation
1026 1027 1028
    is interated to the Warp-CTC library to normalize values for each row of the input tensor.

    Parameters:
1029
        log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
1030 1031 1032 1033 1034
        labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
        input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
        label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
        blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
        reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
1035
        norm_by_times (bool, default False) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
H
Hui Zhang 已提交
1036

1037 1038
    Returns:
        Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and  ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
1039

1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
    Examples:

        .. code-block:: python

            # declarative mode
            import paddle.nn.functional as F
            import numpy as np
            import paddle

            # length of the longest logit sequence
            max_seq_length = 4
            #length of the longest label sequence
            max_label_length = 3
            # number of logit sequences
            batch_size = 2
            # class num
            class_num = 3

            np.random.seed(1)
            log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
                                    [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],

                                    [[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
                                    [5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],

                                    [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
                                    [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],

                                    [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
                                    [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],

                                    [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
                                    [3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
            labels = np.array([[1, 2, 2],
                            [1, 2, 2]]).astype("int32")
            input_lengths = np.array([5, 5]).astype("int64")
            label_lengths = np.array([3, 3]).astype("int64")

1078 1079 1080 1081
            log_probs = paddle.to_tensor(log_probs)
            labels = paddle.to_tensor(labels)
            input_lengths = paddle.to_tensor(input_lengths)
            label_lengths = paddle.to_tensor(label_lengths)
1082

1083 1084 1085 1086
            loss = F.ctc_loss(log_probs, labels,
                input_lengths,
                label_lengths,
                blank=0,
1087
                reduction='none')
1088
            print(loss)  #[3.9179852 2.9076521]
1089

1090 1091 1092 1093 1094
            loss = F.ctc_loss(log_probs, labels,
                input_lengths,
                label_lengths,
                blank=0,
                reduction='mean')
1095
            print(loss)  #[1.1376063]
1096 1097 1098

    """

1099
    loss_out = fluid.layers.warpctc(log_probs, labels, blank, norm_by_times,
H
Hui Zhang 已提交
1100
                                    input_lengths, label_lengths)
1101

H
Hui Zhang 已提交
1102
    loss_out = fluid.layers.squeeze(loss_out, [-1])
1103 1104
    assert reduction in ['mean', 'sum', 'none']
    if reduction == 'mean':
S
ShenLiang 已提交
1105
        loss_out = paddle.mean(loss_out / label_lengths)
1106 1107 1108 1109 1110
    elif reduction == 'sum':
        loss_out = paddle.sum(loss_out)
    return loss_out


1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
def margin_cross_entropy(logits,
                         label,
                         margin1=1.0,
                         margin2=0.5,
                         margin3=0.0,
                         scale=64.0,
                         group=None,
                         return_softmax=False,
                         reduction='mean'):
    """
    .. math::

        L=-\\frac{1}{N}\sum^N_{i=1}\log\\frac{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}}{e^{s(cos(m_{1}\\theta_{y_i}+m_{2})-m_{3})}+\sum^n_{j=1,j\\neq y_i} e^{scos\\theta_{y_i}}}

    where the :math:`\\theta_{y_i}` is the angle between the feature :math:`x` and
    the representation of class :math:`i`. The details of ArcFace loss
    could be referred to https://arxiv.org/abs/1801.07698.

    .. hint::
        The API supports model parallel and single GPU. And logits.shape[-1] can be different at each rank.

    Args:
G
Guoxia Wang 已提交
1133
        logits (Tensor): shape[N, local_num_classes], the output of the normalized X multiply the normalized W.
1134
                The logits is shard_logits when using model parallel.
G
Guoxia Wang 已提交
1135 1136 1137 1138 1139
        label (Tensor): shape[N] or shape[N, 1], the groud truth label.
        margin1 (float, optional): m1 of margin loss, default value is `1.0`.
        margin2 (float, optional): m2 of margin loss, default value is `0.5`.
        margin3 (float, optional): m3 of margin loss, default value is `0.0`.
        scale (float, optional): s of margin loss, default value is `64.0`.
1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159
        group (Group, optional): The abstract representation of group, see paddle.distributed.collective.Group.
            Default `None`.
        return_softmax (bool, optional): Whether return softmax probability. Default value is `False`.
        reduction (str, optional): The candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
                    If :attr:`reduction` is ``'mean'``, return the average of loss;
                    If :attr:`reduction` is ``'sum'``, return the sum of loss;
                    If :attr:`reduction` is ``'none'``, no reduction will be applied.
                    Default value is `'mean'`.

    Returns:
        ``Tensor`` or Tuple of two ``Tensor`` : Return the cross entropy loss if \
            `return_softmax` is False, otherwise the tuple \
            (loss, softmax), softmax is shard_softmax when \
            using model parallel, otherwise softmax is in \
            the same shape with input logits. If ``reduction == None``, \
            the shape of loss is ``[N, 1]``, otherwise the shape is ``[1]``.

    Examples:

    .. code-block:: python
G
Guoxia Wang 已提交
1160
        :name: code-example1
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208

        # required: gpu
        # Single GPU
        import paddle
        m1 = 1.0
        m2 = 0.5
        m3 = 0.0
        s = 64.0
        batch_size = 2
        feature_length = 4
        num_classes = 4

        label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')

        X = paddle.randn(
            shape=[batch_size, feature_length],
            dtype='float64')
        X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
        X = paddle.divide(X, X_l2)

        W = paddle.randn(
            shape=[feature_length, num_classes],
            dtype='float64')
        W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
        W = paddle.divide(W, W_l2)

        logits = paddle.matmul(X, W)
        loss, softmax = paddle.nn.functional.margin_cross_entropy(
            logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)

        print(logits)
        print(label)
        print(loss)
        print(softmax)
        
        #Tensor(shape=[2, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[ 0.85204151, -0.55557678,  0.04994566,  0.71986042],
        #        [-0.20198586, -0.35270476, -0.55182702,  0.09749021]])
        #Tensor(shape=[2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
        #       [2, 3])
        #Tensor(shape=[2, 1], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[82.37059586],
        #        [12.13448420]])
        #Tensor(shape=[2, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[0.99978819, 0.00000000, 0.00000000, 0.00021181],
        #        [0.99992995, 0.00006468, 0.00000000, 0.00000537]])

    .. code-block:: python
G
Guoxia Wang 已提交
1209
        :name: code-example2
1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322

        # required: distributed
        # Multi GPU, test_margin_cross_entropy.py
        import paddle
        import paddle.distributed as dist
        strategy = dist.fleet.DistributedStrategy()
        dist.fleet.init(is_collective=True, strategy=strategy)
        rank_id = dist.get_rank()
        m1 = 1.0
        m2 = 0.5
        m3 = 0.0
        s = 64.0
        batch_size = 2
        feature_length = 4
        num_class_per_card = [4, 8]
        num_classes = paddle.sum(paddle.to_tensor(num_class_per_card))

        label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
        label_list = []
        dist.all_gather(label_list, label)
        label = paddle.concat(label_list, axis=0)

        X = paddle.randn(
            shape=[batch_size, feature_length],
            dtype='float64')
        X_list = []
        dist.all_gather(X_list, X)
        X = paddle.concat(X_list, axis=0)
        X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
        X = paddle.divide(X, X_l2)

        W = paddle.randn(
            shape=[feature_length, num_class_per_card[rank_id]],
            dtype='float64')
        W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
        W = paddle.divide(W, W_l2)

        logits = paddle.matmul(X, W)
        loss, softmax = paddle.nn.functional.margin_cross_entropy(
            logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)

        print(logits)
        print(label)
        print(loss)
        print(softmax)

        # python -m paddle.distributed.launch --gpus=0,1 test_margin_cross_entropy.py 
        ## for rank0 input
        #Tensor(shape=[4, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[ 0.32888934,  0.02408748, -0.02763289,  0.18173063],
        #        [-0.52893978, -0.10623845, -0.21596515, -0.06432517],
        #        [-0.00536345, -0.03924667,  0.66735314, -0.28640926],
        #        [-0.09907366, -0.48534973, -0.10365338, -0.39472322]])
        #Tensor(shape=[4], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
        #       [11, 1 , 10, 11])

        ## for rank1 input
        #Tensor(shape=[4, 8], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
        #       [[ 0.68654754,  0.28137170,  0.69694954, -0.60923933, -0.57077653,  0.54576703, -0.38709028,  0.56028204],
        #        [-0.80360371, -0.03042448, -0.45107338,  0.49559349,  0.69998950, -0.45411693,  0.61927630, -0.82808600],
        #        [ 0.11457570, -0.34785879, -0.68819499, -0.26189226, -0.48241491, -0.67685711,  0.06510185,  0.49660849],
        #        [ 0.31604851,  0.52087884,  0.53124749, -0.86176582, -0.43426329,  0.34786144, -0.10850784,  0.51566383]])
        #Tensor(shape=[4], dtype=int64, place=CUDAPlace(1), stop_gradient=True,
        #       [11, 1 , 10, 11])

        ## for rank0 output
        #Tensor(shape=[4, 1], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[38.96608230],
        #        [81.28152394],
        #        [69.67229865],
        #        [31.74197251]])
        #Tensor(shape=[4, 4], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
        #       [[0.00000000, 0.00000000, 0.00000000, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000000],
        #        [0.00000000, 0.00000000, 0.99998205, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000000]])
        ## for rank1 output
        #Tensor(shape=[4, 1], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
        #       [[38.96608230],
        #        [81.28152394],
        #        [69.67229865],
        #        [31.74197251]])
        #Tensor(shape=[4, 8], dtype=float64, place=CUDAPlace(1), stop_gradient=True,
        #       [[0.33943993, 0.00000000, 0.66051859, 0.00000000, 0.00000000, 0.00004148, 0.00000000, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000207, 0.99432097, 0.00000000, 0.00567696, 0.00000000],
        #        [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00001795],
        #        [0.00000069, 0.33993085, 0.66006319, 0.00000000, 0.00000000, 0.00000528, 0.00000000, 0.00000000]])
    """

    assert reduction in ['mean', 'sum', 'none', None]
    if group is not None and not group.is_member():
        return

    ring_id = 0 if group is None else group.id
    rank = 0
    nranks = 1
    if core.is_compiled_with_dist():
        parallel_env = paddle.distributed.ParallelEnv()
        global_rank = parallel_env.rank
        rank = global_rank if group is None else group.get_group_rank(
            global_rank)
        nranks = parallel_env.world_size if group is None else group.nranks

    input_dims = len(list(logits.shape))
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
        raise ValueError(
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=-1)

    if in_dygraph_mode():
1323
        softmax, loss = _C_ops.margin_cross_entropy(
1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374
            logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks,
            'margin1', margin1, 'margin2', margin2, 'margin3', margin3, 'scale',
            scale, 'return_softmax', return_softmax)
        if reduction == 'mean':
            loss = paddle.mean(loss)
        elif reduction == 'sum':
            loss = paddle.sum(loss)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

    op_type = 'margin_cross_entropy'
    helper = LayerHelper(op_type, **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)

    check_variable_and_dtype(logits, 'logits',
                             ['float16', 'float32', 'float64'],
                             'margin_cross_entropy')
    check_variable_and_dtype(label, 'label', ['int32', 'int64'],
                             'margin_cross_entropy')

    helper.append_op(
        type=op_type,
        inputs={'Logits': logits,
                'Label': label},
        outputs={'Softmax': softmax,
                 'Loss': loss},
        attrs={
            'return_softmax': return_softmax,
            'ring_id': ring_id,
            'rank': rank,
            'nranks': nranks,
            'margin1': margin1,
            'margin2': margin2,
            'margin3': margin3,
            'scale': scale,
        })

    if reduction == 'mean':
        loss = paddle.mean(loss)
    elif reduction == 'sum':
        loss = paddle.sum(loss)

    if not return_softmax:
        return loss
    else:
        return loss, softmax


1375 1376 1377 1378 1379 1380 1381
@deprecated(
    since="2.0.0",
    update_to="paddle.nn.functional.cross_entropy",
    level=1,
    reason=(
        'Please notice that behavior of "paddle.nn.functional.softmax_with_cross_entropy" '
        'and "paddle.nn.functional.cross_entropy" is different.'))
1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393
def softmax_with_cross_entropy(logits,
                               label,
                               soft_label=False,
                               ignore_index=-100,
                               numeric_stable_mode=True,
                               return_softmax=False,
                               axis=-1):
    return fluid_softmax_with_cross_entropy(logits, label, soft_label,
                                            ignore_index, numeric_stable_mode,
                                            return_softmax, axis)


1394 1395 1396 1397
def cross_entropy(input,
                  label,
                  weight=None,
                  ignore_index=-100,
1398 1399 1400
                  reduction='mean',
                  soft_label=False,
                  axis=-1,
1401
                  use_softmax=True,
1402
                  name=None):
1403
    r"""
H
HydrogenSulfate 已提交
1404 1405 1406
    By default, this operator implements the cross entropy loss function with softmax. This function 
    combines the calculation of the softmax operation and the cross entropy loss function 
    to provide a more numerically stable computing. 
1407

1408
    This operator will calculate the cross entropy loss function without softmax when use_softmax=False.
1409

H
HydrogenSulfate 已提交
1410 1411
    By default, this operator will calculate the mean of the result, and you can also affect 
    the default behavior by using the reduction parameter. Please refer to the part of 
1412
    parameters for details.
1413

1414
    This operator can be used to calculate the softmax cross entropy loss with soft and hard labels.
H
HydrogenSulfate 已提交
1415
    Where, the hard labels mean the actual label value, 0, 1, 2, etc.  And the soft labels 
1416
    mean the probability of the actual label, 0.6, 0.8, 0.2, etc.
1417

1418
    The calculation of this operator includes the following two steps.
1419

1420
    - **1.softmax cross entropy**
1421

1422
        1. Hard label (each sample can only be assigned into one category)
1423

1424
        1.1. when use_softmax=True
1425

1426 1427
            .. math::
              \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N
1428

1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469
            where, N is the number of samples and C is the number of categories.

        1.2. when use_softmax=False

            .. math::
              \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).


        2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1).

        2.1. when use_softmax=True

            .. math::
              \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories.

        2.2. when use_softmax=False

            .. math::
              \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N

            where, N is the number of samples and C is the number of categories, P is input(the output of softmax).




    - **2. Weight and reduction processing**

        1. Weight

            If the ``weight`` parameter is ``None`` , go to the next step directly.

            If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight
            according to soft_label = False or True as follows.

            1.1. Hard labels (soft_label = False)

            .. math::
H
HydrogenSulfate 已提交
1470
                \\loss_j=loss_j*weight[label_j] 
1471

1472

1473 1474 1475 1476 1477 1478 1479
            1.2. Soft labels (soft_label = True)

             .. math::
                \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right)

        2. reduction

H
HydrogenSulfate 已提交
1480
            2.1 if the ``reduction`` parameter is ``none`` 
1481 1482 1483

                Return the previous result directly

H
HydrogenSulfate 已提交
1484
            2.2 if the ``reduction`` parameter is ``sum`` 
1485 1486 1487 1488 1489 1490

                Return the sum of the previous results

            .. math::
               \\loss=\sum_{j}loss_j

H
HydrogenSulfate 已提交
1491 1492
            2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to 
            the ``weight`` parameter as follows. 
1493

H
HydrogenSulfate 已提交
1494
            2.3.1. If the  ``weight``  parameter is ``None`` 
1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507

                   Return the average value of the previous results

             .. math::
                \\loss=\sum_{j}loss_j/N

                  where, N is the number of samples and C is the number of categories.

            2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned

            1. Hard labels (soft_label = False)

             .. math::
H
HydrogenSulfate 已提交
1508
                \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] 
1509 1510 1511 1512 1513

            2. Soft labels (soft_label = True)

             .. math::
                \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right)
H
HydrogenSulfate 已提交
1514 1515
 
 
1516
    Parameters:
1517 1518 1519 1520

        - **input** (Tensor)

            Input tensor, the data type is float32, float64. Shape is
H
HydrogenSulfate 已提交
1521
	    :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes ,  ``k >= 1`` . 
1522

H
HydrogenSulfate 已提交
1523
            Note: 
1524

H
HydrogenSulfate 已提交
1525
                1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the 
1526 1527 1528
                output of softmax operator, which will produce incorrect results.

                2. when use_softmax=False, it expects the output of softmax operator.
H
HydrogenSulfate 已提交
1529
 
1530 1531 1532 1533 1534 1535
        - **label** (Tensor)

            1. If soft_label=False, the shape is
            :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1.
            the data type is int32, int64, float32, float64, where each value is [0, C-1].

H
HydrogenSulfate 已提交
1536
            2. If soft_label=True, the shape and data type should be same with ``input`` , 
1537 1538 1539 1540
            and the sum of the labels for each sample should be 1.

        - **weight** (Tensor, optional)

H
HydrogenSulfate 已提交
1541 1542
            a manual rescaling weight given to each class. 
            If given, has to be a Tensor of size C and the data type is float32, float64. 
1543 1544 1545 1546 1547
            Default is ``'None'`` .

        - **ignore_index** (int64, optional)

            Specifies a target value that is ignored
H
HydrogenSulfate 已提交
1548 1549
            and does not contribute to the loss. A negative value means that no label 
            value needs to be ignored. Only valid when soft_label = False.  
1550 1551 1552 1553 1554
            Default is ``-100`` .

        - **reduction** (str, optional)

            Indicate how to average the loss by batch_size,
1555 1556
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
H
Hui Zhang 已提交
1557
            If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
1558 1559
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
            Default is ``'mean'``.
1560

1561 1562
        - **soft_label** (bool, optional)

H
HydrogenSulfate 已提交
1563
            Indicate whether label is soft. 
1564 1565 1566 1567
            Default is ``False``.

        - **axis** (int, optional)

H
HydrogenSulfate 已提交
1568 1569 1570
            The index of dimension to perform softmax calculations. 
            It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the 
            number of dimensions of input :attr:`input`. 
1571 1572 1573 1574 1575 1576 1577
            Default is ``-1`` .

        - **use_softmax** (bool, optional)

            Indicate whether compute softmax before cross_entropy.
            Default is ``True``.

Z
zhiboniu 已提交
1578
        - **name** (str, optional)
1579 1580 1581

            The name of the operator. Default is ``None`` .
            For more information, please refer to :ref:`api_guide_Name` .
1582 1583 1584

    Returns:

1585 1586
        Tensor. Return the softmax cross_entropy loss of ``input`` and ``label``.
        The data type is the same as input.
1587

1588
        If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``.
1589

1590
        If :attr:`reduction` is ``'none'``:
C
Chen Long 已提交
1591

H
HydrogenSulfate 已提交
1592
        1. If soft_label = False, the dimension of return value is the same with ``label`` . 
C
Chen Long 已提交
1593

H
HydrogenSulfate 已提交
1594
        2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . 
1595 1596 1597 1598 1599


     Example1(hard labels):

        .. code-block:: python
H
HydrogenSulfate 已提交
1600
            
1601 1602 1603 1604 1605
            import paddle
            paddle.seed(99999)
            N=100
            C=200
            reduction='mean'
H
HydrogenSulfate 已提交
1606
            input =  paddle.rand([N, C], dtype='float64')  
1607
            label =  paddle.randint(0, C, shape=[N], dtype='int64')
H
HydrogenSulfate 已提交
1608 1609
            weight = paddle.rand([C], dtype='float64') 
            
1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620
            cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
                weight=weight, reduction=reduction)
            dy_ret = cross_entropy_loss(
                                       input,
                                       label)
            print(dy_ret.numpy()) #[5.41993642]


    Example2(soft labels):

        .. code-block:: python
H
HydrogenSulfate 已提交
1621
            
1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634
            import paddle
            paddle.seed(99999)
            axis = -1
            ignore_index = -100
            N = 4
            C = 3
            shape = [N, C]
            reduction='mean'
            weight = None
            logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0)
            labels /= paddle.sum(labels, axis=axis, keepdim=True)
            paddle_loss_mean = paddle.nn.functional.cross_entropy(
H
HydrogenSulfate 已提交
1635 1636 1637
                                                                  logits,  
                                                                  labels, 
                                                                  soft_label=True, 
1638 1639 1640 1641
                                                                  axis=axis,
                                                                  weight=weight,
                                                                  reduction=reduction)
            print(paddle_loss_mean.numpy()) #[1.12908343]
C
Chen Long 已提交
1642

1643 1644 1645 1646
    """

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
1647 1648 1649
            "The value of 'reduction' in softmax_cross_entropy"
            "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
            % reduction)
1650 1651 1652 1653 1654 1655
    if ignore_index > 0 and soft_label == True:
        raise ValueError(
            "When soft_label == True, the value of 'ignore_index' in softmax_cross_entropy"
            "should be '-100', but received %s, which is not allowed." %
            ignore_index)

1656
    input_dims = len(list(input.shape))
1657 1658 1659
    if input_dims == 0:
        raise ValueError('The dimention of input should be larger than zero!')

1660 1661
    label_dims = len(list(label.shape))
    if input_dims - 1 != label_dims and input_dims != label_dims:
1662
        raise ValueError(
1663 1664 1665 1666 1667
            'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
             (got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
    if input_dims - 1 == label_dims:
        label = paddle.unsqueeze(label, axis=axis)
    if in_dygraph_mode():
1668 1669 1670 1671 1672 1673 1674 1675 1676 1677
        if core.is_compiled_with_npu():
            _, _, out = _C_ops.softmax_with_cross_entropy(
                input, label, 'soft_label', soft_label, 'ignore_index',
                ignore_index, 'numeric_stable_mode', True, 'axis', axis,
                'use_softmax', use_softmax)
        else:
            _, out = _C_ops.softmax_with_cross_entropy(
                input, label, 'soft_label', soft_label, 'ignore_index',
                ignore_index, 'numeric_stable_mode', True, 'axis', axis,
                'use_softmax', use_softmax)
1678

1679
        if weight is not None:
1680

H
HydrogenSulfate 已提交
1681
            # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
1682 1683
            if soft_label == True:
                # chajchaj:
1684
                # weight's shape is C, where C is class num. 
1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695
                # for 1d case: label's shape is [N,C], weight_gather's shape is N.
                # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
                weight_gather = paddle.matmul(
                    x=paddle.cast(label, weight.dtype),
                    y=weight,
                    transpose_x=False,
                    transpose_y=True)
                out_shape = list(out.shape)
                weight_gather_reshape = reshape(weight_gather, shape=out_shape)
                out = paddle.cast(out, weight_gather_reshape.dtype)

W
wanghuancoder 已提交
1696
                out = _C_ops.elementwise_mul(out, weight_gather_reshape)
1697 1698

            else:
1699 1700 1701 1702 1703 1704 1705
                if input.shape[axis] != weight.shape[-1]:
                    raise ValueError(
                        "input's class_dimension({}) must equal to "
                        "weight's class_dimension({}) "
                        "when weight is provided" \
                            .format(input.shape[axis], weight.shape[-1]))

H
HydrogenSulfate 已提交
1706 1707 1708 1709 1710 1711 1712
                ignore_weight_mask = (
                    label != ignore_index)  # ignored position will be False

                valid_label = paddle.cast(
                    ignore_weight_mask,
                    dtype=label.dtype) * label  # ignored position will be 0

H
HydrogenSulfate 已提交
1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728
                if len(paddle.nonzero(valid_label < 0)) > 0:
                    invalid_label = paddle.gather_nd(
                        valid_label, paddle.nonzero(valid_label < 0))
                    raise ValueError(
                        "Target({}) is out of class_dimension's lower bound({})".
                        format(invalid_label[0], 0))
                # TODO: Temporarily use paddle.nonzero instead of paddle.max 
                # to detect and find out possible illegal label values
                if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
                    invalid_label = paddle.gather_nd(
                        valid_label,
                        paddle.nonzero(valid_label >= input.shape[axis]))
                    raise ValueError(
                        "Target({}) is out of class_dimension's upper bound({})".
                        format(invalid_label[0], input.shape[axis] - 1))

H
HydrogenSulfate 已提交
1729 1730
                ignore_weight_mask = paddle.cast(
                    ignore_weight_mask, out.dtype)  # convert from 0 to 0.0
H
HydrogenSulfate 已提交
1731 1732

                if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
1733
                        axis] == 1:
H
HydrogenSulfate 已提交
1734
                    # TODO: Temporarily use squeeze instead of squeeze_
H
HydrogenSulfate 已提交
1735 1736
                    ignore_weight_mask = paddle.squeeze(ignore_weight_mask,
                                                        axis)
H
HydrogenSulfate 已提交
1737
                if axis != -1 and axis != valid_label.ndim - 1:
1738
                    temp_perm = list(range(axis % valid_label.ndim)) \
1739
                                + list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \
H
HydrogenSulfate 已提交
1740
                                + [axis % valid_label.ndim]
1741 1742 1743 1744
                    weight_gather = _C_ops.gather_nd(
                        weight, valid_label.transpose(temp_perm))
                else:
                    weight_gather = _C_ops.gather_nd(weight, valid_label)
H
HydrogenSulfate 已提交
1745 1746
                weight_gather = _C_ops.elementwise_mul(weight_gather,
                                                       ignore_weight_mask)
1747 1748 1749 1750
                input_shape = list(label.shape)
                weight_gather_reshape = reshape(
                    weight_gather, shape=input_shape)
                out = paddle.cast(out, weight_gather_reshape.dtype)
W
wanghuancoder 已提交
1751
                out = _C_ops.elementwise_mul(out, weight_gather_reshape)
1752

1753
        if reduction == "sum":
H
HydrogenSulfate 已提交
1754
            #   because of fluid_softmax_with_cross_entropy op's inner logic,
1755 1756
            #   in the out tensor of this op, the loss of sample with class_index==ignore_index is 0
            #   so, reduce_sum all directly is ok
W
wanghuancoder 已提交
1757
            return _C_ops.reduce_sum(out, 'reduce_all', True)
1758
        elif reduction == "mean":
H
HydrogenSulfate 已提交
1759 1760 1761 1762 1763 1764
            # 1. if weight==none,
            #     numerator: reduce_sum all loss directly is ok causeof fluid_softmax_with_cross_entropy's inner logic
            #     denominator: count sample num with class_index!=ignore_index
            # 2. else
            #     numerator: loss's weighted sum
            #     denominator: cal the sum of weight where the sample's class_index!=ignore_index
1765
            if ignore_index != -100:
W
wanghuancoder 已提交
1766
                out_sum = _C_ops.reduce_sum(out, 'reduce_all', True)
H
HydrogenSulfate 已提交
1767 1768 1769
                # for each label[i],set 1 or 0, according to ignore_index
                # mask[i]=0, if label[i]==ignore_index
                # mask[i]=1, otherwise
1770
                mask = (label != ignore_index)
1771
                if weight is None:
1772
                    mask = paddle.cast(mask, dtype=out_sum.dtype)
W
wanghuancoder 已提交
1773
                    count = _C_ops.reduce_sum(mask, 'reduce_all', True)
1774
                    ret = out_sum / (count + (count == 0.0))
1775 1776
                else:
                    mask = paddle.cast(mask, weight_gather_reshape.dtype)
W
wanghuancoder 已提交
1777
                    weight_ignored = _C_ops.elementwise_mul(
1778
                        mask, weight_gather_reshape)
W
wanghuancoder 已提交
1779 1780
                    weight_sum = _C_ops.reduce_sum(weight_ignored, 'reduce_all',
                                                   True)
1781
                    ret = out_sum / (weight_sum + (weight_sum == 0.0))
1782 1783
                return ret
            elif weight is not None:
W
wanghuancoder 已提交
1784 1785 1786
                out_sum = _C_ops.reduce_sum(out, 'reduce_all', True)
                total_weight = _C_ops.reduce_sum(weight_gather_reshape,
                                                 'reduce_all', True)
1787
                return out_sum / (total_weight + (total_weight == 0.0))
1788
            else:
W
wanghuancoder 已提交
1789
                return _C_ops.mean(out)
1790

1791
        else:
1792 1793
            if input_dims - 1 == label_dims:
                out = paddle.squeeze(out, axis=axis)
1794
            return out
1795

1796 1797 1798
    fluid.data_feeder.check_variable_and_dtype(
        input, 'input', ['float32', 'float64'], 'softmax_cross_entropy')
    fluid.data_feeder.check_variable_and_dtype(
1799 1800
        label, 'label', ['int32', 'int64', 'float32', 'float64'],
        'softmax_cross_entropy')
1801 1802 1803 1804 1805
    attrs = {
        'soft_label': soft_label,
        'ignore_index': ignore_index,
        'numeric_stable_mode': True,
        'axis': axis,
1806
        'use_softmax': use_softmax
1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818
    }
    helper = LayerHelper('softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
    out = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='softmax_with_cross_entropy',
        inputs={'Logits': input,
                'Label': label},
        outputs={'Softmax': softmax,
                 'Loss': out},
        attrs=attrs)

1819
    if weight is not None:
1820 1821 1822
        fluid.data_feeder.check_variable_and_dtype(
            weight, 'weight', ['float32', 'float64'], 'softmax_cross_entropy')
        weight_name = name if reduction == 'none' else None
1823 1824
        if soft_label == True:
            # chajchaj:
H
HydrogenSulfate 已提交
1825
            # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838
            # weight's shape is C, where C is class num.
            # for 1d case: label's shape is [N,C], weight_gather's shape is N.
            # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
            weight_gather = paddle.matmul(
                x=paddle.cast(label, weight.dtype),
                y=weight,
                transpose_x=False,
                transpose_y=True)

            out_shape = list(out.shape)
            weight_gather_reshape = reshape(weight_gather, shape=out_shape)
            out = paddle.cast(out, weight_gather_reshape.dtype)
        else:
1839 1840
            if input.shape[axis] != weight.shape[-1]:
                raise ValueError("input's class_dimension({}) must equal to "
1841 1842
                                 "weight's class_dimension({}) "
                                 "when weight is provided" \
1843
                                 .format(input.shape[axis], weight.shape[-1]))
H
HydrogenSulfate 已提交
1844

H
HydrogenSulfate 已提交
1845 1846 1847 1848 1849 1850 1851
            ignore_weight_mask = (
                label != ignore_index)  # ignored position will be False

            valid_label = paddle.cast(
                ignore_weight_mask,
                dtype=label.dtype) * label  # ignored position will be 0

H
HydrogenSulfate 已提交
1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867
            if len(paddle.nonzero(valid_label < 0)) > 0:
                invalid_label = paddle.gather_nd(
                    valid_label, paddle.nonzero(valid_label < 0))
                raise ValueError(
                    "Target({}) is out of class_dimension's lower bound({})".
                    format(invalid_label[0], 0))
            # TODO: Temporarily use paddle.nonzero instead of paddle.max 
            # to detect and find out possible illegal label values
            if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
                invalid_label = paddle.gather_nd(
                    valid_label,
                    paddle.nonzero(valid_label >= input.shape[axis]))
                raise ValueError(
                    "Target({}) is out of class_dimension's upper bound({})".
                    format(invalid_label[0], input.shape[axis] - 1))

H
HydrogenSulfate 已提交
1868 1869 1870
            ignore_weight_mask = paddle.cast(ignore_weight_mask,
                                             out.dtype)  # convert from 0 to 0.0

H
HydrogenSulfate 已提交
1871
            if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
1872 1873
                    axis] == 1:
                ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
H
HydrogenSulfate 已提交
1874
            if axis != -1 and axis != valid_label.ndim - 1:
1875
                temp_perm = list(range(axis % valid_label.ndim)) \
H
HydrogenSulfate 已提交
1876
                            + list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \
1877 1878 1879 1880 1881
                            + [axis % valid_label.ndim]
                weight_gather = paddle.gather_nd(
                    weight, paddle.transpose(valid_label, temp_perm))
            else:
                weight_gather = paddle.gather_nd(weight, valid_label)
H
HydrogenSulfate 已提交
1882 1883
            weight_gather = paddle.multiply(weight_gather, ignore_weight_mask)

1884 1885
            input_shape = list(label.shape)
            weight_gather_reshape = reshape(weight_gather, shape=input_shape)
1886
        out = paddle.multiply(out, weight_gather_reshape, name=weight_name)
1887

1888 1889 1890
    if reduction == "sum":
        return paddle.sum(out, name=name)
    elif reduction == "mean":
1891 1892
        if ignore_index != -100:
            out_sum = paddle.sum(out, name=name)
H
HydrogenSulfate 已提交
1893 1894 1895
            # for each label[i],set 1 or 0, according to ignore_index
            # mask[i]=0, if label[i]==ignore_index
            # mask[i]=1, otherwise
1896 1897 1898 1899
            mask = (label != ignore_index)
            if (weight is None):
                mask = paddle.cast(mask, dtype=out_sum.dtype)
                count = paddle.sum(mask, name=name)
1900
                ret = out_sum / (count + (count == 0.0))
1901 1902 1903 1904
            else:
                mask = paddle.cast(mask, weight_gather_reshape.dtype)
                weight_ignored = paddle.multiply(mask, weight_gather_reshape)
                weight_sum = paddle.sum(weight_ignored, name=name)
1905
                ret = out_sum / (weight_sum + (weight_sum == 0.0))
1906 1907
            return ret
        elif weight is not None:
1908 1909
            out_sum = paddle.sum(out, name=name)
            total_weight = paddle.sum(weight_gather_reshape)
1910
            return out_sum / (total_weight + (total_weight == 0.0))
1911 1912
        else:
            return paddle.mean(out, name=name)
1913

1914
    else:
1915 1916 1917
        if input_dims - 1 == label_dims:
            out = paddle.squeeze(out, axis=axis)

1918
        return out
1919 1920 1921 1922 1923 1924 1925 1926 1927


def sigmoid_focal_loss(logit,
                       label,
                       normalizer=None,
                       alpha=0.25,
                       gamma=2.0,
                       reduction='sum',
                       name=None):
1928
    r"""
1929 1930 1931 1932 1933 1934
    `Focal Loss <https://arxiv.org/abs/1708.02002>`_ is proposed to address the
    foreground-background class imbalance for classification tasks. It down-weights
    easily-classified examples and thus focuses training on hard examples. For example,
    it is used in one-stage object detection where the foreground-background class
    imbalance is extremely high.

H
HydrogenSulfate 已提交
1935
    This operator measures focal loss function as follows: 
1936 1937

    .. math::
1938
           Out = -Labels * alpha * {(1 - \sigma(Logit))}^{gamma}\log(\sigma(Logit)) - (1 - Labels) * (1 - alpha) * {\sigma(Logit)}^{gamma}\log(1 - \sigma(Logit))
1939

H
HydrogenSulfate 已提交
1940
    We know that :math:`\sigma(Logit) = \frac{1}{1 + \exp(-Logit)}`. 
1941 1942 1943 1944 1945

    Then, if :attr:`normalizer` is not None, this operator divides the
    normalizer tensor on the loss `Out`:

    .. math::
1946
           Out = \frac{Out}{normalizer}
1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966

    Finally, this operator applies reduce operation on the loss.
    If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
    If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
    If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.

    Note that the target ``label`` is 0 for the negative class and is 1 for the positive class.

    Args:
        logit (Tensor): The input logit tensor. The shape is [N, *], where N is batch_size,
            `*` means any number of additional dimensions. The ``logit`` is usually the
            output of a convolution layer. Available dtype is float32, float64.
        label (Tensor): The target label tensor with the same shape as
            ``logit``. The target label whose value should be numbers between 0 and 1.
            Available dtype is float32, float64.
        normalizer (Tensor, optional): The number normalizes the focal loss. It has to be
            a 1-D Tensor whose shape is `[1, ]`. The data type is float32, float64.
            For object detection task, it is the the number of positive samples.
            If set to None, the focal loss will not be normalized. Default is None.
        alpha(int|float, optional): Hyper-parameter to balance the positive and negative example,
H
HydrogenSulfate 已提交
1967
            it should be between 0 and 1.  Default value is set to 0.25. 
1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991
        gamma(int|float, optional): Hyper-parameter to modulate the easy and hard examples.
            Default value is set to 2.0.
        reduction (str, optional): Indicate how to average the loss by batch_size,
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default is ``'sum'``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, if :attr:`reduction` is ``'mean'`` or ``'sum'``, the out shape is :math:`[1]`, otherwise the shape is the same as ``logit``. The same dtype as ``logit`` tensor.

    Examples:

        .. code-block:: python

            import paddle

            logit = paddle.to_tensor([[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]], dtype='float32')
            label = paddle.to_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32')
            one = paddle.to_tensor([1.], dtype='float32')
            fg_label = paddle.greater_equal(label, one)
1992
            fg_num = paddle.sum(paddle.cast(fg_label, dtype='float32'))
1993
            output = paddle.nn.functional.sigmoid_focal_loss(logit, label, normalizer=fg_num)
1994
            print(output)  # [0.65782464]
1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015

    """
    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "The value of 'reduction' in sigmoid_focal_loss "
            "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
            % reduction)

    if normalizer is not None:
        fluid.data_feeder.check_variable_and_dtype(normalizer, 'normalizer',
                                                   ['float32', 'float64'],
                                                   'sigmoid_focal_loss')
        normalizer_shape = list(normalizer.shape)
        normalizer_dims = len(normalizer_shape)
        if normalizer_dims > 1:
            raise ValueError(
                "Expected one dimension of normalizer in sigmoid_focal_loss but got {}.".
                format(normalizer_dims))

    if in_dygraph_mode():
        one = _varbase_creator(dtype=logit.dtype)
W
wanghuancoder 已提交
2016 2017 2018 2019 2020 2021 2022 2023 2024 2025
        _C_ops.fill_constant(one, 'value',
                             float(1.0), 'force_cpu', False, 'dtype', one.dtype,
                             'str_value', '1.0', 'shape', logit.shape)
        loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label)
        pred = _C_ops.sigmoid(logit)
        p_t = _C_ops.elementwise_add(
            _C_ops.elementwise_mul(pred, label),
            _C_ops.elementwise_mul(
                _C_ops.elementwise_sub(one, pred),
                _C_ops.elementwise_sub(one, label)))
2026 2027

        alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype)
W
wanghuancoder 已提交
2028 2029 2030 2031 2032 2033
        alpha_t = _C_ops.elementwise_add(
            _C_ops.elementwise_mul(alpha, label),
            _C_ops.elementwise_mul(
                _C_ops.elementwise_sub(one, alpha),
                _C_ops.elementwise_sub(one, label)))
        loss = _C_ops.elementwise_mul(alpha_t, loss)
2034 2035

        gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype)
W
wanghuancoder 已提交
2036 2037 2038
        gamma_t = _C_ops.elementwise_pow(
            _C_ops.elementwise_sub(one, p_t), gamma)
        loss = _C_ops.elementwise_mul(gamma_t, loss)
2039 2040

        if normalizer is not None:
W
wanghuancoder 已提交
2041
            loss = _C_ops.elementwise_div(loss, normalizer)
2042 2043

        if reduction == "sum":
W
wanghuancoder 已提交
2044
            return _C_ops.reduce_sum(loss, 'reduce_all', True)
2045
        elif reduction == "mean":
W
wanghuancoder 已提交
2046
            return _C_ops.mean(loss)
2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079

        return loss

    fluid.data_feeder.check_variable_and_dtype(
        logit, 'logit', ['float32', 'float64'], 'sigmoid_focal_loss')
    fluid.data_feeder.check_variable_and_dtype(
        label, 'label', ['float32', 'float64'], 'sigmoid_focal_loss')

    bce_name = None
    if reduction == 'none' and normalizer is None:
        bce_name = name
    loss = paddle.nn.functional.binary_cross_entropy_with_logits(
        logit, label, reduction='none', name=bce_name)

    pred = fluid.layers.sigmoid(logit)
    p_t = pred * label + (1 - pred) * (1 - label)

    alpha_t = alpha * label + (1 - alpha) * (1 - label)
    loss = paddle.multiply(alpha_t, loss)

    gamma_t = paddle.pow((1 - p_t), gamma)
    loss = paddle.multiply(gamma_t, loss)

    if normalizer is not None:
        normalizer_name = name if reduction == 'none' else None
        loss = paddle.divide(loss, normalizer, name=normalizer_name)

    if reduction == 'mean':
        loss = paddle.mean(loss, name=name)
    elif reduction == 'sum':
        loss = paddle.sum(loss, name=name)

    return loss
2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176


def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None):
    r"""
    This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
    This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance as :math:`x`,
    and is typically used for learning nonlinear embeddings or semi-supervised learning.

    The loss function for :math:`n`-th sample in the mini-batch is

    .. math::
        l_n = \begin{cases}
            x_n, & \text{if}\; y_n = 1,\\
            \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
        \end{cases}

    and the total loss functions is

    .. math::
        \ell(x, y) = \begin{cases}
            \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
            \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
        \end{cases}

    where :math:`L = \{l_1,\dots,l_N\}^\top`.

    Parameters:
        input (Tensor): Input tensor, the data type is float32 or float64.
            the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64.
        label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64.
            The shape of label is the same as the shape of input.
        margin (float, optional): Specifies the hyperparameter margin to be used.
            The value determines how large the input need to be to calculate in
            hinge_embedding_loss. When label is -1, Input smaller than margin are minimized with hinge_embedding_loss.
            Default = 1.0
        reduction (str, optional): Indicate how to average the loss by batch_size.
            the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
            If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
            If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
            If :attr:`reduction` is ``'sum'``, the summed loss is returned.
            Default: ``'mean'``
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Shape:

        input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. The sum operationoperates over all the elements.

        label: N-D Tensor, same shape as the input. tensor elements should containing 1 or -1, the data type is float32 or float64.

        output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.

    Returns:
        Tensor. The tensor variable storing the hinge_embedding_loss of input and label.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
            # label elements in {1., -1.}
            label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)

            loss = F.hinge_embedding_loss(input, label, margin=1.0, reduction='none')
            print(loss)
            # Tensor([[0., -2., 0.],
            #         [0., -1., 2.],
            #         [1., 1., 1.]])

            loss = F.hinge_embedding_loss(input, label, margin=1.0, reduction='mean')
            print(loss)
            # Tensor([0.22222222])
    """

    if reduction not in ['sum', 'mean', 'none']:
        raise ValueError(
            "'reduction' in 'hinge_embedding_loss' should be 'sum', 'mean' or 'none', "
            "but received {}.".format(reduction))

    if not paddle.fluid.framework.in_dygraph_mode():
        check_variable_and_dtype(input, 'input', ['float32', 'float64'],
                                 'hinge_embedding_loss')
        check_variable_and_dtype(label, 'label', ['float32', 'float64'],
                                 'hinge_embedding_loss')

    zero_ = paddle.zeros([1], dtype=input.dtype)
    loss = paddle.where(label == 1., input, zero_) + \
           paddle.where(label == -1., paddle.nn.functional.relu(margin - input), zero_)

    if reduction == 'mean':
        return paddle.mean(loss, name=name)
    elif reduction == 'sum':
        return paddle.sum(loss, name=name)
    elif reduction == 'none':
        return loss