loss.py 74.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
from functools import partial, reduce
19
from paddle.utils import deprecated
20 21 22
from . import nn
from .layer_function_generator import templatedoc
from ..layer_helper import LayerHelper
23 24
from ..framework import Variable, in_dygraph_mode
from .. import core
25
from ..data_feeder import check_variable_and_dtype, check_type
26
from ..param_attr import ParamAttr
S
ShenLiang 已提交
27 28
from ..initializer import NumpyArrayInitializer, Constant
from .. import core
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

__all__ = [
    'center_loss',
    'bpr_loss',
    'cross_entropy',
    'square_error_cost',
    'edit_distance',
    'warpctc',
    'nce',
    'hsigmoid',
    'sampled_softmax_with_cross_entropy',
    'softmax_with_cross_entropy',
    'rank_loss',
    'margin_rank_loss',
    'sigmoid_cross_entropy_with_logits',
    'teacher_student_sigmoid_loss',
    'huber_loss',
    'kldiv_loss',
    'npair_loss',
    'mse_loss',
]

kIgnoreIndex = -100


def center_loss(input,
                label,
                num_classes,
                alpha,
                param_attr,
                update_center=True):
    """
61 62 63 64 65
    :api_attr: Static Graph
	:alias_main: paddle.nn.functional.center_loss
	:alias: paddle.nn.functional.center_loss,paddle.nn.functional.loss.center_loss
	:old_api: paddle.fluid.layers.center_loss

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    **Center loss Cost layer**
    
    This OP accepts input (deep features,the output of the last hidden layer)
    and target label and return the center loss cost. The average of the 
    distances of each sample in the mini-batch from the center of the 
    corresponding category is calculated as the center loss.
    
    For deep features, :math:`X`, and target labels, :math:`Y`, the equation is:
    
    .. math::

        Out = \\frac{1}{2}(X - Y)^2

    Args:
        input (Variable): a 2-D tensor with shape[N x M]. Its dtype should be float32 or float64.
        label (Variable): the groud truth which is a 2-D tensor
                         with shape[N x 1],where N is the batch size. Its dtype should be int32.
        num_classes (int): the number of classification categories.
        alpha (float|Variable): learning rate of centers.
        param_attr (ParamAttr): Attribute initializer of centers. 
        update_center (bool): whether to update value of center.

    Returns:
        Variable: 2-D tensor with shape [N * 1] 

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid 

          input = fluid.data(name='x',shape=[20,30],dtype='float32')
          label = fluid.data(name='y',shape=[20,1],dtype='int64')
          num_classes = 1000
          alpha = 0.01
          param_attr = fluid.initializer.Xavier(uniform=False)
          center_loss=fluid.layers.center_loss(input=input,
                 label=label,
                 num_classes=1000,
                 alpha=alpha,
                 param_attr=fluid.initializer.Xavier(uniform=False),
                 update_center=True)
    """
    helper = LayerHelper('center_loss', **locals())
    dtype = helper.input_dtype()
110 111 112 113
    check_variable_and_dtype(input, 'input', ['float32', 'float64'],
                             'center_loss')
    check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'center_loss')

114 115 116 117 118 119 120
    centers_shape = [num_classes, input.shape[1]]
    centers_param = helper.create_parameter(
        attr=param_attr, shape=centers_shape, dtype=dtype)
    centers_param.stop_gradient = True

    if isinstance(alpha, Variable):
        alpha_param = alpha
121 122
        check_variable_and_dtype(alpha, 'alpha', ['float32', 'float64'],
                                 'center_loss')
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    else:
        assert isinstance(alpha, float)
        alpha_param = helper.create_variable(
            name="centerloss_alpha",
            shape=[1],
            dtype="float32",
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=True,
            stop_gradient=True,
            initializer=Constant(alpha))

    centersdiff = helper.create_variable_for_type_inference(dtype=input.dtype)
    loss = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='center_loss',
        inputs={
            'X': [input],
            'Label': [label],
            'Centers': [centers_param],
            'CenterUpdateRate': [alpha_param]
        },
        outputs={
            'SampleCenterDiff': [centersdiff],
            'Loss': [loss],
            'CentersOut': [centers_param]
        },
        attrs={'cluster_num': num_classes,
               'need_update': update_center})
    return loss


def bpr_loss(input, label, name=None):
    """
156 157 158 159
    :alias_main: paddle.nn.functional.bpr_loss
	:alias: paddle.nn.functional.bpr_loss,paddle.nn.functional.loss.bpr_loss
	:old_api: paddle.fluid.layers.bpr_loss

160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    **Bayesian Personalized Ranking Loss Operator**

    This operator belongs to pairwise ranking loss. Label is the desired item.
    The loss at a given point in one session is defined as:

    .. math::
        Y[i] = 1/(N[i] - 1) * \sum_j{\log(\sigma(X[i, Label[i]]-X[i, j]))}

    Learn more details by reading paper <session-based recommendations with recurrent
    neural networks>.

    Args:
        input (Variable|list):  a 2-D tensor with shape [N x D], where N is the
                                batch size and D is the number of positive classes and negative classes
                                This input is not probability but logits.
        label (Variable|list):  the ground truth which is a 2-D tensor.  `label`
                                is a tensor<int64> with shape [N x 1].
        name (str|None):        A name for this layer(optional). If set None, the
                                layer will be named automatically. Default: None.
    Returns:
        A 2-D tensor with shape [N x 1], the bpr loss.

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid

          neg_size = 10
          label = fluid.data(
                    name="label", shape=[3, 1], dtype="int64")
          predict = fluid.data(
                    name="predict", shape=[3, neg_size + 1], dtype="float32")
          cost = fluid.layers.bpr_loss(input=predict, label=label)
    """
    helper = LayerHelper('bpr_loss', **locals())
    out = helper.create_variable_for_type_inference(dtype=input.dtype)
196 197
    check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
                             'bpr_loss')
198 199 200 201 202 203 204 205 206 207
    helper.append_op(
        type='bpr_loss',
        inputs={'X': [input],
                'Label': [label]},
        outputs={'Y': [out]})
    return out


def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
    """
208 209 210 211
    :alias_main: paddle.nn.functional.cross_entropy
	:alias: paddle.nn.functional.cross_entropy,paddle.nn.functional.loss.cross_entropy
	:old_api: paddle.fluid.layers.cross_entropy

212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    This operator computes the cross entropy between input and label. It
    supports both hard-label and and soft-label cross entropy computation.

    1. Hard-label cross entropy: if soft_label=False, :math:`label[i_1, i_2, ..., i_k]`
       is the hard label of each sample.

        .. math::

           output[i_1, i_2, ..., i_k]=-log(input[i_1, i_2, ..., i_k, j]), label[i_1, i_2, ..., i_k] = j, j != ignore\_index

    2. Soft-label cross entropy: if soft_label=True,  :math:`label[i_1, i_2, ..., i_k, j]`
       is the soft label of each sample corresponding to the j-th class.

        .. math::

           output[i_1, i_2, ..., i_k]= -\sum_{j}label[i_1,i_2,...,i_k,j]*log(input[i_1, i_2, ..., i_k,j])

    Args:
        input (Variable): a multidimensional Tensor with shape
                :math:`[N_1, N_2, ..., N_k, D]`, where the last dimension D is
                the class number. The data type should be float32 or float64.
        label (Variable): label value corresponding to input. If
                soft_label=False, the dimension of label should be :math:`[N_1, N_2, ..., N_k]`
                or :math:`[N_1, N_2, ..., N_k, 1]` , and its data type should be int64,
                and the value must be inside [0, D). If soft_label=True, the shape,
                data type of label should be the same with input, and the sum of
                soft label value of each sample should be 1.
        soft_label (bool): indicate whether label is soft. Default False, meaning that
                the label is hard. If soft_label=True, the label is soft.
        ignore_index (int): specify an ignorable label value. The ignored label would be
                omitted when computing. If it is a negative integer, no label would
                be ignored. Only valid when soft_label=False. Default -100.

    Returns:
         A Variable holding Tensor representing the cross entropy, whose data type is the same with input.
         If soft_label=False, the shape of output is the same with label.
         If soft_label=True, the shape of output is :math:`[N_1, N_2, ..., N_k, 1]` .

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid
            class_num = 7
            x = fluid.data(name='x', shape=[None, 3, 10], dtype='float32')
            label = fluid.data(name='label', shape=[None, 1], dtype='int64')
            predict = fluid.layers.fc(input=x, size=class_num, act='softmax')
            cost = fluid.layers.cross_entropy(input=predict, label=label)
    """
