未验证 提交 c10cf6d2 编写于 作者: Z Zhong Hui 提交者: GitHub

Fix the API of bce loss and add functional API binary_cross_entropy (#26012)

Fix the API of bce loss and add functional API binary_cross_entropy 
上级 21ea2976
......@@ -19,94 +19,204 @@ import unittest
from op_test import OpTest
class TestBCELoss(unittest.TestCase):
def test_BCELoss(self):
input_np = np.random.random(size=(20, 30)).astype(np.float64)
label_np = np.random.random(size=(20, 30)).astype(np.float64)
prog = fluid.Program()
startup_prog = fluid.Program()
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for red in reductions:
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[None, 30], dtype='float64')
label = fluid.data(
name='label', shape=[None, 30], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
def test_static_layer(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.data(name='input', shape=input_np.shape, dtype='float64')
label = paddle.data(name='label', shape=label_np.shape, dtype='float64')
if weight_np is not None:
weight = paddle.data(
name='weight', shape=weight_np.shape, dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(
weight=weight, reduction=reduction)
else:
bce_loss = paddle.nn.loss.BCELoss(reduction=reduction)
res = bce_loss(input, label)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={"input": input_np,
"label": label_np}
if weight_np is None else {
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
return static_result
exe = fluid.Executor(place)
static_result = exe.run(
prog,
def test_static_functional(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.data(name='input', shape=input_np.shape, dtype='float64')
label = paddle.data(name='label', shape=label_np.shape, dtype='float64')
if weight_np is not None:
weight = paddle.data(
name='weight', shape=weight_np.shape, dtype='float64')
res = paddle.nn.functional.binary_cross_entropy(
input, label, weight=weight, reduction=reduction)
else:
res = paddle.nn.functional.binary_cross_entropy(
input, label, reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={"input": input_np,
"label": label_np},
"label": label_np}
if weight_np is None else {
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
return static_result
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
def test_dygraph_layer(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction)
else:
bce_loss = paddle.nn.loss.BCELoss(reduction=reduction)
dy_res = bce_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def test_dygraph_functional(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
dy_res = paddle.nn.functional.binary_cross_entropy(
input, label, weight=weight, reduction=reduction)
else:
dy_res = paddle.nn.functional.binary_cross_entropy(
input, label, reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_bceloss(input_np, label_np, reduction='mean', weight_np=None):
if weight_np is None:
expected = -1 * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
if red == 'mean':
else:
expected = -1 * weight_np * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
if reduction == 'mean':
expected = np.mean(expected)
elif red == 'sum':
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestBCELoss(unittest.TestCase):
def test_BCELoss(self):
input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64)
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
static_result = test_static_layer(place, input_np, label_np,
reduction)
dy_result = test_dygraph_layer(place, input_np, label_np,
reduction)
expected = calc_bceloss(input_np, label_np, reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(place, input_np,
label_np, reduction)
dy_functional = test_dygraph_functional(place, input_np,
label_np, reduction)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCELoss_weight(self):
input_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64)
input_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(
0, 2, size=(2, 3, 4, 10)).astype(np.float64)
weight_np = np.random.random(size=(3, 4, 10)).astype(np.float64)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[None, 3, 4, 10], dtype='float64')
label = fluid.data(
name='label', shape=[None, 3, 4, 10], dtype='float64')
weight = fluid.data(
name='weight', shape=[3, 4, 10], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(weight=weight)
res = bce_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
for reduction in ['sum', 'mean', 'none']:
static_result = test_static_layer(
place, input_np, label_np, reduction, weight_np=weight_np)
dy_result = test_dygraph_layer(
place, input_np, label_np, reduction, weight_np=weight_np)
expected = calc_bceloss(
input_np, label_np, reduction, weight_np=weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(
place, input_np, label_np, reduction, weight_np=weight_np)
dy_functional = test_dygraph_functional(
place, input_np, label_np, reduction, weight_np=weight_np)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy()
def test_BCELoss_boardcast(self):
input_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(3, 4, 10)).astype(np.float64)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
expected = np.mean(-1 * weight_np *
(label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np)))
static_result = test_static_layer(place, input_np, label_np)
dy_result = test_dygraph_layer(place, input_np, label_np)
expected = calc_bceloss(input_np, label_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
def test_BCELoss_error(self):
paddle.disable_static()
self.assertRaises(
ValueError, paddle.nn.loss.BCELoss, reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.binary_cross_entropy,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
def bce_loss(input, label):
return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input))
......
......@@ -120,6 +120,7 @@ from .lod import hash #DEFINE_ALIAS
# from .lod import dynamic_gru #DEFINE_ALIAS
# from .lod import dynamic_lstm #DEFINE_ALIAS
# from .lod import dynamic_lstmp #DEFINE_ALIAS
from .loss import binary_cross_entropy #DEFINE_ALIAS
from .loss import bpr_loss #DEFINE_ALIAS
from .loss import center_loss #DEFINE_ALIAS
from .loss import cross_entropy #DEFINE_ALIAS
......
......@@ -42,9 +42,11 @@ from ...fluid.layers import huber_loss #DEFINE_ALIAS
from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from ...fluid.framework import _varbase_creator
from ...fluid.framework import Variable
__all__ = [
'binary_cross_entropy',
'bpr_loss',
'center_loss',
'cross_entropy',
......@@ -73,6 +75,142 @@ __all__ = [
]
def binary_cross_entropy(input, label, weight=None, reduction='mean',
name=None):
"""
This op measures the binary_cross_entropy loss between input predictions ``input``
and target labels ``label`` . The binary_cross_entropy loss can be described as:
If :attr:`weight` is set, the loss is:
.. math::
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
.. math::
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
should be numbers between 0 and 1.
Parameters:
input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``input``
should always be the output of sigmod. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``input``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is scalar.
Examples:
.. code-block:: python
import paddle
import numpy as np
input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
paddle.disable_static()
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
output = paddle.nn.functional.binary_cross_entropy(input, label)
print(output.numpy()) # [0.65537095]
paddle.enable_static()
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy should be 'sum', "
"'mean' or 'none', but received %s, which is not allowed." %
reduction)
if in_dygraph_mode():
one = _varbase_creator(dtype=input.dtype)
core.ops.fill_constant(one, 'value',
float(1.0), 'force_cpu', False, 'dtype',
one.dtype, 'str_value', '1.0', 'shape', [1])
one.stop_gradient = True
label_minus = core.ops.elementwise_sub(label, one)
input_minus = core.ops.elementwise_sub(one, input)
input_minus_log = core.ops.log(input_minus)
input_log = core.ops.log(input)
loss_1 = core.ops.elementwise_mul(label_minus, input_minus_log)
loss_2 = core.ops.elementwise_mul(label, input_log)
out = core.ops.elementwise_sub(loss_1, loss_2)
if weight is not None:
out = core.ops.elementwise_mul(out, weight, 'axis', -1)
if reduction == 'sum':
return core.ops.reduce_sum(out, 'dim', [0], 'keep_dim', False,
"reduce_all", True)
elif reduction == 'mean':
return core.ops.reduce_mean(out, 'dim', [0], 'keep_dim', False,
"reduce_all", True)
else:
return out
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'binary_cross_entropy')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'binary_cross_entropy')
one = paddle.fill_constant(shape=[1], value=1.0, dtype=input.dtype)
one.stop_gradient = True
label_minus = paddle.elementwise_sub(label, one)
input_minus = paddle.elementwise_sub(one, input)
input_minus_log = paddle.log(input_minus)
input_log = paddle.log(input)
loss_1 = paddle.multiply(label_minus, input_minus_log)
loss_2 = paddle.multiply(label, input_log)
sub_name = name if weight is None and reduction is 'none' else None
out = paddle.elementwise_sub(loss_1, loss_2, name=sub_name)
if weight is not None:
if isinstance(weight, paddle.framework.Variable):
weight_name = name if reduction is 'none' else None
out = paddle.multiply(out, weight, axis=-1, name=weight_name)
else:
raise ValueError(
"The weight is not a Tensor, please convert to Tensor.")
if reduction == 'sum':
return paddle.sum(out, name=name)
elif reduction == 'mean':
return paddle.mean(out, name=name)
else:
return out
def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
"""
This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
......
......@@ -18,6 +18,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle
from .. import functional as F
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
__all__ = [
# 'NCELoss',
......@@ -335,45 +336,38 @@ class L1Loss(fluid.dygraph.Layer):
class BCELoss(fluid.dygraph.Layer):
"""
:alias_main: paddle.nn.BCELoss
:alias: paddle.nn.BCELoss,paddle.nn.layer.BCELoss,paddle.nn.layer.loss.BCELoss
This interface is used to construct a callable object of the ``BCELoss`` class.
The BCELoss layer measures the binary_cross_entropy loss between input predictions
and target labels. The binary_cross_entropy loss can be described as:
The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
and target labels ``label`` . The binary_cross_entropy loss can be described as:
If :attr:`weight` is set, the loss is:
.. math::
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
.. math::
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the unreduced loss is:
If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
.. math::
Out = Out
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Note that the input predictions always be the output of sigmoid, and the target labels
Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
should be numbers between 0 and 1.
The shape of input predictions and target labels are [N, *], where N is batch_size and `*`
means any number of additional dimensions. If ``reduction`` is ``'none'``, the shape of
output is scalar, else the shape of output is same as input.
Parameters:
weight (Variable, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Variable of size nbatch and the data type
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
......@@ -381,6 +375,18 @@ class BCELoss(fluid.dygraph.Layer):
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
input (Tensor): 2-D tensor with shape: (N, *), N is batch_size, `*` means
number of additional dimensions. The input ``input`` should always
be the output of sigmod. Available dtype is float32, float64.
label (Tensor): 2-D tensor with the same shape as ``input``. The target
labels which values should be numbers between 0 and 1. Available
dtype is float32, float64.
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is scalar.
Returns:
A callable object of BCELoss.
......@@ -388,37 +394,22 @@ class BCELoss(fluid.dygraph.Layer):
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
input = fluid.data(name="input", shape=[3, 1], dtype='float32')
label = fluid.data(name="label", shape=[3, 1], dtype='float32')
bce_loss = paddle.nn.loss.BCELoss()
output = bce_loss(input, label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data) # [array([0.65537095], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
paddle.disable_static()
input = paddle.to_variable(input_data)
label = paddle.to_variable(label_data)
bce_loss = paddle.nn.loss.BCELoss()
output = bce_loss(input, label)
print(output.numpy()) # [0.65537095]
paddle.enable_static()
"""
def __init__(self, weight=None, reduction='mean'):
def __init__(self, weight=None, reduction='mean', name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
......@@ -427,37 +418,11 @@ class BCELoss(fluid.dygraph.Layer):
super(BCELoss, self).__init__()
self.weight = weight
self.reduction = reduction
self.name = name
def forward(self, input, label):
dtype = self._helper.input_dtype(input)
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'bce_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'bce_loss')
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
self._helper.append_op(
type='bce_loss',
inputs={
'X': [input],
'Label': [label],
},
outputs={'Out': [out]})
if self.weight is not None:
if isinstance(self.weight, fluid.framework.Variable):
w = self.weight
out = fluid.layers.elementwise_mul(out, w, axis=-1)
else:
raise ValueError(
"The weight is not a Variable, please convert to Variable.")
if self.reduction == 'sum':
return fluid.layers.reduce_sum(out)
elif self.reduction == 'mean':
return fluid.layers.reduce_mean(out)
else:
out = paddle.nn.functional.binary_cross_entropy(
input, label, self.weight, self.reduction, self.name)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册