From f5d134982698cbec1b61f685b2e0c39a2fedc882 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Sat, 22 Aug 2020 13:53:04 +0800 Subject: [PATCH] add binary cross entropy with logit loss (#26468) * add binary cross entropy with logit loss --- .../unittests/test_bce_with_logits_loss.py | 260 ++++++++++++++++++ python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/loss.py | 149 ++++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 122 +++++++- 6 files changed, 526 insertions(+), 8 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py diff --git a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py new file mode 100644 index 0000000000..5ba13a6da0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py @@ -0,0 +1,260 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest +from op_test import OpTest + + +def call_bce_layer(logit, label, weight=None, reduction='mean', + pos_weight=None): + bce_logit_loss = paddle.nn.loss.BCEWithLogitsLoss( + weight=weight, reduction=reduction, pos_weight=pos_weight) + res = bce_logit_loss(logit, label) + return res + + +def call_bce_functional(logit, + label, + weight=None, + reduction='mean', + pos_weight=None): + res = paddle.nn.functional.binary_cross_entropy_with_logits( + logit, label, weight=weight, reduction=reduction, pos_weight=pos_weight) + return res + + +def test_static(place, + logit_np, + label_np, + weight_np=None, + reduction='mean', + pos_weight_np=None, + functional=False): + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + logit = paddle.data(name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.data(name='label', shape=label_np.shape, dtype='float64') + feed_dict = {"logit": logit_np, "label": label_np} + + pos_weight = None + weight = None + if pos_weight_np is not None: + pos_weight = paddle.data( + name='pos_weight', shape=pos_weight_np.shape, dtype='float64') + feed_dict["pos_weight"] = pos_weight_np + if weight_np is not None: + weight = paddle.data( + name='weight', shape=weight_np.shape, dtype='float64') + feed_dict["weight"] = weight_np + if functional: + res = call_bce_functional(logit, label, weight, reduction, + pos_weight) + else: + res = call_bce_layer(logit, label, weight, reduction, pos_weight) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +def test_dygraph(place, + logit_np, + label_np, + weight_np=None, + reduction='mean', + pos_weight_np=None, + functional=False): + paddle.disable_static() + logit = paddle.to_tensor(logit_np) + label = paddle.to_tensor(label_np) + weight = None + pos_weight = None + if weight_np is not None: + weight = paddle.to_tensor(weight_np) + if pos_weight_np is not None: + pos_weight = paddle.to_tensor(pos_weight_np) + if functional: + dy_res = call_bce_functional(logit, label, weight, reduction, + pos_weight) + else: + dy_res = call_bce_layer(logit, label, weight, reduction, pos_weight) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_bce_with_logits_loss(logit_np, + label_np, + reduction='mean', + weight_np=None, + pos_weight=None): + expected = np.maximum( + logit_np, + 0) - logit_np * label_np + np.log(1 + np.exp(-np.abs(logit_np))) + if pos_weight is not None: + expected = expected * ((pos_weight - 1) * label_np + 1) + if weight_np is not None: + expected = weight_np * expected + + if reduction == 'mean': + expected = np.mean(expected) + elif reduction == 'sum': + expected = np.sum(expected) + else: + expected = expected + + return expected + + +class TestBCEWithLogitsLoss(unittest.TestCase): + def test_BCEWithLogitsLoss(self): + logit_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64) + label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64) + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + static_result = test_static( + place, logit_np, label_np, reduction=reduction) + dy_result = test_dygraph( + place, logit_np, label_np, reduction=reduction) + expected = calc_bce_with_logits_loss(logit_np, label_np, + reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static( + place, + logit_np, + label_np, + reduction=reduction, + functional=True) + dy_functional = test_dygraph( + place, + logit_np, + label_np, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCEWithLogitsLoss_weight(self): + logit_np = np.random.uniform( + 0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64) + label_np = np.random.randint( + 0, 2, size=(2, 3, 4, 10)).astype(np.float64) + weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64) + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + for reduction in ['sum', 'mean', 'none']: + static_result = test_static( + place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction) + dy_result = test_dygraph( + place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction) + expected = calc_bce_with_logits_loss( + logit_np, label_np, reduction, weight_np=weight_np) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static( + place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction, + functional=True) + dy_functional = test_dygraph( + place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCEWithLogitsLoss_pos_weight(self): + logit_np = np.random.uniform( + 0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64) + label_np = np.random.randint( + 0, 2, size=(2, 3, 4, 10)).astype(np.float64) + pos_weight_np = np.random.random(size=(3, 4, 10)).astype(np.float64) + weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64) + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + reduction = "mean" + static_result = test_static(place, logit_np, label_np, weight_np, + reduction, pos_weight_np) + dy_result = test_dygraph(place, logit_np, label_np, weight_np, + reduction, pos_weight_np) + expected = calc_bce_with_logits_loss(logit_np, label_np, reduction, + weight_np, pos_weight_np) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static( + place, + logit_np, + label_np, + weight_np, + reduction, + pos_weight_np, + functional=True) + dy_functional = test_dygraph( + place, + logit_np, + label_np, + weight_np, + reduction, + pos_weight_np, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCEWithLogitsLoss_error(self): + paddle.disable_static() + self.assertRaises( + ValueError, + paddle.nn.BCEWithLogitsLoss, + reduction="unsupport reduction") + logit = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + self.assertRaises( + ValueError, + paddle.nn.functional.binary_cross_entropy_with_logits, + logit=logit, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 7fbf26df96..0e09eeb6a0 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -107,6 +107,7 @@ from .layer.extension import RowConv #DEFINE_ALIAS # from .layer.learning_rate import PiecewiseDecay #DEFINE_ALIAS # from .layer.learning_rate import PolynomialDecay #DEFINE_ALIAS # from .layer.loss import NCELoss #DEFINE_ALIAS +from .layer.loss import BCEWithLogitsLoss #DEFINE_ALIAS from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS from .layer.loss import MSELoss #DEFINE_ALIAS from .layer.loss import L1Loss #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a135aea98c..f91caade8f 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -126,6 +126,7 @@ from .lod import hash #DEFINE_ALIAS # from .lod import dynamic_lstm #DEFINE_ALIAS # from .lod import dynamic_lstmp #DEFINE_ALIAS from .loss import binary_cross_entropy #DEFINE_ALIAS +from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS from .loss import bpr_loss #DEFINE_ALIAS from .loss import center_loss #DEFINE_ALIAS from .loss import cross_entropy #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 79826afb58..ba057f38bb 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -49,6 +49,7 @@ from ...fluid.framework import Variable __all__ = [ 'binary_cross_entropy', + 'binary_cross_entropy_with_logits', 'bpr_loss', 'center_loss', 'cross_entropy', @@ -214,6 +215,154 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean', return out +def binary_cross_entropy_with_logits(logit, + label, + weight=None, + reduction='mean', + pos_weight=None, + name=None): + """ + This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer. + Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits`` + layer and some reduce operations. + + This measures the element-wise probability error in classification tasks + in which each class is independent. + This can be thought of as predicting labels for a data-point, where labels + are not mutually exclusive. For example, a news article can be about + politics, technology or sports at the same time or none of these. + + First this operator calculate loss function as follows: + + .. math:: + Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit)) + + We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get: + + .. math:: + Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit}) + + For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0, + we reformulate the loss as follows: + + .. math:: + Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|}) + + Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the + weight tensor on the loss `Out`. The ``weight`` tensor will attach different + weight on every items in the batch. The ``pos_weight`` will attach different + weight on the positive label of each class. + + Finally, this operator applies reduce operation on the loss. + If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`. + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`. + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`. + + Note that the target labels ``label`` should be numbers between 0 and 1. + + Args: + logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], + N is batch_size, `*` means number of additional dimensions. The ``logit`` + is usually the output of Linear layer. Available dtype is float32, float64. + label (Tensor): The target labels tensor. 2-D tensor with the same shape as + ``logit``. The target labels which values should be numbers between 0 and 1. + Available dtype is float32, float64. + weight (Tensor, optional): A manual rescaling weight given to the loss of each + batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`, + The data type is float32, float64. Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default is ``'mean'``. + pos_weight (Tensor, optional): A weight of positive examples. Must be a vector + with length equal to the number of classes. The data type is float32, float64. + Default is ``'None'``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + same as ``logit`` , else the shape of output is scalar. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") + label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") + output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label) + print(output.numpy()) # [0.45618808] + + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in binary_cross_entropy_with_logits " + "should be 'sum', 'mean' or 'none', but received %s, which is not allowed." + % reduction) + + if in_dygraph_mode(): + one = _varbase_creator(dtype=logit.dtype) + core.ops.fill_constant(one, 'value', + float(1.0), 'force_cpu', False, 'dtype', + one.dtype, 'str_value', '1.0', 'shape', [1]) + out = core.ops.sigmoid_cross_entropy_with_logits(logit, label) + if pos_weight is not None: + log_weight = core.ops.elementwise_add( + core.ops.elementwise_mul( + label, core.ops.elementwise_sub(pos_weight, one)), one) + out = core.ops.elementwise_mul(out, log_weight) + if weight is not None: + out = core.ops.elementwise_mul(out, weight) + + if reduction == "sum": + return core.ops.reduce_sum(out, 'reduce_all', True) + elif reduction == "mean": + return core.ops.mean(out) + else: + return out + + fluid.data_feeder.check_variable_and_dtype( + logit, 'logit', ['float32', 'float64'], + 'binary_cross_entropy_with_logits') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['float32', 'float64'], + 'binary_cross_entropy_with_logits') + sigmoid_name = None + if reduction == 'none' and pos_weight is None and weight is None: + sigmoid_name = name + + out = paddle.nn.functional.sigmoid_cross_entropy_with_logits( + logit, label, name=sigmoid_name) + + one = paddle.fill_constant(shape=[1], value=1.0, dtype=logit.dtype) + if pos_weight is not None: + fluid.data_feeder.check_variable_and_dtype( + pos_weight, 'pos_weight', ['float32', 'float64'], + 'binary_cross_entropy_with_logits') + log_weight = paddle.add( + paddle.multiply(label, paddle.elementwise_sub(pos_weight, one)), + one) + pos_weight_name = name if reduction == 'none' and weight is None else None + out = paddle.multiply(out, log_weight, name=pos_weight_name) + + if weight is not None: + fluid.data_feeder.check_variable_and_dtype( + weight, 'weight', ['float32', 'float64'], + 'binary_cross_entropy_with_logits') + weight_name = name if reduction == 'none' else None + out = paddle.multiply(out, weight, name=weight_name) + + if reduction == "sum": + return paddle.sum(out, name=name) + elif reduction == "mean": + return paddle.mean(out, name=name) + return out + + def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None): """ This operator calculates smooth_l1_loss. Creates a criterion that uses a squared diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 342a684c04..f86051d768 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -72,6 +72,7 @@ from .extension import RowConv #DEFINE_ALIAS # from .learning_rate import PiecewiseDecay #DEFINE_ALIAS # from .learning_rate import PolynomialDecay #DEFINE_ALIAS # from .loss import NCELoss #DEFINE_ALIAS +from .loss import BCEWithLogitsLoss #DEFINE_ALIAS from .loss import CrossEntropyLoss #DEFINE_ALIAS from .loss import MSELoss #DEFINE_ALIAS from .loss import L1Loss #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 3306b7a02b..6ce036f41f 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -21,6 +21,7 @@ from .. import functional as F from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator __all__ = [ + 'BCEWithLogitsLoss', 'CrossEntropyLoss', 'MSELoss', 'L1Loss', @@ -33,6 +34,111 @@ __all__ = [ ] +class BCEWithLogitsLoss(fluid.dygraph.Layer): + """ + This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer. + Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits`` + layer and some reduce operations. + + This measures the element-wise probability error in classification tasks + in which each class is independent. + This can be thought of as predicting labels for a data-point, where labels + are not mutually exclusive. For example, a news article can be about + politics, technology or sports at the same time or none of these. + + First this operator calculate loss function as follows: + + .. math:: + Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit)) + + We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get: + + .. math:: + Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit}) + + For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0, + we reformulate the loss as follows: + + .. math:: + Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|}) + + Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the + weight tensor on the loss `Out`. The ``weight`` tensor will attach different + weight on every items in the batch. The ``pos_weight`` will attach different + weight on the positive label of each class. + + Finally, this operator applies reduce operation on the loss. + If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`. + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`. + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`. + + Note that the target labels ``label`` should be numbers between 0 and 1. + + Args: + weight (Tensor, optional): A manual rescaling weight given to the loss of each + batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`, + The data type is float32, float64. Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default is ``'mean'``. + pos_weight (Tensor, optional): A weight of positive examples. Must be a vector + with length equal to the number of classes. The data type is float32, float64. + Default is ``'None'``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shapes: + logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], + N is batch_size, `*` means number of additional dimensions. The ``logit`` + is usually the output of Linear layer. Available dtype is float32, float64. + label (Tensor): The target labels tensor. 2-D tensor with the same shape as + ``logit``. The target labels which values should be numbers between 0 and 1. + Available dtype is float32, float64. + output (Tensor): If ``reduction`` is ``'none'``, the shape of output is + same as ``logit`` , else the shape of output is scalar. + + Returns: + A callable object of BCEWithLogitsLoss. + + Examples: + + .. code-block:: python + import paddle + paddle.disable_static() + logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") + label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") + bce_logit_loss = paddle.nn.BCEWithLogitsLoss() + output = bce_logit_loss(logit, label) + print(output.numpy()) # [0.45618808] + + """ + + def __init__(self, + weight=None, + reduction='mean', + pos_weight=None, + name=None): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + + super(BCEWithLogitsLoss, self).__init__() + self.weight = weight + self.reduction = reduction + self.pos_weight = pos_weight + self.name = name + + def forward(self, logit, label): + out = paddle.nn.functional.binary_cross_entropy_with_logits( + logit, label, self.weight, self.reduction, self.pos_weight, + self.name) + return out + + class CrossEntropyLoss(fluid.dygraph.Layer): """ :alias_main: paddle.nn.CrossEntropyLoss @@ -678,9 +784,9 @@ class CTCLoss(fluid.dygraph.Layer): :alias_main: paddle.nn.CTCLoss :alias: paddle.nn.CTCLoss, paddle.nn.layer.CTCLoss, paddle.nn.layer.loss.CTCLoss - An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) - to compute Connectionist Temporal Classification (CTC) loss. - It can be aliased as softmax with CTC, since a native softmax activation + An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) + to compute Connectionist Temporal Classification (CTC) loss. + It can be aliased as softmax with CTC, since a native softmax activation is interated to the Warp-CTC library to normalize values for each row of the input tensor. Parameters: @@ -695,7 +801,7 @@ class CTCLoss(fluid.dygraph.Layer): Returns: Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``. - + Examples: .. code-block:: python @@ -739,13 +845,13 @@ class CTCLoss(fluid.dygraph.Layer): input_lengths = paddle.to_variable(input_lengths) label_lengths = paddle.to_variable(label_lengths) - loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels, - input_lengths, + loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels, + input_lengths, label_lengths) print(loss.numpy()) #[3.9179852 2.9076521] - loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels, - input_lengths, + loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels, + input_lengths, label_lengths) print(loss.numpy()) #[1.1376063] """ -- GitLab