260 261 262
    if not soft_label:
        return cross_entropy2(input, label, ignore_index)

263 264 265 266
    if in_dygraph_mode():
        return core.ops.cross_entropy(input, label, "soft_label", soft_label,
                                      "ignore_index", ignore_index)

267 268 269
    inputs = {'X': [input], 'Label': [label]}
    attrs = {"soft_label": soft_label, "ignore_index": ignore_index}

270 271
    check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
                             'cross_entropy')
272 273 274
    helper = LayerHelper('cross_entropy', **locals())
    out = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
275
        type='cross_entropy', inputs=inputs, outputs={'Y': [out]}, attrs=attrs)
276 277 278 279
    return out


def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
280
    if in_dygraph_mode():
281 282 283
        loss, _, _ = core.ops.cross_entropy2(input, label, 'ignore_index',
                                             ignore_index)
        return loss
284

285 286
    inputs = {'X': [input], 'Label': [label]}
    attrs = {'ignore_index': ignore_index}
287 288
    check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'],
                             'cross_entropy2')
289 290 291 292 293 294
    helper = LayerHelper('cross_entropy2', **locals())
    out = helper.create_variable_for_type_inference(dtype=input.dtype)
    xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
    match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='cross_entropy2',
295
        inputs=inputs,
296 297 298
        outputs={'Y': [out],
                 'MatchX': [match_x],
                 'XShape': [xshape]},
299
        attrs=attrs)
300 301 302 303 304
    return out


def square_error_cost(input, label):
    """
305

306 307 308 309 310 311 312 313 314 315
    This op accepts input predictions and target label and returns the
    squared error cost.

    For predictions label, and target label, the equation is:

    .. math::

        Out = (input - label)^2

    Parameters:
316 317
        input (Tensor): Input tensor, the data type should be float32.
        label (Tensor): Label tensor, the data type should be float32.
318 319

    Returns:
320
        The tensor storing the element-wise squared error \
321 322
                  difference between input and label.

323
    Return type: Tensor.
324 325 326 327 328

    Examples:

        .. code-block:: python

329 330 331 332 333 334 335
            import paddle
            input = paddle.to_tensor([1.1, 1.9])
            label = paddle.to_tensor([1.0, 2.0])
            output = paddle.nn.functional.square_error_cost(input, label)
            print(output.numpy())
            # [0.01, 0.01]

336
    """
337 338 339 340
    check_variable_and_dtype(input, "input", ['float32', 'float64'],
                             'square_error_cost')
    check_variable_and_dtype(label, "label", ['float32', 'float64'],
                             'square_error_cost')
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
    helper = LayerHelper('square_error_cost', **locals())
    minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='elementwise_sub',
        inputs={'X': [input],
                'Y': [label]},
        outputs={'Out': [minus_out]})

    square_out = helper.create_variable_for_type_inference(dtype=input.dtype)
    helper.append_op(
        type='square', inputs={'X': [minus_out]},
        outputs={'Out': [square_out]})
    return square_out


def edit_distance(input,
                  label,
                  normalized=True,
                  ignored_tokens=None,
                  input_length=None,
                  label_length=None):
    """
R
ruri 已提交
363 364 365 366
    This op computes the edit distances, also called Levenshtein distance, between a batch of
    hypothesis strings and their references. It measures how dissimilar two strings are by counting
    the minimum number of operations to transform one string into another.
    The operations include insertion, deletion, and substitution.
367 368

    For example, given hypothesis string A = "kitten" and reference
R
ruri 已提交
369
    B = "sitting", A will be transformed into B
370 371 372 373
    at least after two substitutions and one insertion:

    "kitten" -> "sitten" -> "sittin" -> "sitting"

R
ruri 已提交
374
    So the edit distance between A and B is 3.
375

R
ruri 已提交
376 377 378 379 380 381 382 383 384
    The input is a LoDTensor or Tensor.
    If it is a LoDTensor, The separation is specified by the LoD information.
    If it is a Tensor, The input_length and label_length should be supported.

    The `batch_size` of labels should be same as `input`.

    The output include the edit distance value between every pair of input and related label, and the number of sequence.
    If Attr(normalized) is true,
    the edit distance value will be divided by the length of label.
385 386

    Parameters:
R
ruri 已提交
387 388 389 390
        input(Variable): The input variable which is a tensor or LoDTensor, its rank should be equal to 2 and its data type should be int64.
        label(Variable): The label variable which is a tensor or LoDTensor, its rank should be equal to 2 and its data type should be int64.
        normalized(bool, default True): Indicated whether to normalize the edit distance.
        ignored_tokens(list<int>, default None): Tokens that will be removed before
391
                                     calculating edit distance.
R
ruri 已提交
392 393 394 395
        input_length(Variable): The length for each sequence in `input` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
        label_length(Variable): The length for each sequence in `label` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
        NOTE: To be avoid unexpected result, the value of every elements in input_length and label_length should be equal to the value of the second dimension of input and label. For example, The input: [[1,2,3,4],[5,6,7,8],[9,10,11,12]], the shape of input is [3,4] and the input_length should be [4,4,4]
        NOTE: This Api is different from fluid.metrics.EditDistance
396 397 398 399

    Returns:
	Tuple:

R
ruri 已提交
400 401
        distance(Variable): edit distance result, its data type is float32, and its shape is (batch_size, 1).
        sequence_num(Variable): sequence number, its data type is float32, and its shape is (1,).
402 403 404 405 406

    Examples:
        .. code-block:: python
            
            import paddle.fluid as fluid
R
ruri 已提交
407
            import numpy as np
408 409 410 411 412 413 414

            # using LoDTensor
            x_lod = fluid.data(name='x_lod', shape=[None,1], dtype='int64', lod_level=1)
            y_lod = fluid.data(name='y_lod', shape=[None,1], dtype='int64', lod_level=1)
            distance_lod, seq_num_lod = fluid.layers.edit_distance(input=x_lod, label=y_lod)

            # using Tensor
R
ruri 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
            input_data = np.array([[1,2,3],[4,5,6],[4,4,4],[1,1,1]]).astype('int64')
            label_data = np.array([[1,3,4,1],[4,5,8,1],[7,7,7,1],[1,1,1,1]]).astype('int64')
            input_len = np.array([3,3,3,3]).astype('int64')
            label_len = np.array([4,4,4,4]).astype('int64')

            input_t = fluid.data(name='input', shape=[None,3], dtype='int64')
            label_t = fluid.data(name='label', shape=[None,4], dtype='int64')
            input_len_t = fluid.data(name='input_length', shape=[None], dtype='int64')
            label_len_t = fluid.data(name='label_length', shape=[None], dtype='int64')

            distance, sequence_num = fluid.layers.edit_distance(input=input_t, label=label_t, input_length=input_len_t, label_length=label_len_t,normalized=False)

            # print(input_data.shape, label_data.shape)
            # ((4,3), (4,4))

            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(fluid.default_startup_program())
            dis, seq_num = exe.run(fluid.default_main_program(),
                                   feed={"input":input_data,
                                         "label":label_data,
                                         "input_length": input_len,
                                         "label_length": label_len},
            fetch_list=[distance,sequence_num])
            # print(dis)
            # [[3.]
            #  [2.]
            #  [4.]
            #  [1.]]
            # if set normalized to True
            # [[0.75]
            #  [0.5 ]
            #  [1.  ]
            #  [0.25]
            #
            # print(seq_num)
            # [4]
452 453

    """
