未验证 提交 988fbf82 编写于 作者: G Guanghua Yu 提交者: GitHub

Fix bug with wrong calculation result in `nn.loss.CrossEntropyLoss` (#24352)

* fix bug of cross_entropy_loss,test=develop
* fix log_softmax and some comment,test=develop
上级 8d0bae2d
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# TODO: define loss functions of neural network # TODO: define loss functions of neural network
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
__all__ = [ __all__ = [
# 'NCELoss', # 'NCELoss',
...@@ -27,8 +28,8 @@ __all__ = [ ...@@ -27,8 +28,8 @@ __all__ = [
class CrossEntropyLoss(fluid.dygraph.Layer): class CrossEntropyLoss(fluid.dygraph.Layer):
""" """
This operator implements the cross entropy loss function. This OP combines ``softmax``, This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``,
``cross_entropy``, and ``reduce_sum``/``reduce_mean`` together. and ``NLLLoss`` together.
It is useful when training a classification problem with ``C`` classes. It is useful when training a classification problem with ``C`` classes.
If provided, the optional argument ``weight`` should be a 1D Variable assigning If provided, the optional argument ``weight`` should be a 1D Variable assigning
...@@ -49,19 +50,23 @@ class CrossEntropyLoss(fluid.dygraph.Layer): ...@@ -49,19 +50,23 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K
Parameters: Parameters:
input (Variable): Input tensor, the data type is float32, input (Variable): Input tensor, the data type is float32, float64. Shape is
float64, int32, int64. (N, C), where C is number of classes, and if shape is more than 2D, this
label (Variable): Label tensor, the data type is float32, is (N, C, D1, D2,..., Dk), k >= 1.
float64, int32, int64. label (Variable): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
weight (Variable, optional): Weight tensor, a manual rescaling weight given weight (Variable, optional): Weight tensor, a manual rescaling weight given
to each class. It has the same dimensions as class number and the data type to each class and the shape is (C). It has the same dimensions as class
is float32, float64, int32, int64. Default is ``'None'``. number and the data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size, reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned. If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``. Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default is ``-100``.
Returns: Returns:
The tensor variable storing the cross_entropy_loss of input and label. The tensor variable storing the cross_entropy_loss of input and label.
...@@ -76,17 +81,17 @@ class CrossEntropyLoss(fluid.dygraph.Layer): ...@@ -76,17 +81,17 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
input = fluid.layers.data(name='input', shape=[5, 100], dtype='float32') input = fluid.data(name='input', shape=[5, 100], dtype='float64')
label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64') label = fluid.data(name='label', shape=[5], dtype='int64')
weight = fluid.layers.data(name='weight', shape=[100], dtype='float32') weight = fluid.data(name='weight', shape=[100], dtype='float64')
ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean') ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean')
output = ce_loss(input,label) output = ce_loss(input, label)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
input_data = np.random.random([5, 100]).astype("float32") input_data = np.random.random([5, 100]).astype("float64")
label_data = np.array([[1], [9], [40], [50], [90]]).astype("int64") label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
weight_data = np.random.random([100]).astype("float32") weight_data = np.random.random([100]).astype("float64")
output = exe.run(fluid.default_main_program(), output = exe.run(fluid.default_main_program(),
feed={"input": input_data, "label": label_data,"weight": weight_data}, feed={"input": input_data, "label": label_data,"weight": weight_data},
fetch_list=[output], fetch_list=[output],
...@@ -104,41 +109,36 @@ class CrossEntropyLoss(fluid.dygraph.Layer): ...@@ -104,41 +109,36 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
print(output.numpy()) print(output.numpy())
""" """
def __init__(self, weight=None, reduction='mean'): def __init__(self, weight=None, reduction='mean', ignore_index=-100):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
self.weight = weight self.weight = weight
self.reduction = reduction self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, input, label): def forward(self, input, label):
fluid.data_feeder.check_variable_and_dtype( fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], input, 'input', ['float32', 'float64'], 'cross_entropy_loss')
'cross_entropy_loss') fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
fluid.data_feeder.check_variable_and_dtype( 'cross_entropy_loss')
label, 'label', ['float32', 'float64', 'int32', 'int64'],
'cross_entropy_loss')
if self.reduction not in ['sum', 'mean', 'none']: if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError( raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or 'none'," "The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" but received %s, which is not allowed." % self.reduction) " 'none', but received %s, which is not allowed." %
self.reduction)
softmax_out = fluid.layers.softmax(input)
if self.weight is not None: log_softmax = paddle.nn.LogSoftmax()
if isinstance(self.weight, fluid.framework.Variable): log_softmax_out = log_softmax(input)
softmax_out = fluid.layers.elementwise_pow( if self.weight is not None and not isinstance(self.weight,
softmax_out, self.weight, axis=-1) fluid.framework.Variable):
else: raise ValueError(
raise ValueError( "The weight' is not a Variable, please convert to Variable.")
"The weight' is not a Variable, please convert to Variable.") nll_loss = paddle.nn.loss.NLLLoss(
weight=self.weight,
reduction=self.reduction,
ignore_index=self.ignore_index)
out = fluid.layers.cross_entropy(softmax_out, label) return nll_loss(log_softmax_out, label)
if self.reduction == 'sum':
return fluid.layers.reduce_sum(out)
elif self.reduction == 'mean':
return fluid.layers.reduce_mean(out)
else:
return out
class MSELoss(fluid.dygraph.layers.Layer): class MSELoss(fluid.dygraph.layers.Layer):
...@@ -578,7 +578,6 @@ class NLLLoss(fluid.dygraph.Layer): ...@@ -578,7 +578,6 @@ class NLLLoss(fluid.dygraph.Layer):
inputs = {'X': input, 'Label': label} inputs = {'X': input, 'Label': label}
attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index} attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index}
if self.weight is not None: if self.weight is not None:
if isinstance(self.weight, fluid.framework.Variable): if isinstance(self.weight, fluid.framework.Variable):
inputs['Weight'] = self.weight inputs['Weight'] = self.weight
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册