未验证 提交 5407e327 编写于 作者: C chajchaj 提交者: GitHub

add cross_entropy to nn/layer and nn/functional, test=develop (#26478)

* add cross_entropy to nn/layer and nn/functional, test=develop

* use functional/cross_entropy in layer/CrossEntropy

* use functional/cross_entropy in layer/CrossEntropy, test=develop
上级 e167e879
......@@ -22,7 +22,6 @@ from ...fluid.framework import core, in_dygraph_mode
from ...fluid.layers.nn import _elementwise_op_in_dygraph
from ...fluid.layers import bpr_loss #DEFINE_ALIAS
from ...fluid.layers import center_loss #DEFINE_ALIAS
from ...fluid.layers import cross_entropy #DEFINE_ALIAS
from ...fluid.layers import dice_loss #DEFINE_ALIAS
from ...fluid.layers import iou_similarity #DEFINE_ALIAS
from ...fluid.layers import log_loss #DEFINE_ALIAS
......@@ -786,3 +785,132 @@ def mse_loss(input, label, reduction='mean', name=None):
return paddle.sum(paddle.fluid.layers.square(
paddle.fluid.layers.elementwise_sub(input, label)),
name=name)
def cross_entropy(input,
label,
weight=None,
ignore_index=-100,
reduction='mean'):
"""
This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``,
and ``NLLLoss`` together.
It is useful when training a classification problem with ``C`` classes.
If provided, the optional argument ``weight`` should be a 1D Variable assigning
weight to each of the classes.
For predictions label, and target label, the loss is calculated as follows.
.. math::
loss_j = -\\text{input[class]} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K
If weight is not ``None``:
.. math::
loss_j = \\text{weight[class]}(-\\text{input[class]} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K
Parameters:
input (Tensor): Input tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
weight (Tensor, optional): Weight tensor, a manual rescaling weight given
to each class and the shape is (C). It has the same dimensions as class
number and the data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default is ``-100``.
Returns:
The tensor variable storing the cross_entropy_loss of input and label.
Return type: Tensor.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
input_data = np.random.random([5, 100]).astype("float64")
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
weight_data = np.random.random([100]).astype("float64")
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
weight = paddle.to_tensor(weight_data)
loss = paddle.nn.functional.cross_entropy(input=input, label=label, weight=weight)
print(loss.numpy())
"""
if not in_dygraph_mode():
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'cross_entropy_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'cross_entropy_loss')
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" 'none', but received %s, which is not allowed." % reduction)
#step 1. log_softmax
log_softmax_out = paddle.nn.functional.log_softmax(input)
if weight is not None and not isinstance(weight, Variable):
raise ValueError(
"The weight' is not a Variable, please convert to Variable.")
#step 2. nll_loss
input = log_softmax_out
helper = LayerHelper('nll_loss', **locals())
dtype = helper.input_dtype(input)
if not in_dygraph_mode():
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'nll_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'nll_loss')
x_shape = list(input.shape)
n = x_shape[0]
c = x_shape[1]
x_dims = len(x_shape)
if x_dims < 2:
raise ValueError('Expected 2 or more dimensions (got {})'.format(
x_dims))
if x_dims != 2 and x_dims != 4:
input = reshape(input, shape=[n, c, 1, -1])
label = reshape(label, shape=[n, 1, -1])
out_shape = [n] + x_shape[2:]
if not in_dygraph_mode():
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'nll_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'nll_loss')
inputs = {'X': input, 'Label': label}
attrs = {'reduction': reduction, 'ignore_index': ignore_index}
if weight is not None:
if isinstance(weight, Variable):
inputs['Weight'] = weight
out = helper.create_variable_for_type_inference(dtype=input.dtype)
total_weight = helper.create_variable_for_type_inference(dtype=input.dtype)
outputs = {'Out': out, 'Total_weight': total_weight}
helper.append_op(
type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs)
if x_dims != 2 and x_dims != 4 and reduction == 'none':
out = reshape(out, shape=out_shape)
return out
......@@ -21,7 +21,6 @@ from .. import functional as F
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
__all__ = [
# 'NCELoss',
'CrossEntropyLoss',
'MSELoss',
'L1Loss',
......@@ -119,7 +118,7 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
print(output.numpy())
"""
def __init__(self, weight=None, reduction='mean', ignore_index=-100):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self.weight = weight
self.reduction = reduction
......@@ -137,18 +136,12 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
" 'none', but received %s, which is not allowed." %
self.reduction)
log_softmax = paddle.nn.LogSoftmax()
log_softmax_out = log_softmax(input)
if self.weight is not None and not isinstance(self.weight,
fluid.framework.Variable):
raise ValueError(
"The weight' is not a Variable, please convert to Variable.")
nll_loss = paddle.nn.loss.NLLLoss(
return paddle.nn.functional.cross_entropy(
input,
label,
weight=self.weight,
reduction=self.reduction,
ignore_index=self.ignore_index)
return nll_loss(log_softmax_out, label)
ignore_index=self.ignore_index,
reduction=self.reduction)
class MSELoss(fluid.dygraph.layers.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册