454 455
    check_variable_and_dtype(input, 'input', ['int64'], 'edit_distance')
    check_variable_and_dtype(label, 'label', ['int64'], 'edit_distance')
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
    helper = LayerHelper("edit_distance", **locals())

    # remove some tokens from input and labels
    if ignored_tokens is not None and len(ignored_tokens) > 0:
        erased_input = helper.create_variable_for_type_inference(dtype="int64")
        erased_label = helper.create_variable_for_type_inference(dtype="int64")

        helper.append_op(
            type="sequence_erase",
            inputs={"X": [input]},
            outputs={"Out": [erased_input]},
            attrs={"tokens": ignored_tokens})
        input = erased_input

        helper.append_op(
            type="sequence_erase",
            inputs={"X": [label]},
            outputs={"Out": [erased_label]},
            attrs={"tokens": ignored_tokens})
        label = erased_label

    this_inputs = {"Hyps": [input], "Refs": [label]}
478
    if input_length is not None and label_length is not None:
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
        this_inputs['HypsLength'] = [input_length]
        this_inputs['RefsLength'] = [label_length]

    # edit distance op
    edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
    sequence_num = helper.create_variable_for_type_inference(dtype="int64")
    helper.append_op(
        type="edit_distance",
        inputs=this_inputs,
        outputs={"Out": [edit_distance_out],
                 "SequenceNum": [sequence_num]},
        attrs={"normalized": normalized})

    return edit_distance_out, sequence_num


def warpctc(input,
            label,
            blank=0,
            norm_by_times=False,
            input_length=None,
            label_length=None):
    """
    An operator integrating the open source Warp-CTC library
    (https://github.com/baidu-research/warp-ctc)
    to compute Connectionist Temporal Classification (CTC) loss.
    It can be aliased as softmax with CTC, since a native softmax activation is
T
tianshuo78520a 已提交
506
    interated to the Warp-CTC library to normalize values for each row of the
507 508 509 510 511
    input tensor.

    Args:
       input (Variable): The unscaled probabilities of variable-length sequences,
         which is a 2-D Tensor with LoD information, or a 3-D Tensor without Lod
512 513 514 515 516 517
         information. When it is a 2-D LodTensor, its shape is 
         `[Lp, num_classes + 1]`, where `Lp` is the sum of all input
         sequences' length and `num_classes` is the true number of classes.
         (not including the blank label). When it is a 3-D Tensor, its shape 
         is `[max_logit_length, batch_size, num_classes + 1]`,
         where `max_logit_length` is the longest length of
518
         input logit sequence. The data type should be float32 or float64.
519
       label (Variable): The ground truth of variable-length sequence,
520 521 522 523 524 525
         which must be a 2-D Tensor with LoD information or a 3-D Tensor without
         LoD information, needs to be consistent with the coressponding input. 
         When it is a 2-D LoDTensor, its shape is `[Lg, 1]`, where `Lg` is the sum 
         of all labels' length. When it is a 3-D Tensor, its shape is 
         `[batch_size, max_label_length]`, where `max_label_length` is the longest
         length of label sequence. Data type must be int32.
526 527
       blank (int, default 0): The blank label index of Connectionist
         Temporal Classification (CTC) loss, which is in the
528
         half-opened interval `[0, num_classes + 1)`. The data type must be int32. 
529 530 531
       norm_by_times(bool, default false): Whether to normalize the gradients
         by the number of time-step, which is also the sequence's length.
         There is no need to normalize the gradients if warpctc layer was
T
tianshuo78520a 已提交
532
         followed by a mean_op.
533 534 535 536 537 538 539
       input_length(Variable): The length for each input sequence if it is 
         of Tensor type, it should have shape `[batch_size]` and dtype int64.
       label_length(Variable): The length for each label sequence if it is
         of Tensor type, it should have shape `[batch_size]` and dtype int64.

    Returns:
        Variable: The Connectionist Temporal Classification (CTC) loss,
540
        which is a 2-D Tensor with the shape `[batch_size, 1]`.
541 542 543 544 545 546 547
        The date type is the same as input.

    Examples:

        .. code-block:: python

            # using LoDTensor
548
            import paddle
549 550
            import paddle.fluid as fluid
            import numpy as np
551 552 553 554 555 556 557 558

            # lengths of logit sequences
            seq_lens = [2,6]
            # lengths of label sequences
            label_lens = [2,3]
            # class num
            class_num = 5

559
            paddle.enable_static()
560 561
            logits = fluid.data(name='logits',shape=[None, class_num+1],
                                 dtype='float32',lod_level=1)
562
            label = fluid.data(name='label', shape=[None, 1],
563 564
                               dtype='int32', lod_level=1)
            cost = fluid.layers.warpctc(input=logits, label=label)
565
            place = fluid.CPUPlace()
566 567 568 569 570 571
            x = fluid.create_lod_tensor(
                     np.random.rand(np.sum(seq_lens), class_num+1).astype("float32"), 
                     [seq_lens], place)
            y = fluid.create_lod_tensor(
                     np.random.randint(0, class_num, [np.sum(label_lens), 1]).astype("int32"), 
                     [label_lens], place)
572
            exe = fluid.Executor(place)
573 574 575 576
            output= exe.run(fluid.default_main_program(),
                            feed={"logits": x,"label": y},
                            fetch_list=[cost.name])
            print(output)
577 578 579 580

        .. code-block:: python

            # using Tensor
581
            import paddle
582 583
            import paddle.fluid as fluid
            import numpy as np
584

585 586
            # length of the longest logit sequence
            max_seq_length = 5
587 588
            #length of the longest label sequence
            max_label_length = 3
589
            # number of logit sequences
590 591 592
            batch_size = 16
            # class num
            class_num = 5
593
            paddle.enable_static()
594 595 596
            logits = fluid.data(name='logits',
                           shape=[max_seq_length, batch_size, class_num+1],
                           dtype='float32')
597
            logits_length = fluid.data(name='logits_length', shape=[None],
598 599 600 601 602
                             dtype='int64')
            label = fluid.data(name='label', shape=[batch_size, max_label_length],
                           dtype='int32')
            label_length = fluid.data(name='labels_length', shape=[None],
                             dtype='int64')
603
            cost = fluid.layers.warpctc(input=logits, label=label,
604 605
                            input_length=logits_length,
                            label_length=label_length)
606
            place = fluid.CPUPlace()
607 608
            x = np.random.rand(max_seq_length, batch_size, class_num+1).astype("float32")
            y = np.random.randint(0, class_num, [batch_size, max_label_length]).astype("int32")
609
            exe = fluid.Executor(place)
610 611
            output= exe.run(fluid.default_main_program(),
                            feed={"logits": x,
612
                                  "label": y,
613 614
                                  "logits_length": np.array([max_seq_length]*batch_size).astype("int64"),
                                  "labels_length": np.array([max_label_length]*batch_size).astype("int64")},
615 616 617
                                  fetch_list=[cost.name])
            print(output)
    """
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
    if in_dygraph_mode():
        if input_length is None or label_length is None:
            raise ValueError(
                "input_length and label_length must not be None in dygraph mode!"
            )
        grad, loss_out = core.ops.warpctc(
            input,
            label,
            input_length,
            label_length,
            'blank',
            blank,
            'norm_by_times',
            norm_by_times, )
        return loss_out
633
    helper = LayerHelper('warpctc', **locals())
634
    check_variable_and_dtype(input, 'input', ['float32', 'float64'], "warpctc")
635
    check_variable_and_dtype(label, 'label', ['int32'], "warpctc")
636
    this_inputs = {'Logits': [input], 'Label': [label]}
637
    if input_length is not None and label_length is not None:
638 639 640 641
        check_variable_and_dtype(input_length, 'LogitsLength', ['int64'],
                                 "warpctc")
        check_variable_and_dtype(label_length, 'LabelLength', ['int64'],
                                 "warpctc")
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
        this_inputs['LogitsLength'] = [input_length]
        this_inputs['LabelLength'] = [label_length]

    loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
    grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)

    helper.append_op(
        type='warpctc',
        inputs=this_inputs,
        outputs={'WarpCTCGrad': [grad_out],
                 'Loss': [loss_out]},
        attrs={
            'blank': blank,
            'norm_by_times': norm_by_times,
        })
    return loss_out


