未验证 提交 f5d13498 编写于 作者: Z Zhong Hui 提交者: GitHub

add binary cross entropy with logit loss (#26468)

* add binary cross entropy with logit loss
上级 4e0c6d91
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import unittest
from op_test import OpTest
def call_bce_layer(logit, label, weight=None, reduction='mean',
pos_weight=None):
bce_logit_loss = paddle.nn.loss.BCEWithLogitsLoss(
weight=weight, reduction=reduction, pos_weight=pos_weight)
res = bce_logit_loss(logit, label)
return res
def call_bce_functional(logit,
label,
weight=None,
reduction='mean',
pos_weight=None):
res = paddle.nn.functional.binary_cross_entropy_with_logits(
logit, label, weight=weight, reduction=reduction, pos_weight=pos_weight)
return res
def test_static(place,
logit_np,
label_np,
weight_np=None,
reduction='mean',
pos_weight_np=None,
functional=False):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
logit = paddle.data(name='logit', shape=logit_np.shape, dtype='float64')
label = paddle.data(name='label', shape=label_np.shape, dtype='float64')
feed_dict = {"logit": logit_np, "label": label_np}
pos_weight = None
weight = None
if pos_weight_np is not None:
pos_weight = paddle.data(
name='pos_weight', shape=pos_weight_np.shape, dtype='float64')
feed_dict["pos_weight"] = pos_weight_np
if weight_np is not None:
weight = paddle.data(
name='weight', shape=weight_np.shape, dtype='float64')
feed_dict["weight"] = weight_np
if functional:
res = call_bce_functional(logit, label, weight, reduction,
pos_weight)
else:
res = call_bce_layer(logit, label, weight, reduction, pos_weight)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
def test_dygraph(place,
logit_np,
label_np,
weight_np=None,
reduction='mean',
pos_weight_np=None,
functional=False):
paddle.disable_static()
logit = paddle.to_tensor(logit_np)
label = paddle.to_tensor(label_np)
weight = None
pos_weight = None
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
if pos_weight_np is not None:
pos_weight = paddle.to_tensor(pos_weight_np)
if functional:
dy_res = call_bce_functional(logit, label, weight, reduction,
pos_weight)
else:
dy_res = call_bce_layer(logit, label, weight, reduction, pos_weight)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_bce_with_logits_loss(logit_np,
label_np,
reduction='mean',
weight_np=None,
pos_weight=None):
expected = np.maximum(
logit_np,
0) - logit_np * label_np + np.log(1 + np.exp(-np.abs(logit_np)))
if pos_weight is not None:
expected = expected * ((pos_weight - 1) * label_np + 1)
if weight_np is not None:
expected = weight_np * expected
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestBCEWithLogitsLoss(unittest.TestCase):
def test_BCEWithLogitsLoss(self):
logit_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64)
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
static_result = test_static(
place, logit_np, label_np, reduction=reduction)
dy_result = test_dygraph(
place, logit_np, label_np, reduction=reduction)
expected = calc_bce_with_logits_loss(logit_np, label_np,
reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(
place,
logit_np,
label_np,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(
place,
logit_np,
label_np,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCEWithLogitsLoss_weight(self):
logit_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(
0, 2, size=(2, 3, 4, 10)).astype(np.float64)
weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
for reduction in ['sum', 'mean', 'none']:
static_result = test_static(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction)
dy_result = test_dygraph(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction)
expected = calc_bce_with_logits_loss(
logit_np, label_np, reduction, weight_np=weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(
place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCEWithLogitsLoss_pos_weight(self):
logit_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(
0, 2, size=(2, 3, 4, 10)).astype(np.float64)
pos_weight_np = np.random.random(size=(3, 4, 10)).astype(np.float64)
weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
reduction = "mean"
static_result = test_static(place, logit_np, label_np, weight_np,
reduction, pos_weight_np)
dy_result = test_dygraph(place, logit_np, label_np, weight_np,
reduction, pos_weight_np)
expected = calc_bce_with_logits_loss(logit_np, label_np, reduction,
weight_np, pos_weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(
place,
logit_np,
label_np,
weight_np,
reduction,
pos_weight_np,
functional=True)
dy_functional = test_dygraph(
place,
logit_np,
label_np,
weight_np,
reduction,
pos_weight_np,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCEWithLogitsLoss_error(self):
paddle.disable_static()
self.assertRaises(
ValueError,
paddle.nn.BCEWithLogitsLoss,
reduction="unsupport reduction")
logit = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.binary_cross_entropy_with_logits,
logit=logit,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -107,6 +107,7 @@ from .layer.extension import RowConv #DEFINE_ALIAS
# from .layer.learning_rate import PiecewiseDecay #DEFINE_ALIAS
# from .layer.learning_rate import PolynomialDecay #DEFINE_ALIAS
# from .layer.loss import NCELoss #DEFINE_ALIAS
from .layer.loss import BCEWithLogitsLoss #DEFINE_ALIAS
from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS
from .layer.loss import MSELoss #DEFINE_ALIAS
from .layer.loss import L1Loss #DEFINE_ALIAS
......
......@@ -126,6 +126,7 @@ from .lod import hash #DEFINE_ALIAS
# from .lod import dynamic_lstm #DEFINE_ALIAS
# from .lod import dynamic_lstmp #DEFINE_ALIAS
from .loss import binary_cross_entropy #DEFINE_ALIAS
from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS
from .loss import bpr_loss #DEFINE_ALIAS
from .loss import center_loss #DEFINE_ALIAS
from .loss import cross_entropy #DEFINE_ALIAS
......
......@@ -49,6 +49,7 @@ from ...fluid.framework import Variable
__all__ = [
'binary_cross_entropy',
'binary_cross_entropy_with_logits',
'bpr_loss',
'center_loss',
'cross_entropy',
......@@ -214,6 +215,154 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean',
return out
def binary_cross_entropy_with_logits(logit,
label,
weight=None,
reduction='mean',
pos_weight=None,
name=None):
"""
This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
layer and some reduce operations.
This measures the element-wise probability error in classification tasks
in which each class is independent.
This can be thought of as predicting labels for a data-point, where labels
are not mutually exclusive. For example, a news article can be about
politics, technology or sports at the same time or none of these.
First this operator calculate loss function as follows:
.. math::
Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit))
We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get:
.. math::
Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit})
For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0,
we reformulate the loss as follows:
.. math::
Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|})
Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
weight tensor on the loss `Out`. The ``weight`` tensor will attach different
weight on every items in the batch. The ``pos_weight`` will attach different
weight on the positive label of each class.
Finally, this operator applies reduce operation on the loss.
If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.
Note that the target labels ``label`` should be numbers between 0 and 1.
Args:
logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``logit``
is usually the output of Linear layer. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``logit``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
The data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
with length equal to the number of classes. The data type is float32, float64.
Default is ``'None'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``logit`` , else the shape of output is scalar.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label)
print(output.numpy()) # [0.45618808]
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy_with_logits "
"should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
% reduction)
if in_dygraph_mode():
one = _varbase_creator(dtype=logit.dtype)
core.ops.fill_constant(one, 'value',
float(1.0), 'force_cpu', False, 'dtype',
one.dtype, 'str_value', '1.0', 'shape', [1])
out = core.ops.sigmoid_cross_entropy_with_logits(logit, label)
if pos_weight is not None:
log_weight = core.ops.elementwise_add(
core.ops.elementwise_mul(
label, core.ops.elementwise_sub(pos_weight, one)), one)
out = core.ops.elementwise_mul(out, log_weight)
if weight is not None:
out = core.ops.elementwise_mul(out, weight)
if reduction == "sum":
return core.ops.reduce_sum(out, 'reduce_all', True)
elif reduction == "mean":
return core.ops.mean(out)
else:
return out
fluid.data_feeder.check_variable_and_dtype(
logit, 'logit', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
sigmoid_name = None
if reduction == 'none' and pos_weight is None and weight is None:
sigmoid_name = name
out = paddle.nn.functional.sigmoid_cross_entropy_with_logits(
logit, label, name=sigmoid_name)
one = paddle.fill_constant(shape=[1], value=1.0, dtype=logit.dtype)
if pos_weight is not None:
fluid.data_feeder.check_variable_and_dtype(
pos_weight, 'pos_weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
log_weight = paddle.add(
paddle.multiply(label, paddle.elementwise_sub(pos_weight, one)),
one)
pos_weight_name = name if reduction == 'none' and weight is None else None
out = paddle.multiply(out, log_weight, name=pos_weight_name)
if weight is not None:
fluid.data_feeder.check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
weight_name = name if reduction == 'none' else None
out = paddle.multiply(out, weight, name=weight_name)
if reduction == "sum":
return paddle.sum(out, name=name)
elif reduction == "mean":
return paddle.mean(out, name=name)
return out
def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
"""
This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
......
......@@ -72,6 +72,7 @@ from .extension import RowConv #DEFINE_ALIAS
# from .learning_rate import PiecewiseDecay #DEFINE_ALIAS
# from .learning_rate import PolynomialDecay #DEFINE_ALIAS
# from .loss import NCELoss #DEFINE_ALIAS
from .loss import BCEWithLogitsLoss #DEFINE_ALIAS
from .loss import CrossEntropyLoss #DEFINE_ALIAS
from .loss import MSELoss #DEFINE_ALIAS
from .loss import L1Loss #DEFINE_ALIAS
......
......@@ -21,6 +21,7 @@ from .. import functional as F
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
__all__ = [
'BCEWithLogitsLoss',
'CrossEntropyLoss',
'MSELoss',
'L1Loss',
......@@ -33,6 +34,111 @@ __all__ = [
]
class BCEWithLogitsLoss(fluid.dygraph.Layer):
"""
This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits``
layer and some reduce operations.
This measures the element-wise probability error in classification tasks
in which each class is independent.
This can be thought of as predicting labels for a data-point, where labels
are not mutually exclusive. For example, a news article can be about
politics, technology or sports at the same time or none of these.
First this operator calculate loss function as follows:
.. math::
Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit))
We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get:
.. math::
Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit})
For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0,
we reformulate the loss as follows:
.. math::
Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|})
Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the
weight tensor on the loss `Out`. The ``weight`` tensor will attach different
weight on every items in the batch. The ``pos_weight`` will attach different
weight on the positive label of each class.
Finally, this operator applies reduce operation on the loss.
If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.
Note that the target labels ``label`` should be numbers between 0 and 1.
Args:
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
The data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
with length equal to the number of classes. The data type is float32, float64.
Default is ``'None'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shapes:
logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``logit``
is usually the output of Linear layer. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``logit``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``logit`` , else the shape of output is scalar.
Returns:
A callable object of BCEWithLogitsLoss.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32")
label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32")
bce_logit_loss = paddle.nn.BCEWithLogitsLoss()
output = bce_logit_loss(logit, label)
print(output.numpy()) # [0.45618808]
"""
def __init__(self,
weight=None,
reduction='mean',
pos_weight=None,
name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in BCEWithLogitsLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(BCEWithLogitsLoss, self).__init__()
self.weight = weight
self.reduction = reduction
self.pos_weight = pos_weight
self.name = name
def forward(self, logit, label):
out = paddle.nn.functional.binary_cross_entropy_with_logits(
logit, label, self.weight, self.reduction, self.pos_weight,
self.name)
return out
class CrossEntropyLoss(fluid.dygraph.Layer):
"""
:alias_main: paddle.nn.CrossEntropyLoss
......@@ -678,9 +784,9 @@ class CTCLoss(fluid.dygraph.Layer):
:alias_main: paddle.nn.CTCLoss
:alias: paddle.nn.CTCLoss, paddle.nn.layer.CTCLoss, paddle.nn.layer.loss.CTCLoss
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
is interated to the Warp-CTC library to normalize values for each row of the input tensor.
Parameters:
......@@ -695,7 +801,7 @@ class CTCLoss(fluid.dygraph.Layer):
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Examples:
.. code-block:: python
......@@ -739,13 +845,13 @@ class CTCLoss(fluid.dygraph.Layer):
input_lengths = paddle.to_variable(input_lengths)
label_lengths = paddle.to_variable(label_lengths)
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
input_lengths,
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
input_lengths,
label_lengths)
print(loss.numpy()) #[3.9179852 2.9076521]
loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
input_lengths,
loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
input_lengths,
label_lengths)
print(loss.numpy()) #[1.1376063]
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册