未验证 提交 c10cf6d2 编写于 作者: Z Zhong Hui 提交者: GitHub

Fix the API of bce loss and add functional API binary_cross_entropy (#26012)

Fix the API of bce loss and add functional API binary_cross_entropy 
上级 21ea2976
......@@ -19,94 +19,204 @@ import unittest
from op_test import OpTest
def test_static_layer(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.data(name='input', shape=input_np.shape, dtype='float64')
label = paddle.data(name='label', shape=label_np.shape, dtype='float64')
if weight_np is not None:
weight = paddle.data(
name='weight', shape=weight_np.shape, dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(
weight=weight, reduction=reduction)
else:
bce_loss = paddle.nn.loss.BCELoss(reduction=reduction)
res = bce_loss(input, label)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={"input": input_np,
"label": label_np}
if weight_np is None else {
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
return static_result
def test_static_functional(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.data(name='input', shape=input_np.shape, dtype='float64')
label = paddle.data(name='label', shape=label_np.shape, dtype='float64')
if weight_np is not None:
weight = paddle.data(
name='weight', shape=weight_np.shape, dtype='float64')
res = paddle.nn.functional.binary_cross_entropy(
input, label, weight=weight, reduction=reduction)
else:
res = paddle.nn.functional.binary_cross_entropy(
input, label, reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={"input": input_np,
"label": label_np}
if weight_np is None else {
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
return static_result
def test_dygraph_layer(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction)
else:
bce_loss = paddle.nn.loss.BCELoss(reduction=reduction)
dy_res = bce_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def test_dygraph_functional(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
dy_res = paddle.nn.functional.binary_cross_entropy(
input, label, weight=weight, reduction=reduction)
else:
dy_res = paddle.nn.functional.binary_cross_entropy(
input, label, reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_bceloss(input_np, label_np, reduction='mean', weight_np=None):
if weight_np is None:
expected = -1 * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
else:
expected = -1 * weight_np * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestBCELoss(unittest.TestCase):
def test_BCELoss(self):
input_np = np.random.random(size=(20, 30)).astype(np.float64)
label_np = np.random.random(size=(20, 30)).astype(np.float64)
prog = fluid.Program()
startup_prog = fluid.Program()
input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float64)
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for red in reductions:
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[None, 30], dtype='float64')
label = fluid.data(
name='label', shape=[None, 30], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
res = bce_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[res])
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy()
expected = -1 * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
if red == 'mean':
expected = np.mean(expected)
elif red == 'sum':
expected = np.sum(expected)
else:
expected = expected
for reduction in reductions:
static_result = test_static_layer(place, input_np, label_np,
reduction)
dy_result = test_dygraph_layer(place, input_np, label_np,
reduction)
expected = calc_bceloss(input_np, label_np, reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(place, input_np,
label_np, reduction)
dy_functional = test_dygraph_functional(place, input_np,
label_np, reduction)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCELoss_weight(self):
input_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float64)
input_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(
0, 2, size=(2, 3, 4, 10)).astype(np.float64)
weight_np = np.random.random(size=(3, 4, 10)).astype(np.float64)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[None, 3, 4, 10], dtype='float64')
label = fluid.data(
name='label', shape=[None, 3, 4, 10], dtype='float64')
weight = fluid.data(
name='weight', shape=[3, 4, 10], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(weight=weight)
res = bce_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy()
expected = np.mean(-1 * weight_np *
(label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np)))
for reduction in ['sum', 'mean', 'none']:
static_result = test_static_layer(
place, input_np, label_np, reduction, weight_np=weight_np)
dy_result = test_dygraph_layer(
place, input_np, label_np, reduction, weight_np=weight_np)
expected = calc_bceloss(
input_np, label_np, reduction, weight_np=weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(
place, input_np, label_np, reduction, weight_np=weight_np)
dy_functional = test_dygraph_functional(
place, input_np, label_np, reduction, weight_np=weight_np)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCELoss_boardcast(self):
input_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(3, 4, 10)).astype(np.float64)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
static_result = test_static_layer(place, input_np, label_np)
dy_result = test_dygraph_layer(place, input_np, label_np)
expected = calc_bceloss(input_np, label_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
def test_BCELoss_error(self):
paddle.disable_static()
self.assertRaises(
ValueError, paddle.nn.loss.BCELoss, reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
self.assertRaises(
ValueError,
paddle.nn.functional.binary_cross_entropy,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
def bce_loss(input, label):
return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input))
......
......@@ -120,6 +120,7 @@ from .lod import hash #DEFINE_ALIAS
# from .lod import dynamic_gru #DEFINE_ALIAS
# from .lod import dynamic_lstm #DEFINE_ALIAS
# from .lod import dynamic_lstmp #DEFINE_ALIAS
from .loss import binary_cross_entropy #DEFINE_ALIAS
from .loss import bpr_loss #DEFINE_ALIAS
from .loss import center_loss #DEFINE_ALIAS
from .loss import cross_entropy #DEFINE_ALIAS
......
......@@ -14,7 +14,7 @@
import paddle
# TODO: define loss functions of neural network
# TODO: define loss functions of neural network
import numpy as np
import paddle
import paddle.fluid as fluid
......@@ -42,9 +42,11 @@ from ...fluid.layers import huber_loss #DEFINE_ALIAS
from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from ...fluid.framework import _varbase_creator
from ...fluid.framework import Variable
__all__ = [
'binary_cross_entropy',
'bpr_loss',
'center_loss',
'cross_entropy',
......@@ -73,6 +75,142 @@ __all__ = [
]
def binary_cross_entropy(input, label, weight=None, reduction='mean',
name=None):
"""
This op measures the binary_cross_entropy loss between input predictions ``input``
and target labels ``label`` . The binary_cross_entropy loss can be described as:
If :attr:`weight` is set, the loss is:
.. math::
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
.. math::
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
should be numbers between 0 and 1.
Parameters:
input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``input``
should always be the output of sigmod. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``input``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is scalar.
Examples:
.. code-block:: python
import paddle
import numpy as np
input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
paddle.disable_static()
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
output = paddle.nn.functional.binary_cross_entropy(input, label)
print(output.numpy()) # [0.65537095]
paddle.enable_static()
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy should be 'sum', "
"'mean' or 'none', but received %s, which is not allowed." %
reduction)
if in_dygraph_mode():
one = _varbase_creator(dtype=input.dtype)
core.ops.fill_constant(one, 'value',
float(1.0), 'force_cpu', False, 'dtype',
one.dtype, 'str_value', '1.0', 'shape', [1])
one.stop_gradient = True
label_minus = core.ops.elementwise_sub(label, one)
input_minus = core.ops.elementwise_sub(one, input)
input_minus_log = core.ops.log(input_minus)
input_log = core.ops.log(input)
loss_1 = core.ops.elementwise_mul(label_minus, input_minus_log)
loss_2 = core.ops.elementwise_mul(label, input_log)
out = core.ops.elementwise_sub(loss_1, loss_2)
if weight is not None:
out = core.ops.elementwise_mul(out, weight, 'axis', -1)
if reduction == 'sum':
return core.ops.reduce_sum(out, 'dim', [0], 'keep_dim', False,
"reduce_all", True)
elif reduction == 'mean':
return core.ops.reduce_mean(out, 'dim', [0], 'keep_dim', False,
"reduce_all", True)
else:
return out
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'binary_cross_entropy')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'binary_cross_entropy')
one = paddle.fill_constant(shape=[1], value=1.0, dtype=input.dtype)
one.stop_gradient = True
label_minus = paddle.elementwise_sub(label, one)
input_minus = paddle.elementwise_sub(one, input)
input_minus_log = paddle.log(input_minus)
input_log = paddle.log(input)
loss_1 = paddle.multiply(label_minus, input_minus_log)
loss_2 = paddle.multiply(label, input_log)
sub_name = name if weight is None and reduction is 'none' else None
out = paddle.elementwise_sub(loss_1, loss_2, name=sub_name)
if weight is not None:
if isinstance(weight, paddle.framework.Variable):
weight_name = name if reduction is 'none' else None
out = paddle.multiply(out, weight, axis=-1, name=weight_name)
else:
raise ValueError(
"The weight is not a Tensor, please convert to Tensor.")
if reduction == 'sum':
return paddle.sum(out, name=name)
elif reduction == 'mean':
return paddle.mean(out, name=name)
else:
return out
def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
"""
This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
......@@ -106,7 +244,7 @@ def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``.
delta (float, optional): Specifies the hyperparameter delta to be used.
delta (float, optional): Specifies the hyperparameter delta to be used.
The value determines how large the errors need to be to use L1. Errors
smaller than delta are minimized with L2. Parameter is ignored for
negative/zero values. Default = 1.0
......@@ -159,9 +297,9 @@ def margin_ranking_loss(input,
name=None):
"""
This op the calcluate the the margin rank loss between the input, other and label, use the math function as follows.
This op the calcluate the the margin rank loss between the input, other and label, use the math function as follows.
.. math::
.. math::
margin\_rank\_loss = max(0, -label * (input - other) + margin)
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
......@@ -179,7 +317,7 @@ def margin_ranking_loss(input,
Parameters:
input(Tensor): the first input tensor, it's data type should be float32, float64.
other(Tensor): the second input tensor, it's data type should be float32, float64.
label(Tensor): the label value corresponding to input, it's data type should be float32, float64.
label(Tensor): the label value corresponding to input, it's data type should be float32, float64.
margin (float, optional): The margin value to add, default value is 0;
reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -190,15 +328,15 @@ def margin_ranking_loss(input,
.. code-block:: python
import numpy as np
import paddle
import numpy as np
import paddle
paddle.disable_static()
input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype('float32'))
other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype('float32'))
label = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype('float32'))
loss = paddle.nn.functional.margin_ranking_loss(input, other, label)
loss = paddle.nn.functional.margin_ranking_loss(input, other, label)
print(loss.numpy()) # [0.75]
"""
if fluid.framework.in_dygraph_mode():
......@@ -274,15 +412,15 @@ def l1_loss(input, label, reduction='mean', name=None):
.. math::
Out = SUM(\lvert input - label\rvert)
Parameters:
input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
reduction (str, optional): Indicate the reduction to apply to the loss,
reduction (str, optional): Indicate the reduction to apply to the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'none'``, the unreduced loss is returned;
If `reduction` is ``'mean'``, the reduced mean loss is returned.
If `reduction` is ``'sum'``, the reduced sum loss is returned.
If `reduction` is ``'none'``, the unreduced loss is returned;
If `reduction` is ``'mean'``, the reduced mean loss is returned.
If `reduction` is ``'sum'``, the reduced sum loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -293,7 +431,7 @@ def l1_loss(input, label, reduction='mean', name=None):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
......@@ -301,16 +439,16 @@ def l1_loss(input, label, reduction='mean', name=None):
label = paddle.to_variable(label_data)
l1_loss = paddle.nn.functional.l1_loss(input, label)
print(l1_loss.numpy())
print(l1_loss.numpy())
# [0.35]
l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='none')
print(l1_loss.numpy())
print(l1_loss.numpy())
# [[0.20000005 0.19999999]
# [0.2 0.79999995]]
l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='sum')
print(l1_loss.numpy())
print(l1_loss.numpy())
# [1.4]
"""
if reduction not in ['sum', 'mean', 'none']:
......@@ -466,21 +604,21 @@ def kl_div(input, label, reduction='mean', name=None):
While :math:`x` is input and :math:`y` is label.
While :attr:`reduction` is :attr:`none`, output loss is in
the same shape as input, loss in each point is calculated
the same shape as input, loss in each point is calculated
seperately and no reduction is applied.
While :attr:`reduction` is :attr:`mean`, output loss is in
shape of [1] and loss value is the mean value of all losses.
While :attr:`reduction` is :attr:`sum`, output loss is in
shape of [1] and loss value is the sum value of all losses.
While :attr:`reduction` is :attr:`batchmean`, output loss is
While :attr:`reduction` is :attr:`batchmean`, output loss is
in shape of [1] and loss value is the sum value of all losses
divided by batch size.
Args:
input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means
input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means
any number of additional dimensions. It's data type should be float32, float64.
label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64.
reduction (Tensor): Indicate how to average the loss,
......@@ -490,7 +628,7 @@ def kl_div(input, label, reduction='mean', name=None):
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
name(str, optional): Name for the operation (optional, default is None). For more information,
name(str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.
Returns:
......@@ -502,9 +640,9 @@ def kl_div(input, label, reduction='mean', name=None):
import paddle
import numpy as np
import paddle.nn.functional as F
paddle.enable_imperative()
shape = (5, 20)
input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
......@@ -513,7 +651,7 @@ def kl_div(input, label, reduction='mean', name=None):
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='batchmean')
# shape=[5]
# 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_variable(input),
paddle.to_variable(target), reduction='mean')
......@@ -587,7 +725,7 @@ def mse_loss(input, label, reduction='mean', name=None):
Tensor: The tensor tensor storing the mean square error difference of input and label.
Return type: Tensor.
Examples:
.. code-block:: python
......
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define loss functions of neural network
# TODO: define loss functions of neural network
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle
from .. import functional as F
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
__all__ = [
# 'NCELoss',
......@@ -61,8 +62,8 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
Parameters:
input (Variable): Input tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): Label tensor, the data type is int64. Shape is (N), where each
is (N, C, D1, D2,..., Dk), k >= 1.
label (Variable): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
weight (Variable, optional): Weight tensor, a manual rescaling weight given
......@@ -180,9 +181,9 @@ class MSELoss(fluid.dygraph.layers.Layer):
label (Variable): Label tensor, the data type is float32,
reduction (string, optional): The reduction method for the output,
could be 'none' | 'mean' | 'sum'.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``.
Returns:
......@@ -274,23 +275,23 @@ class L1Loss(fluid.dygraph.Layer):
.. math::
Out = SUM(\lvert input - label\rvert)
Parameters:
reduction (str, optional): Indicate the reduction to apply to the loss,
reduction (str, optional): Indicate the reduction to apply to the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'none'``, the unreduced loss is returned;
If `reduction` is ``'mean'``, the reduced mean loss is returned.
If `reduction` is ``'sum'``, the reduced sum loss is returned.
If `reduction` is ``'none'``, the unreduced loss is returned;
If `reduction` is ``'mean'``, the reduced mean loss is returned.
If `reduction` is ``'sum'``, the reduced sum loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64.
output (Tensor): The L1 Loss of ``input`` and ``label``.
output (Tensor): The L1 Loss of ``input`` and ``label``.
If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` .
If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
Examples:
.. code-block:: python
import paddle
......@@ -304,17 +305,17 @@ class L1Loss(fluid.dygraph.Layer):
l1_loss = paddle.nn.loss.L1Loss()
output = l1_loss(input, label)
print(output.numpy())
print(output.numpy())
# [0.35]
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
output = l1_loss(input, label)
print(output.numpy())
print(output.numpy())
# [1.4]
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
output = l1_loss(input, label)
print(output.numpy())
print(output.numpy())
# [[0.20000005 0.19999999]
# [0.2 0.79999995]]
"""
......@@ -335,90 +336,80 @@ class L1Loss(fluid.dygraph.Layer):
class BCELoss(fluid.dygraph.Layer):
"""
:alias_main: paddle.nn.BCELoss
:alias: paddle.nn.BCELoss,paddle.nn.layer.BCELoss,paddle.nn.layer.loss.BCELoss
This interface is used to construct a callable object of the ``BCELoss`` class.
The BCELoss layer measures the binary_cross_entropy loss between input predictions
and target labels. The binary_cross_entropy loss can be described as:
The BCELoss layer measures the binary_cross_entropy loss between input predictions ``input``
and target labels ``label`` . The binary_cross_entropy loss can be described as:
If :attr:`weight` is set, the loss is:
.. math::
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
.. math::
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the unreduced loss is:
If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
.. math::
Out = Out
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Note that the input predictions always be the output of sigmoid, and the target labels
Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
should be numbers between 0 and 1.
The shape of input predictions and target labels are [N, *], where N is batch_size and `*`
means any number of additional dimensions. If ``reduction`` is ``'none'``, the shape of
output is scalar, else the shape of output is same as input.
Parameters:
weight (Variable, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Variable of size nbatch and the data type
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Shape:
input (Tensor): 2-D tensor with shape: (N, *), N is batch_size, `*` means
number of additional dimensions. The input ``input`` should always
be the output of sigmod. Available dtype is float32, float64.
label (Tensor): 2-D tensor with the same shape as ``input``. The target
labels which values should be numbers between 0 and 1. Available
dtype is float32, float64.
output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is scalar.
Returns:
A callable object of BCELoss.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
input = fluid.data(name="input", shape=[3, 1], dtype='float32')
label = fluid.data(name="label", shape=[3, 1], dtype='float32')
bce_loss = paddle.nn.loss.BCELoss()
output = bce_loss(input, label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data) # [array([0.65537095], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = bce_loss(input, label)
print(output.numpy()) # [0.65537095]
paddle.disable_static()
input = paddle.to_variable(input_data)
label = paddle.to_variable(label_data)
bce_loss = paddle.nn.loss.BCELoss()
output = bce_loss(input, label)
print(output.numpy()) # [0.65537095]
paddle.enable_static()
"""
def __init__(self, weight=None, reduction='mean'):
def __init__(self, weight=None, reduction='mean', name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
......@@ -427,38 +418,12 @@ class BCELoss(fluid.dygraph.Layer):
super(BCELoss, self).__init__()
self.weight = weight
self.reduction = reduction
self.name = name
def forward(self, input, label):
dtype = self._helper.input_dtype(input)
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'bce_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'bce_loss')
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
self._helper.append_op(
type='bce_loss',
inputs={
'X': [input],
'Label': [label],
},
outputs={'Out': [out]})
if self.weight is not None:
if isinstance(self.weight, fluid.framework.Variable):
w = self.weight
out = fluid.layers.elementwise_mul(out, w, axis=-1)
else:
raise ValueError(
"The weight is not a Variable, please convert to Variable.")
if self.reduction == 'sum':
return fluid.layers.reduce_sum(out)
elif self.reduction == 'mean':
return fluid.layers.reduce_mean(out)
else:
return out
out = paddle.nn.functional.binary_cross_entropy(
input, label, self.weight, self.reduction, self.name)
return out
class NLLLoss(fluid.dygraph.Layer):
......@@ -468,18 +433,18 @@ class NLLLoss(fluid.dygraph.Layer):
This class accepts input and target label and returns negative log likelihood
cross error. It is useful to train a classification problem with C classes.
The input for the loss is epected to contain log-probabilities of
each classes. It has to be a Tensor of size either (batch_size, C) or
(batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
The label for the loss should be a class index in the range [0, C-1]
where C is the number of classes. If ignore_index is specified, the
specified target value does not contribute to the input gradient.
If the optional argument `weight` is provided, it should be a 1D Tensor
assigning weight to each of the classed. This is particularly useful
when you have an unbalanced training set.
The loss is calculated as follows.
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
......@@ -502,11 +467,11 @@ class NLLLoss(fluid.dygraph.Layer):
Parameters:
weight (Tensor, optional): Weight tensor, a manual rescaling weight given
to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise,
it treated as if having all ones. the data type is
it treated as if having all ones. the data type is
float32, float64, Default is ``'None'``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient.
reduction (str, optional): Indicate how to average the loss,
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'mean'``, the reduced mean loss is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
......@@ -587,9 +552,9 @@ class KLDivLoss(fluid.dygraph.Layer):
$$l(x, y) = y * (\log(y) - x)$$
Parameters:
reduction (str, optional): Indicate how to average the loss,
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
Shape:
......@@ -604,7 +569,7 @@ class KLDivLoss(fluid.dygraph.Layer):
import paddle
import numpy as np
import paddle.nn as nn
paddle.enable_imperative()
shape = (5, 20)
......@@ -616,7 +581,7 @@ class KLDivLoss(fluid.dygraph.Layer):
pred_loss = kldiv_criterion(paddle.to_variable(x),
paddle.to_variable(target))
# shape=[5]
# 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean')
pred_loss = kldiv_criterion(paddle.to_variable(x),
......@@ -649,10 +614,10 @@ class MarginRankingLoss(fluid.dygraph.Layer):
"""
This interface is used to construct a callable object of the ``MarginRankingLoss`` class.
The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
The MarginRankingLoss layer calculates the margin rank loss between the input, other and label
, use the math function as follows.
.. math::
.. math::
margin\_rank\_loss = max(0, -label * (input - other) + margin)
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
......@@ -672,7 +637,7 @@ class MarginRankingLoss(fluid.dygraph.Layer):
reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'``, ``'mean'``, ``'sum'``.If :attr:`reduction` is ``'none'``, the unreduced loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
Shape:
input: N-D Tensor, the shape is [N, *], N is batch size and `*` means any number of additional dimensions., available dtype is float32, float64.
other: N-D Tensor, `other` have the same shape and dtype as `input`.
label: N-D Tensor, label have the same shape and dtype as `input`.
......@@ -685,16 +650,16 @@ class MarginRankingLoss(fluid.dygraph.Layer):
.. code-block:: python
import numpy as np
import paddle
import numpy as np
import paddle
paddle.disable_static()
input = paddle.to_variable(np.array([[1, 2], [3, 4]]).astype("float32"))
other = paddle.to_variable(np.array([[2, 1], [2, 4]]).astype("float32"))
label = paddle.to_variable(np.array([[1, -1], [-1, -1]]).astype("float32"))
margin_rank_loss = paddle.nn.MarginRankingLoss()
loss = margin_rank_loss(input, other, label)
loss = margin_rank_loss(input, other, label)
print(loss.numpy()) # [0.75]
"""
......@@ -741,7 +706,7 @@ class SmoothL1Loss(fluid.dygraph.Layer):
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``.
delta (float, optional): Specifies the hyperparameter delta to be used.
delta (float, optional): Specifies the hyperparameter delta to be used.
The value determines how large the errors need to be to use L1. Errors
smaller than delta are minimized with L2. Parameter is ignored for
negative/zero values. Default = 1.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册