# FIXME(wuyi): let docstring_checker.py understand @autodoc.
# For now, the comments in c++ use types like Tensor, but in python side
# the type is often "Variable", and arguments may vary.
@templatedoc(op_type="nce")
def nce(input,
        label,
        num_total_classes,
        sample_weight=None,
        param_attr=None,
        bias_attr=None,
        num_neg_samples=None,
        name=None,
        sampler="uniform",
        custom_dist=None,
        seed=0,
        is_sparse=False):
    """
677 678
    :api_attr: Static Graph

679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
    ${comment}

    Args:
        input (Variable): Input variable, 2-D tensor with shape [batch_size, dim], 
            and data type is float32 or float64.
        label (Variable): Input label, 2-D tensor with shape [batch_size, num_true_class],
            and data type is int64.
        num_total_classes (int):${num_total_classes_comment}.
        sample_weight (Variable|None): A Variable of shape [batch_size, 1]
            storing a weight for each sample. The default weight for each
            sample is 1.0.
        param_attr (ParamAttr|None): To specify the weight parameter attribute. 
            Default: None, which means the default weight parameter property is 
            used. See usage for details in :ref:`api_fluid_ParamAttr` .
        bias_attr (ParamAttr|None): To specify the bias parameter attribute. 
            Default: None, which means the default bias parameter property is 
            used. See usage for details in :ref:`api_fluid_ParamAttr` .
        num_neg_samples (int): ${num_neg_samples_comment}.
        name(str|None): For detailed information, please refer to 
            :ref:`api_guide_Name` . Usually name is no need to set and None by default.
T
tianshuo78520a 已提交
699
        sampler (str, optional): The sampler used to sample class from negative classes.
700 701 702 703
                       It can be 'uniform', 'log_uniform' or 'custom_dist'.
                       default: 'uniform'.
        custom_dist (nd.array|None): A numpy ndarray with size=num_total_classes.
                       It is used when sampler is set to 'custom_dist'.
T
tianshuo78520a 已提交
704
                       custom_dist[i] is the probability of i-th class to be sampled.
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
                       default: None.
        seed (int, optional): The seed used in sampler. Default 0, means no random seed.
        is_sparse(bool, optional): The flag indicating whether to use sparse update, 
            the weight@GRAD and bias@GRAD will be changed to SelectedRows. Default False.

    Returns:
        Variable: The output nce loss.

    Examples:
        .. code-block:: python


            import paddle.fluid as fluid
            import numpy as np

            window_size = 5
            words = []
722
            for i in range(window_size):
723 724 725 726 727 728 729
                words.append(fluid.data(
                    name='word_{0}'.format(i), shape=[-1, 1], dtype='int64'))

            dict_size = 10000
            label_word = int(window_size / 2) + 1

            embs = []
730
            for i in range(window_size):
731 732 733 734 735 736 737 738 739 740 741 742
                if i == label_word:
                    continue

                emb = fluid.layers.embedding(input=words[i], size=[dict_size, 32],
                                   param_attr='embed', is_sparse=True)
                embs.append(emb)

            embs = fluid.layers.concat(input=embs, axis=1)
            loss = fluid.layers.nce(input=embs, label=words[label_word],
                      num_total_classes=dict_size, param_attr='nce.w_0',
                      bias_attr='nce.b_0')

743 744 745 746 747 748 749 750
            #or use custom distribution
            dist = np.array([0.05,0.5,0.1,0.3,0.05])
            loss = fluid.layers.nce(input=embs, label=words[label_word],
                    num_total_classes=5, param_attr='nce.w_1',
                    bias_attr='nce.b_1',
                    num_neg_samples=3,
                    sampler="custom_dist",
                    custom_dist=dist)
751 752
    """
    helper = LayerHelper('nce', **locals())
753 754
    check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nce')
    check_variable_and_dtype(label, 'label', ['int64'], 'nce')
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889

    dim = input.shape[1]
    num_true_class = label.shape[1]
    w = helper.create_parameter(
        attr=helper.param_attr,
        shape=[num_total_classes, dim],
        is_bias=False,
        dtype=input.dtype)
    inputs = {}
    if helper.bias_attr:
        b = helper.create_parameter(
            attr=helper.bias_attr,
            shape=[num_total_classes, 1],
            is_bias=True,
            dtype=input.dtype)
        inputs['Bias'] = b
    cost = helper.create_variable_for_type_inference(dtype=input.dtype)
    sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
    sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)

    inputs['Input'] = input
    inputs['Label'] = label
    inputs['Weight'] = w
    inputs['SampleWeight'] = sample_weight if sample_weight is not None else []

    if sampler == "uniform":
        sampler = 0
    elif sampler == "log_uniform":
        sampler = 1
    elif sampler == "custom_dist":
        assert custom_dist is not None

        custom_dist_len = num_total_classes
        alias_probs_ = [0] * custom_dist_len
        alias_ = [0] * custom_dist_len
        bigs = []
        littles = []
        for i in range(custom_dist_len):
            normal_prob = custom_dist[i] * custom_dist_len
            if normal_prob - 1.0 > 0:
                bigs.append((i, normal_prob))
            elif 1.0 - normal_prob > 0:
                littles.append((i, normal_prob))
            else:
                alias_probs_[i] = normal_prob
                alias_[i] = -1

        while len(bigs) and len(littles):
            big = bigs.pop(0)
            little = littles.pop(0)

            big_idx = big[0]
            big_prob = big[1]

            alias_probs_[little[0]] = little[1]
            alias_[little[0]] = big_idx
            big_left = big[1] + little[1] - 1
            if big_left - 1.0 > 0:
                bigs.append((big_idx, big_left))
            elif 1.0 - big_left > 0:
                littles.append((big_idx, big_left))
            else:
                alias_probs_[big_idx] = big_left
                alias_[big_idx] = -1

        if len(bigs):
            big = bigs.pop(0)
            alias_probs_[big[0]] = 1.0
            alias_[big[0]] = -1
        if len(littles):
            little = littles.pop(0)
            alias_probs_[little[0]] = 1.0
            alias_[little[0]] = -1

        def _init_by_numpy_array(numpy_array):
            ret = helper.create_parameter(
                attr=ParamAttr(),
                shape=numpy_array.shape,
                dtype=numpy_array.dtype,
                default_initializer=NumpyArrayInitializer(numpy_array))
            ret.stop_gradient = True
            return ret

        inputs['CustomDistProbs'] = _init_by_numpy_array(
            np.array(custom_dist).astype('float32'))
        inputs['CustomDistAlias'] = _init_by_numpy_array(
            np.array(alias_).astype('int32'))
        inputs['CustomDistAliasProbs'] = _init_by_numpy_array(
            np.array(alias_probs_).astype('float32'))
        sampler = 2
    else:
        raise Exception("Unsupported sampler type.")

    if num_neg_samples is None:
        num_neg_samples = 10
    else:
        num_neg_samples = int(num_neg_samples)

    remote_prefetch = is_sparse
    print(
        "With sparse mode, if your models has only small parameter prefetch may cause speed down"
    )

    attrs = {
        'num_total_classes': int(num_total_classes),
        'num_neg_samples': num_neg_samples,
        'seed': seed,
        'sampler': sampler,
        'is_sparse': is_sparse,
        'remote_prefetch': remote_prefetch
    }

    helper.append_op(
        type='nce',
        inputs=inputs,
        outputs={
            'Cost': cost,
            'SampleLogits': sample_logits,
            'SampleLabels': sample_labels
        },
        attrs=attrs)
    return cost / (num_neg_samples + 1)


def hsigmoid(input,
             label,
             num_classes,
             param_attr=None,
             bias_attr=None,
             name=None,
             path_table=None,
             path_code=None,
             is_custom=False,
             is_sparse=False):
    """
890 891
    :api_attr: Static Graph
    
892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
    The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
    and speed up the model training, especially the training of language model.
    Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
    For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
    the path, and sum them to get a total cost.
    Comparing to softmax, the OP can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
    represents the number of classes or the size of word dict.

    The OP supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
    Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`. For the custom
    tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):

    1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
    2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
    3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
       Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
    4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
       to the same batch of inputs.

    Parameters:
        input (Variable): A tensor with the shape [N, D], where N is the size of mini-batch,
            and D is the feature size. Its data type supports float32 and float64.
        label (Variable): A tensor contains the labels of training data. Its shape is [N, 1]
            and data type is int64.
        num_classes (int): The number of classes or the size of word dict, must be greater than 2.
            If the default tree is used (:attr:`is_custom` is set to False), :attr:`num_classes`
            should not be None. If the custom tree is used (:attr:`is_custom` is set to True),
            :attr:`num_classes` should be the number of non-leaf nodes, which indicates the num of
            classes using by the binary classifier.
        param_attr (ParamAttr, optional): The parameter attribute for the learnable parameters/weights
            of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid will create a
            ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is
            initialized with Xavier. Default: None.
        bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of hsigmoid. If it
            is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
            hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
            set, the bias is initialized zero. Default: None.
        name (str, optional): Normally there is no need for user to set this property. For more information,
            please refer to :ref:`api_guide_Name`. Default: None.
        path_table (Variable, optional): A tensor that stores each batch of samples' path from leaf to root
            node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i,
            path_table[i] is a np.array like structure and each element in this array is the indexes in parent
            nodes' weight matrix. Default: None.
        path_code (Variable, optional): A tensor that stores each batch of samples' code of path from leaf
            to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`.
            Each code of path is consisted with the code of nodes from leaf to root node. Default: None.
        is_custom (bool, optional): Whether use custom binary tree. If it's True, :attr:`path_table`,
            :attr:`path_code` and :attr:`num_classes` should be set, otherwise :attr:`num_classes` should
            be set. Default: False.
        is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the
            gradient of W and input will be sparse. Default: False.

    Returns:
        Variable: A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as :attr:`input`.

    Examples:

        .. code-block:: python

            import paddle.fluid as fluid
            x = fluid.layers.fill_constant(shape=[4, 3], value=0.9, dtype='float32')
            # x = [[0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9], [0.9, 0.9, 0.9]]
            y = fluid.layers.fill_constant(
                shape=[4, 1], value=1, dtype='int64')
            # y = [[1], [1], [1], [1]]
            out = fluid.layers.hsigmoid(input=x, label=y, num_classes=2, param_attr=fluid.initializer.Constant(
                value=0.05), bias_attr=fluid.initializer.Constant(value=.0))
            # out = [[0.62792355], [0.62792355], [0.62792355], [0.62792355]]
    """
961 962
    check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'hsigmoid')
    check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid')
963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119

    helper = LayerHelper('hierarchical_sigmoid', **locals())
    dtype = helper.input_dtype()
    out = helper.create_variable_for_type_inference(dtype)
    pre_out = helper.create_variable_for_type_inference(dtype)
    dim = input.shape[1]
    if ((num_classes is None) or (num_classes < 2)) and (not is_custom):
        raise ValueError(
            "num_classes must not be less than 2 with default tree")

    if (not is_custom) and (is_sparse):
        print("Sparse mode should not be used without custom tree")
        is_sparse = False

    if (not is_custom) and ((path_table is not None) or
                            (path_code is not None)):
        raise ValueError(
            "only num_classes should be passed without custom tree")

    if (is_custom) and (path_code is None):
        raise ValueError("path_code should not be None with custom tree")
    elif (is_custom) and (path_table is None):
        raise ValueError("path_table should not be None with custom tree")
    elif (is_custom) and (num_classes is None):
        raise ValueError("num_classes should not be None with custom tree")
    else:
        pass

    weights = None
    remote_prefetch = is_sparse
    print(
        "With sparse mode, if your models has only small parameter prefetch may cause speed down"
    )
    if not is_custom:
        weights = helper.create_parameter(
            attr=helper.param_attr,
            shape=[num_classes - 1, dim],
            is_bias=False,
            dtype=input.dtype)
    else:
        weights = helper.create_parameter(
            attr=helper.param_attr,
            shape=[num_classes, dim],
            is_bias=False,
            dtype=input.dtype)
    inputs = {
        "X": input,
        "W": weights,
        "PathTable": path_table,
        "PathCode": path_code,
        "Label": label
    }
    if helper.bias_attr:
        if not is_custom:
            bias = helper.create_parameter(
                attr=helper.bias_attr,
                shape=[num_classes - 1, 1],
                is_bias=True,
                dtype=input.dtype)
            inputs['Bias'] = bias
        else:
            bias = helper.create_parameter(
                attr=helper.bias_attr,
                shape=[num_classes, 1],
                is_bias=True,
                dtype=input.dtype)
            inputs['Bias'] = bias
    helper.append_op(
        type="hierarchical_sigmoid",
        inputs=inputs,
        outputs={"Out": out,
                 "PreOut": pre_out,
                 "W_Out": weights},
        attrs={
            "num_classes": num_classes,
            "is_sparse": is_sparse,
            "remote_prefetch": remote_prefetch
        })
    return out


def sampled_softmax_with_cross_entropy(logits,
                                       label,
                                       num_samples,
                                       num_true=1,
                                       remove_accidental_hits=True,
                                       use_customized_samples=False,
                                       customized_samples=None,
                                       customized_probabilities=None,
                                       seed=0):
    """
    **Sampled Softmax With Cross Entropy Operator.**

    Cross entropy loss with sampled softmax is used as the output layer for 
    larger output classes extensively. This operator samples a number of samples
    for all examples, and computes the softmax normalized values for each 
    row of the sampled tensor, after which cross-entropy loss is computed. 

    Because this operator performs a softmax on logits internally, it expects
    unscaled logits. This operator should not be used with the output of
    softmax operator since that would produce incorrect results.
    
    For examples with T true labels (T >= 1), we assume that each true label has
    a probability of 1/T. For each sample, S samples are generated using a
    log uniform distribution. True labels are concatenated with these samples to
    form T + S samples for each example. So, assume the shape of logits is
    [N x K], the shape for samples is [N x (T+S)]. For each sampled label, a 
    probability is calculated, which corresponds to the Q(y|x) in 
    [Jean et al., 2014](http://arxiv.org/abs/1412.2007).
    
    Logits are sampled according to the sampled labels. Then if 
    remove_accidental_hits is True, if a sample[i, j] accidentally hits true 
    labels, then the corresponding sampled_logits[i, j] is minus by 1e20 to 
    make its softmax result close to zero. Then sampled logits are subtracted by
    logQ(y|x), these sampled logits and re-indexed labels are used to compute 
    a softmax with cross entropy.

    Args:
        logits (Variable): The unscaled log probabilities, which is a 2-D tensor
            with shape [N x K]. N is the batch_size, and K is the class number.
        label (Variable): The ground truth which is a 2-D tensor. Label is a 
            Tensor<int64> with shape [N x T], where T is the number of true 
            labels per example. 
        num_samples (int): The number for each example, num_samples should be 
            less than the number of class.
        num_true(int): The number of target classes per training example.
        remove_accidental_hits (bool): A flag indicating whether to remove 
            accidental hits when sampling. If True and if a sample[i, j] 
            accidentally hits true labels, then the corresponding 
            sampled_logits[i, j] is minus by 1e20 to make its softmax result 
            close to zero. Default is True.
        use_customized_samples (bool): Whether to use custom samples and probabities to sample
            logits.
        customized_samples (Variable): User defined samples, which is a 2-D tensor
            with shape [N, T + S]. S is the num_samples, and T is the number of true 
            labels per example. 
        customized_probabilities (Variable): User defined probabilities of samples, 
            a 2-D tensor which has the same shape with customized_samples.
        seed (int): The random seed for generating random number, which is used
            in the process of sampling. Default is 0.

    Returns:
        Variable: Return the cross entropy loss which is a 2-D tensor with shape
                  [N x 1].

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            input = fluid.layers.data(name='data', shape=[256], dtype='float32')
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            fc = fluid.layers.fc(input=input, size=100)
            out = fluid.layers.sampled_softmax_with_cross_entropy(
                      logits=fc, label=label, num_samples=25)
    """
    helper = LayerHelper('sample_logits', **locals())
1120 1121 1122
    samples = customized_samples if use_customized_samples else helper.create_variable_for_type_inference(
        dtype='int64')
    probabilities = customized_probabilities if use_customized_samples else helper.create_variable_for_type_inference(
1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184
        dtype=logits.dtype)
    sampled_logits \
        = helper.create_variable_for_type_inference(dtype=logits.dtype)
    sampled_label = helper.create_variable_for_type_inference(dtype='int64')
    sampled_softlabel = helper.create_variable_for_type_inference(
        dtype=logits.dtype)
    logits_dim = helper.create_variable_for_type_inference(dtype=logits.dtype)
    labels_dim = helper.create_variable_for_type_inference(dtype=label.type)

    helper.append_op(
        type='sample_logits',
        inputs={
            'Logits': logits,
            'Labels': label,
            'CustomizedSamples': customized_samples,
            'CustomizedProbabilities': customized_probabilities
        },
        outputs={
            'Samples': samples,
            'Probabilities': probabilities,
            'SampledLabels': sampled_label,
            'SampledLogits': sampled_logits,
            'LogitsDim': logits_dim,
            'LabelsDim': labels_dim
        },
        attrs={
            'use_customized_samples': use_customized_samples,
            'uniq': True,
            'remove_accidental_hits': remove_accidental_hits,
            'num_samples': num_samples,
            'seed': seed
        })
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    helper.append_op(
        type='one_hot',
        inputs={'X': sampled_label},
        attrs={'depth': num_samples + 1},
        outputs={'Out': sampled_softlabel})

    helper.append_op(
        type='softmax_with_cross_entropy',
        inputs={'Logits': sampled_logits,
                'Label': sampled_softlabel},
        outputs={'Softmax': softmax,
                 'Loss': loss},
        attrs={
            'soft_label': True,
            'ignore_index': False,
            'numeric_stable_mode': False
        })
    return loss / num_true


def softmax_with_cross_entropy(logits,
                               label,
                               soft_label=False,
                               ignore_index=kIgnoreIndex,
                               numeric_stable_mode=True,
                               return_softmax=False,
                               axis=-1):
    """
1185 1186 1187 1188
    :alias_main: paddle.nn.functional.softmax_with_cross_entropy
	:alias: paddle.nn.functional.softmax_with_cross_entropy,paddle.nn.functional.loss.softmax_with_cross_entropy
	:old_api: paddle.fluid.layers.softmax_with_cross_entropy

1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237
    This operator implements the cross entropy loss function with softmax. This function 
    combines the calculation of the softmax operation and the cross entropy loss function 
    to provide a more numerically stable gradient.

    Because this operator performs a softmax on logits internally, it expects
    unscaled logits. This operator should not be used with the output of
    softmax operator since that would produce incorrect results.

    When the attribute :attr:`soft_label` is set :attr:`False`, this operators 
    expects mutually exclusive hard labels, each sample in a batch is in exactly 
    one class with a probability of 1.0. Each sample in the batch will have a 
    single label.

    The equation is as follows:

    1) Hard label (one-hot label, so every sample has exactly one class)

    .. math::

        loss_j =  -\\text{logits}_{label_j} +
        \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K

    2) Soft label (each sample can have a distribution over all classes)

    .. math::

        loss_j =  -\\sum_{i=0}^{K}\\text{label}_i
        \\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
        \\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K

    3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:

    .. math::

        max_j &= \\max_{i=0}^{K}{\\text{logits}_i}

        log\\_max\\_sum_j &= \\log\\sum_{i=0}^{K}\\exp(logits_i - max_j)

        softmax_j &= \\exp(logits_j - max_j - {log\\_max\\_sum}_j)

    and then cross entropy loss is calculated by softmax and label.

    Args:
        logits (Variable): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
        label (Variable): The ground truth  ``Tensor`` , data type is the same
            as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`, 
            Label is a ``Tensor``  in the same shape with :attr:`logits`. 
            If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor`` 
            in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
T
tianshuo78520a 已提交
1238
        soft_label (bool, optional): A flag to indicate whether to interpretant the given
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276
            labels as soft labels. Default False.
        ignore_index (int, optional): Specifies a target value that is ignored and does
                                      not contribute to the input gradient. Only valid
                                      if :attr:`soft_label` is set to :attr:`False`. 
                                      Default: kIgnoreIndex(-100).
        numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
                                              numerically stable algorithm. Only valid
                                              when :attr:`soft_label` is :attr:`False` 
                                              and GPU is used. When :attr:`soft_label` 
                                              is :attr:`True` or CPU is used, the 
                                              algorithm is always numerically stable.
                                              Note that the speed may be slower when use
                                              stable algorithm. Default: True.
        return_softmax (bool, optional): A flag indicating whether to return the softmax
                                         along with the cross entropy loss. Default: False.
        axis (int, optional): The index of dimension to perform softmax calculations. It 
                              should be in range :math:`[-1, rank - 1]`, while :math:`rank`
                              is the rank of input :attr:`logits`. Default: -1.

    Returns:
        ``Variable`` or Tuple of two ``Variable`` : Return the cross entropy loss if \
                                                    `return_softmax` is False, otherwise the tuple \
                                                    (loss, softmax), softmax is in the same shape \
                                                    with input logits and cross entropy loss is in \
                                                    the same shape with input logits except shape \
                                                    in dimension :attr:`axis` as 1.

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            data = fluid.data(name='data', shape=[-1, 128], dtype='float32')
            label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
            fc = fluid.layers.fc(input=data, size=100)
            out = fluid.layers.softmax_with_cross_entropy(
                logits=fc, label=label)
    """
1277 1278 1279 1280 1281 1282 1283 1284 1285 1286
    if in_dygraph_mode():
        softmax, loss = core.ops.softmax_with_cross_entropy(
            logits, label, 'soft_label', soft_label, 'ignore_index',
            ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis',
            axis)
        if not return_softmax:
            return loss
        else:
            return loss, softmax

1287 1288 1289 1290 1291 1292
    attrs = {
        'soft_label': soft_label,
        'ignore_index': ignore_index,
        'numeric_stable_mode': numeric_stable_mode,
        'axis': axis
    }
1293 1294 1295 1296 1297 1298 1299 1300 1301
    helper = LayerHelper('softmax_with_cross_entropy', **locals())
    softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
    loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
    helper.append_op(
        type='softmax_with_cross_entropy',
        inputs={'Logits': logits,
                'Label': label},
        outputs={'Softmax': softmax,
                 'Loss': loss},
1302
        attrs=attrs)
1303 1304 1305 1306 1307 1308 1309 1310 1311

    if return_softmax:
        return loss, softmax

    return loss


def rank_loss(label, left, right, name=None):
    """
1312 1313 1314 1315
    :alias_main: paddle.nn.functional.rank_loss
	:alias: paddle.nn.functional.rank_loss,paddle.nn.functional.loss.rank_loss
	:old_api: paddle.fluid.layers.rank_loss

1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358
    This operator implements the sort loss layer in the RankNet model. RankNet is a pairwise ranking model 
    with a training sample consisting of a pair of documents (A and B), The label (P) 
    indicates whether A is ranked higher than B or not. Please refer to more details: 
    `RankNet <http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf>`_

    Rank loss layer takes three inputs: left ( :math:`o_i` ), right ( :math:`o_j` ) and
    label ( :math:`P_{i,j}` ). The inputs respectively represent RankNet's output scores
    for documents A and B and the value of label P. Rank loss layer takes batch inputs 
    with size batch_size (batch_size >= 1), P = {0, 1} or {0, 0.5, 1}, 
    where 0.5 means that there is no information about the rank of the input pair.
    The following equation computes rank loss C_{i,j} from the inputs:

    .. math::
      C_{i,j} &= -\\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\\\
    .. math::
      o_{i,j} &=  o_i - o_j  \\\\
    .. math::
      \\tilde{P_{i,j}} &= \\left \{0, 0.5, 1 \\right \} \ or \ \\left \{0, 1 \\right \}

    Parameters:
        label (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32, batch indicates the size of the data. Indicats whether A ranked higher than B or not.
        left (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32. RankNet's output score for doc A.
        right (Variable): 2-D ``Tensor`` with the shape of :math:`[batch,1]`, the data type is float32. RankNet's output score for doc B.
        name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .

    Returns:
        Variable: ``Tensor`` indicating the output value of the sort loss layer, the data type is float32, and the return value's shape is :math:`[batch,1]` .

    Raises:
        ValueError: Any of label, left, and right is not a ``Variable`` .

    Examples:

        .. code-block:: python

            import paddle.fluid as fluid
            label = fluid.data(name="label", shape=[-1, 1], dtype="float32")
            left = fluid.data(name="left", shape=[-1, 1], dtype="float32")
            right = fluid.data(name="right", shape=[-1, 1], dtype="float32")
            out = fluid.layers.rank_loss(label, left, right)

    """
    helper = LayerHelper('rank_loss', **locals())
1359 1360 1361
    check_variable_and_dtype(label, 'label', ['float32'], "rank_loss")
    check_variable_and_dtype(left, 'left', ['float32'], "rank_loss")
    check_variable_and_dtype(right, 'right', ['float32'], "rank_loss")
1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409

    out = helper.create_variable_for_type_inference("float32")

    helper.append_op(
        type='rank_loss',
        inputs={"Label": label,
                "Left": left,
                "Right": right},
        outputs={'Out': out})
    return out


def margin_rank_loss(label, left, right, margin=0.1, name=None):
    """
    Margin Ranking Loss Layer for ranking problem,
    which compares left score and right score passed in.
    The ranking loss can be defined as following equation:

    .. math::

        rank\_loss = max(0, -label * (left - right) + margin)

    Args:
       label (Variable): Indicates whether the left is ranked higher than the right or not.
           Data type is float32.
       left (Variable): Ranking score for left. Data type float32.
       right (Variable): Ranking score for right. Data type float32.
       margin (float): Indicates the given margin.
       name(str|None): For detailed information, please refer to 
           :ref:`api_guide_Name` . Usually name is no need to set and None by default.

    Returns:
       Variable: The ranking loss.

    Raises:
       ValueError: Any of label, left, and right is not a Variable.

    Examples:

        .. code-block:: python

           import paddle.fluid as fluid
           label = fluid.data(name="label", shape=[-1, 1], dtype="float32")
           left = fluid.data(name="left", shape=[-1, 1], dtype="float32")
           right = fluid.data(name="right", shape=[-1, 1], dtype="float32")
           out = fluid.layers.margin_rank_loss(label, left, right)
    """
    helper = LayerHelper('margin_rank_loss', **locals())
1410 1411 1412
    check_variable_and_dtype(label, 'label', ['float32'], 'margin_rank_loss')
    check_variable_and_dtype(label, 'left', ['float32'], 'margin_rank_loss')
    check_variable_and_dtype(label, 'right', ['float32'], 'margin_rank_loss')
1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432
    out = helper.create_variable_for_type_inference(left.dtype)
    act = helper.create_variable_for_type_inference(left.dtype)
    helper.append_op(
        type='margin_rank_loss',
        inputs={"Label": label,
                "X1": left,
                "X2": right},
        outputs={'Out': out,
                 'Activated': act},
        attrs={'margin': margin})
    return out


@templatedoc()
def sigmoid_cross_entropy_with_logits(x,
                                      label,
                                      ignore_index=kIgnoreIndex,
                                      name=None,
                                      normalize=False):
    """
1433 1434 1435 1436
    :alias_main: paddle.nn.functional.sigmoid_cross_entropy_with_logits
	:alias: paddle.nn.functional.sigmoid_cross_entropy_with_logits,paddle.nn.functional.loss.sigmoid_cross_entropy_with_logits
	:old_api: paddle.fluid.layers.sigmoid_cross_entropy_with_logits

1437 1438 1439
    ${comment}

    Args:
1440 1441 1442 1443 1444 1445 1446 1447
        x(Variable): a 2-D tensor with shape N x D, where N is the batch size and
                D is the number of classes. This input is a tensor of logits computed
                by the previous operator. Logits are unscaled log probabilities given
                as log(p/(1-p)) The data type should be float32 or float64.
        label (Variable): a 2-D tensor of the same type and shape as X.
                This input is a tensor of probabalistic labels for each logit.
        ignore_index(int): Specifies a target value that is ignored and 
                does not contribute to the input gradient.
1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471
        name(str|None): The default value is None.  Normally there is
            no need for user to set this property.  For more information,
            please refer to :ref:`api_guide_Name`
        normalize(bool): If true, divide the output by the number of
            targets != ignore_index.

    Returns:
        out(${out_type}): ${out_comment}

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid
            input = fluid.data(
                name='data', shape=[10], dtype='float32')
            label = fluid.data(
                name='data', shape=[10], dtype='float32')
            loss = fluid.layers.sigmoid_cross_entropy_with_logits(
                x=input,
                label=label,
                ignore_index=-1,
                normalize=True) # or False
            # loss = fluid.layers.reduce_sum(loss) # summation of loss
    """
1472 1473
    check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
                             'sigmoid_cross_entropy_with_logits')
1474 1475 1476

    helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())

1477
    out = helper.create_variable_for_type_inference(dtype=x.dtype)
1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493

    helper.append_op(
        type="sigmoid_cross_entropy_with_logits",
        inputs={"X": x,
                "Label": label},
        attrs={"ignore_index": ignore_index,
               'normalize': normalize},
        outputs={"Out": out})
    return out


def teacher_student_sigmoid_loss(input,
                                 label,
                                 soft_max_up_bound=15.0,
                                 soft_max_lower_bound=-15.0):
    """
1494 1495 1496 1497
    :alias_main: paddle.nn.functional.teacher_student_sigmoid_loss
	:alias: paddle.nn.functional.teacher_student_sigmoid_loss,paddle.nn.functional.loss.teacher_student_sigmoid_loss
	:old_api: paddle.fluid.layers.teacher_student_sigmoid_loss

1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532
    **Teacher Student Log Loss Layer**

    This layer accepts input predictions and target label and returns the
    teacher_student loss. Z is click or not, z' is value of teacher loss, label = {-2, -1, [0, 2]}
    when z' is not exist, clk = 0 : label = -2; when z' is not exist, clk = 1 : label = -1;
    when z' is exist    , clk = 0 : label = 0 + z'; when z' is exist    , clk = 1 : label = 1 + z'

    .. math::
        loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x)))

    Args:
        input (Variable|list):  a 2-D tensor with shape [N x 1], where N is the
                                batch size. This input is a probability computed
                                by the previous operator.
        label (Variable|list):  the ground truth which is a 2-D tensor with
                                shape [N x 1], where N is the batch size.
        soft_max_up_bound  (float):  if input > soft_max_up_bound, will be bound
        soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound

    Returns:
        Variable: A 2-D tensor with shape [N x 1], the teacher_student_sigmoid_loss.

    Examples:
        .. code-block:: python
          
          import paddle.fluid as fluid

          batch_size = 64
          label = fluid.data(
                    name="label", shape=[batch_size, 1], dtype="int64")
          similarity = fluid.data(
                    name="similarity", shape=[batch_size, 1], dtype="float32")
          cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label)

    """
1533 1534
    check_variable_and_dtype(input, "input",
                             ['float32', 'float64', 'int32', 'int64'],
1535
                             'teacher_student_sigmoid_loss')
1536 1537
    check_variable_and_dtype(label, "label",
                             ['float32', 'float64', 'int32', 'int64'],
1538 1539
                             'teacher_student_sigmoid_loss')

1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568
    helper = LayerHelper('teacher_student_sigmoid_loss', **locals())
    out = helper.create_variable(dtype=input.dtype)
    helper.append_op(
        type='teacher_student_sigmoid_loss',
        inputs={'X': [input],
                'Label': [label]},
        outputs={'Y': [out]},
        attrs={"soft_max_lower_bound": float(soft_max_lower_bound), \
                "soft_max_up_bound": float(soft_max_up_bound)})
    return out


def huber_loss(input, label, delta):
    """
    This operator computes the Huber loss between input and label.
    Huber loss is commonly used in regression tasks. Compared to square_error_cost, Huber loss is more robust and less sensitivity to outliers.

    When the absolute difference between input and label is greater than delta, the linear error is calculated:

    .. math::
            huber\_loss = delta * (label - input) - 0.5 * delta * delta

    When the absolute difference between input and label is greater than delta, the square error is calculated:

    .. math::
            huber\_loss = 0.5 * (label - input) * (label - input)


    Args:
1569 1570
        input (Variable): Predicted data, 2D-Tensor with the shape of [batch_size, 1]. The data type should be float32.
        label (Variable): Ground truth label, 2D-Tensor with the shape of [batch_size, 1]. The data type should be float32.
1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598
        delta (float): The threshold for Huber loss, which is used to control the balance between the linear error and square error. The data type should be float32.

    Returns:
        Variable: The huber loss, a tensor with the same shape and data type as input.


    Examples:

    ..  code-block:: python

        import paddle.fluid as fluid
        import numpy as np

        DATATYPE='float32'
        input_data = np.array([[1.],[2.],[3.],[4.]]).astype(DATATYPE)
        label_data = np.array([[3.],[3.],[4.],[4.]]).astype(DATATYPE)

        x = fluid.data(name='input', shape=[None, 1], dtype=DATATYPE)
        y = fluid.data(name='label', shape=[None, 1], dtype=DATATYPE)
        loss = fluid.layers.huber_loss(input=x, label=y, delta=1.0)

        place = fluid.CPUPlace()
        #place = fluid.CUDAPlace(0)
        exe = fluid.Executor(place)
        HuberLoss, = exe.run(feed={'input':input_data ,'label':label_data}, fetch_list=[loss.name])
        print(HuberLoss)  #[[1.5], [0.5], [0.5], [0. ]], dtype=float32
    """
    helper = LayerHelper('huber_loss', **locals())
1599 1600 1601 1602
    check_variable_and_dtype(input, 'input', ['float32', 'float64'],
                             'huber_loss')
    check_variable_and_dtype(label, 'label', ['float32', 'float64'],
                             'huber_loss')
1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615
    residual = helper.create_variable_for_type_inference(
        dtype=helper.input_dtype())
    out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
    helper.append_op(
        type='huber_loss',
        inputs={'X': input,
                'Y': label},
        outputs={'Out': out,
                 'Residual': residual},
        attrs={'delta': delta})
    return out


1616
@deprecated(since="2.0.0", update_to="paddle.nn.functional.kl_div")
1617 1618 1619
@templatedoc()
def kldiv_loss(x, target, reduction='mean', name=None):
    """
1620 1621 1622 1623
    :alias_main: paddle.nn.functional.kldiv_loss
	:alias: paddle.nn.functional.kldiv_loss,paddle.nn.functional.loss.kldiv_loss
	:old_api: paddle.fluid.layers.kldiv_loss

1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640
    ${comment}

    Args:
        x (Variable): ${x_comment}
        target (Variable): ${target_comment}
        reduction (Variable): ${reduction_comment}
        name(str, optional): For detailed information, please refer
                             to :ref:`api_guide_Name`. Usually name is no need to set and
                             None by default.

    Returns:
        Variable(Tensor): The KL divergence loss. The data type is same as input tensor

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid
1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658
            
            # 'batchmean' reduction, loss shape will be [N]
            x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2]
            target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
            loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean') # shape=[-1]
            
            # 'mean' reduction, loss shape will be [1]
            x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2]
            target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
            loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='mean') # shape=[1]
            
            # 'sum' reduction, loss shape will be [1]
            x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2]
            target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
            loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='sum') # shape=[1]
            
            # 'none' reduction, loss shape is same with X shape
            x = fluid.data(name='x', shape=[None,4,2,2], dtype='float32') # shape=[-1, 4, 2, 2]
1659
            target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
1660 1661
            loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='none') # shape=[-1, 4, 2, 2]

1662 1663
    """
    helper = LayerHelper('kldiv_loss', **locals())
1664 1665 1666 1667
    check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'kldiv_loss')
    check_variable_and_dtype(target, 'target', ['float32', 'float64'],
                             'kldiv_loss')
    check_type(reduction, 'reduction', str, 'kldiv_loss')
1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683
    loss = helper.create_variable_for_type_inference(dtype=x.dtype)
    helper.append_op(
        type='kldiv_loss',
        inputs={'X': x,
                'Target': target},
        outputs={'Loss': loss},
        attrs={'reduction': reduction})
    return loss


from .ops import square
from .control_flow import equal


def npair_loss(anchor, positive, labels, l2_reg=0.002):
    '''
1684 1685 1686 1687
    :alias_main: paddle.nn.functional.npair_loss
	:alias: paddle.nn.functional.npair_loss,paddle.nn.functional.loss.npair_loss
	:old_api: paddle.fluid.layers.npair_loss

1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722
  **Npair Loss Layer**

  Read `Improved Deep Metric Learning with Multi class N pair Loss Objective\
       <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/\
       papers/nips16_npairmetriclearning.pdf>`_ .

  Npair loss requires paired data. Npair loss has two parts: the first part is L2
  regularizer on the embedding vector; the second part is cross entropy loss which
  takes the similarity matrix of anchor and positive as logits.

  Args:
    anchor(Variable): embedding vector for the anchor image. shape=[batch_size, embedding_dims], 
                      the data type is float32 or float64.
    positive(Variable): embedding vector for the positive image. shape=[batch_size, embedding_dims], 
                      the data type is float32 or float64.
    labels(Variable): 1-D tensor. shape=[batch_size], the data type is float32 or float64 or int64.
    l2_reg(float32): L2 regularization term on embedding vector, default: 0.002.

  Returns:
    A Variable holding Tensor representing the npair loss, the data type is the same as 
    anchor, the shape is [1].

  Examples:
    .. code-block:: python

       import paddle.fluid as fluid
       anchor = fluid.data(
                     name = 'anchor', shape = [18, 6], dtype = 'float32')
       positive = fluid.data(
                     name = 'positive', shape = [18, 6], dtype = 'float32')
       labels = fluid.data(
                     name = 'labels', shape = [18], dtype = 'float32')

       npair_loss = fluid.layers.npair_loss(anchor, positive, labels, l2_reg = 0.002)
  '''
1723 1724 1725 1726 1727 1728
    check_variable_and_dtype(anchor, 'anchor', ['float32', 'float64'],
                             'npair_loss')
    check_variable_and_dtype(positive, 'positive', ['float32', 'float64'],
                             'positive')
    check_variable_and_dtype(labels, 'labels', ['float32', 'float64', 'int64'],
                             'labels')
1729 1730 1731
    Beta = 0.25
    batch_size = labels.shape[0]

1732
    labels = nn.reshape(labels, shape=[batch_size, 1])
1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753
    labels = nn.expand(labels, expand_times=[1, batch_size])

    labels = equal(labels, nn.transpose(labels, perm=[1, 0])).astype('float32')
    labels = labels / nn.reduce_sum(labels, dim=1, keep_dim=True)

    l2loss = nn.reduce_mean(nn.reduce_sum(square(anchor), 1)) \
             + nn.reduce_mean(nn.reduce_sum(square(positive), 1))
    l2loss = l2loss * Beta * l2_reg

    similarity_matrix = nn.matmul(
        anchor, positive, transpose_x=False, transpose_y=True)
    softmax_ce = softmax_with_cross_entropy(
        logits=similarity_matrix, label=labels, soft_label=True)
    cross_entropy = nn.reduce_sum(labels * softmax_ce, 0)
    celoss = nn.reduce_mean(cross_entropy)

    return l2loss + celoss


def mse_loss(input, label):
    """
1754

1755 1756 1757 1758 1759 1760 1761 1762 1763
    This op accepts input predications and target label and returns the mean square error.

    The loss can be described as:

    .. math::
        
        Out = MEAN((input - label)^2)

    Parameters: 
1764 1765
        input (Tensor): Input tensor, the data type should be float32.
        label (Tensor): Label tensor, the data type should be float32.
1766 1767

    Returns:
1768
        Tensor: The tensor storing the mean square error difference of input and label.
1769

1770
    Return type: Tensor.
1771 1772 1773 1774
    
    Examples:
        .. code-block:: python

1775 1776 1777 1778 1779 1780
            import paddle
            input = paddle.to_tensor([1.1, 1.9])
            label = paddle.to_tensor([1.0, 2.0])
            output = paddle.fluid.layers.mse_loss(input, label)
            print(output.numpy())
            # [0.01]
1781
    """
1782 1783
    check_variable_and_dtype(input, "input", ['float32', 'float64'], 'mse_loss')
    check_variable_and_dtype(label, "label", ['float32', 'float64'], 'mse_loss')
1784
    return nn.reduce_mean(square_error_cost(input, label))