未验证 提交 b5242732 编写于 作者: C chajchaj 提交者: GitHub

add soft_label and axis for CrossEntropyLoss and improve performance (#29024)

* add soft_label and axis for CrossEntropyLoss and improve performance,test=develop

* fix conflict in nn/functional/loss.py, test=develop
上级 018e1699
...@@ -128,6 +128,8 @@ from .loss import binary_cross_entropy #DEFINE_ALIAS ...@@ -128,6 +128,8 @@ from .loss import binary_cross_entropy #DEFINE_ALIAS
from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS from .loss import binary_cross_entropy_with_logits #DEFINE_ALIAS
# from .loss import bpr_loss #DEFINE_ALIAS # from .loss import bpr_loss #DEFINE_ALIAS
# from .loss import center_loss #DEFINE_ALIAS # from .loss import center_loss #DEFINE_ALIAS
#from .loss import cross_entropy #DEFINE_ALIAS
from .loss import softmax_cross_entropy #DEFINE_ALIAS
from .loss import cross_entropy #DEFINE_ALIAS from .loss import cross_entropy #DEFINE_ALIAS
from .loss import dice_loss #DEFINE_ALIAS from .loss import dice_loss #DEFINE_ALIAS
from .loss import hsigmoid_loss #DEFINE_ALIAS from .loss import hsigmoid_loss #DEFINE_ALIAS
......
...@@ -42,6 +42,7 @@ __all__ = [ ...@@ -42,6 +42,7 @@ __all__ = [
'binary_cross_entropy', 'binary_cross_entropy',
'binary_cross_entropy_with_logits', 'binary_cross_entropy_with_logits',
'cross_entropy', 'cross_entropy',
'softmax_cross_entropy',
'dice_loss', 'dice_loss',
'hsigmoid_loss', 'hsigmoid_loss',
'kl_div', 'kl_div',
...@@ -1120,39 +1121,73 @@ def cross_entropy(input, ...@@ -1120,39 +1121,73 @@ def cross_entropy(input,
label, label,
weight=None, weight=None,
ignore_index=-100, ignore_index=-100,
reduction='mean'): reduction='mean',
r""" soft_label=False,
This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``, axis=-1,
and ``NLLLoss`` together. name=None):
return softmax_cross_entropy(
input=input,
label=label,
weight=weight,
ignore_index=ignore_index,
reduction=reduction,
soft_label=soft_label,
axis=axis,
name=name)
def softmax_cross_entropy(input,
label,
weight=None,
ignore_index=-100,
reduction='mean',
soft_label=False,
axis=-1,
name=None):
"""
This operator implements the cross entropy loss function with softmax. This function
combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
It is useful when training a classification problem with ``C`` classes. When the attribute :attr:`soft_label` is set :attr:`False`, this operators
If provided, the optional argument ``weight`` should be a 1D Variable assigning expects mutually exclusive hard labels, each sample in a batch is in exactly
weight to each of the classes. one class with a probability of 1.0. Each sample in the batch will have a
single label.
For predictions label, and target label, the loss is calculated as follows. The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math:: .. math::
loss_j = -\\text{input[class]} + loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
If weight is not ``None``: 2) Soft label (each sample can have a distribution over all classes)
.. math:: .. math::
loss_j = \\text{weight[class]}(-\\text{input[class]} + loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K \\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K
It is useful when training a classification problem with ``C`` classes.
Parameters: Parameters:
input (Tensor): Input tensor, the data type is float32, float64. Shape is input (Tensor): Input tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this (N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1. is (N, D1, D2,..., Dk, C), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1. (N, D1, D2,..., Dk), k >= 1.
weight (Tensor, optional): Weight tensor, a manual rescaling weight given weight (Tensor, optional):a manual rescaling weight given to each class.
to each class and the shape is (C). It has the same dimensions as class If given, has to be a Tensor of size C and the data type is float32, float64.
number and the data type is float32, float64. Default is ``'None'``. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size, reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
...@@ -1161,88 +1196,103 @@ def cross_entropy(input, ...@@ -1161,88 +1196,103 @@ def cross_entropy(input,
Default is ``'mean'``. Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default is ``-100``. and does not contribute to the input gradient. Default is ``-100``.
soft_label (bool): indicate whether label is soft. Default False, meaning that
the label is hard. If soft_label=True, the label is soft.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns: Returns:
The tensor variable storing the cross_entropy_loss of input and label. The tensor variable storing the cross_entropy_loss of input and label.
Return type: Tensor. Return type: Variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static() import paddle.nn.functional as F
input_data = np.random.random([5, 100]).astype("float64") import numpy as np
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) input_np = np.random.random([2, 4]).astype(np.float64)
weight_data = np.random.random([100]).astype("float64") label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)
input = paddle.to_tensor(input_data) weight_np = np.random.random([4]).astype(np.float64) #shape:C
label = paddle.to_tensor(label_data) output = F.softmax_cross_entropy(
weight = paddle.to_tensor(weight_data) paddle.to_tensor(input_np),
loss = paddle.nn.functional.cross_entropy(input=input, label=label, weight=weight) paddle.to_tensor(label_np),
print(loss.numpy()) weight=paddle.to_tensor(weight_np))
print(output.numpy()) #[1.30719427]
""" """
if not in_dygraph_mode():
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'cross_entropy_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'cross_entropy_loss')
if reduction not in ['sum', 'mean', 'none']: if reduction not in ['sum', 'mean', 'none']:
raise ValueError( raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or" "The value of 'reduction' in softmax_cross_entropy"
" 'none', but received %s, which is not allowed." % reduction) "should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
% reduction)
#step 1. log_softmax input_dims = len(list(input.shape))
log_softmax_out = paddle.nn.functional.log_softmax(input, axis=1) label_dims = len(list(label.shape))
if weight is not None and not isinstance(weight, Variable): if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError( raise ValueError(
"The weight' is not a Variable, please convert to Variable.") 'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(input_dims, label_dims))
#step 2. nll_loss if input_dims - 1 == label_dims:
input = log_softmax_out label = paddle.unsqueeze(label, axis=axis)
helper = LayerHelper('nll_loss', **locals()) if in_dygraph_mode():
dtype = helper.input_dtype(input) out = softmax_with_cross_entropy(
input,
label,
soft_label=soft_label,
ignore_index=ignore_index,
axis=axis)
if weight is not None:
weight_gather = core.ops.gather_nd(weight, label) #trans to sample
input_shape = list(label.shape)
weight_gather_reshape, _ = core.ops.reshape2(weight_gather, 'shape',
input_shape)
out = core.ops.elementwise_mul(out, weight_gather_reshape)
if not in_dygraph_mode(): if reduction == "sum":
fluid.data_feeder.check_variable_and_dtype( return core.ops.reduce_sum(out, 'reduce_all', True)
input, 'input', ['float32', 'float64'], 'nll_loss') elif reduction == "mean":
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], if weight is not None:
'nll_loss') out_sum = core.ops.reduce_sum(out, 'reduce_all', True)
total_weight = core.ops.reduce_sum(weight_gather_reshape,
x_shape = list(input.shape) 'reduce_all', True)
n = x_shape[0] return out_sum / total_weight
c = x_shape[1] else:
x_dims = len(x_shape) return core.ops.mean(out)
if x_dims < 2: else:
raise ValueError('Expected 2 or more dimensions (got {})'.format( return out
x_dims))
if x_dims != 2 and x_dims != 4:
input = reshape(input, shape=[n, c, 1, -1])
label = reshape(label, shape=[n, 1, -1])
out_shape = [n] + x_shape[2:]
if not in_dygraph_mode(): fluid.data_feeder.check_variable_and_dtype(
fluid.data_feeder.check_variable_and_dtype( input, 'input', ['float32', 'float64'], 'softmax_cross_entropy')
input, 'input', ['float32', 'float64'], 'nll_loss') fluid.data_feeder.check_variable_and_dtype(
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], label, 'label', ['int32', 'int64'], 'softmax_cross_entropy')
'nll_loss') out = softmax_with_cross_entropy(
inputs = {'X': input, 'Label': label} input,
attrs = {'reduction': reduction, 'ignore_index': ignore_index} label,
soft_label=soft_label,
ignore_index=ignore_index,
axis=axis)
if weight is not None: if weight is not None:
if isinstance(weight, Variable): fluid.data_feeder.check_variable_and_dtype(
inputs['Weight'] = weight weight, 'weight', ['float32', 'float64'], 'softmax_cross_entropy')
weight_name = name if reduction == 'none' else None
out = helper.create_variable_for_type_inference(dtype=input.dtype) weight_gather = paddle.gather_nd(weight, label) #trans to sample
total_weight = helper.create_variable_for_type_inference(dtype=input.dtype) input_shape = list(label.shape)
outputs = {'Out': out, 'Total_weight': total_weight} weight_gather_reshape = reshape(weight_gather, shape=input_shape)
out = paddle.multiply(out, weight_gather_reshape, name=weight_name)
helper.append_op(
type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs)
if x_dims != 2 and x_dims != 4 and reduction == 'none':
out = reshape(out, shape=out_shape)
return out if reduction == "sum":
return paddle.sum(out, name=name)
elif reduction == "mean":
if weight is not None:
out_sum = paddle.sum(out, name=name)
total_weight = paddle.sum(weight_gather_reshape)
return out_sum / total_weight
else:
return paddle.mean(out, name=name)
else:
return out
def sigmoid_focal_loss(logit, def sigmoid_focal_loss(logit,
......
...@@ -141,30 +141,40 @@ class BCEWithLogitsLoss(fluid.dygraph.Layer): ...@@ -141,30 +141,40 @@ class BCEWithLogitsLoss(fluid.dygraph.Layer):
class CrossEntropyLoss(fluid.dygraph.Layer): class CrossEntropyLoss(fluid.dygraph.Layer):
r""" """
:alias_main: paddle.nn.CrossEntropyLoss This operator implements the cross entropy loss function with softmax. This function
:alias: paddle.nn.CrossEntropyLoss,paddle.nn.layer.CrossEntropyLoss,paddle.nn.layer.loss.CrossEntropyLoss combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``, Because this operator performs a softmax on logits internally, it expects
and ``NLLLoss`` together. unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
It is useful when training a classification problem with ``C`` classes. When the attribute :attr:`soft_label` is set :attr:`False`, this operators
If provided, the optional argument ``weight`` should be a 1D Variable assigning expects mutually exclusive hard labels, each sample in a batch is in exactly
weight to each of the classes. one class with a probability of 1.0. Each sample in the batch will have a
single label.
For predictions label, and target label, the loss is calculated as follows. The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math:: .. math::
loss_j = -\\text{input[class]} + loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
If weight is not ``None``: 2) Soft label (each sample can have a distribution over all classes)
.. math:: .. math::
loss_j = \\text{weight[class]}(-\\text{input[class]} + loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K \\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K
It is useful when training a classification problem with ``C`` classes.
Parameters: Parameters:
input (Variable): Input tensor, the data type is float32, float64. Shape is input (Variable): Input tensor, the data type is float32, float64. Shape is
...@@ -173,9 +183,9 @@ class CrossEntropyLoss(fluid.dygraph.Layer): ...@@ -173,9 +183,9 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
label (Variable): Label tensor, the data type is int64. Shape is (N), where each label (Variable): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1. (N, D1, D2,..., Dk), k >= 1.
weight (Variable, optional): Weight tensor, a manual rescaling weight given weight (Variable, optional): Weight tensor, a manual rescaling weight for each
to each class and the shape is (C). It has the same dimensions as class sample relative to each class. It has the same shape as label.
number and the data type is float32, float64. Default is ``'None'``. and the data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size, reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
...@@ -184,6 +194,12 @@ class CrossEntropyLoss(fluid.dygraph.Layer): ...@@ -184,6 +194,12 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
Default is ``'mean'``. Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default is ``-100``. and does not contribute to the input gradient. Default is ``-100``.
soft_label (bool): indicate whether label is soft. Default False, meaning that
the label is hard. If soft_label=True, the label is soft.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns: Returns:
The tensor variable storing the cross_entropy_loss of input and label. The tensor variable storing the cross_entropy_loss of input and label.
...@@ -192,64 +208,47 @@ class CrossEntropyLoss(fluid.dygraph.Layer): ...@@ -192,64 +208,47 @@ class CrossEntropyLoss(fluid.dygraph.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
# declarative mode
import paddle import paddle
import paddle.fluid as fluid
import numpy as np import numpy as np
input_np = np.random.random([2, 4]).astype(np.float64)
input = fluid.data(name='input', shape=[5, 100], dtype='float64') label_np = np.random.randint(0, 4, size=(2, 1)).astype(np.int64)
label = fluid.data(name='label', shape=[5], dtype='int64') weight_np = np.random.random([4]).astype(np.float64) #shape:C
weight = fluid.data(name='weight', shape=[100], dtype='float64') weight_ce = weight_np[label_np] #shape:N,1
ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean') cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
output = ce_loss(input, label) weight=paddle.to_tensor(weight_ce))
place = fluid.CPUPlace() output = cross_entropy_loss(
exe = fluid.Executor(place) paddle.to_tensor(input_np),
exe.run(fluid.default_startup_program()) paddle.to_tensor(label_np))
input_data = np.random.random([5, 100]).astype("float64") print(output.numpy()) #[1.44375251]
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
weight_data = np.random.random([100]).astype("float64")
output = exe.run(fluid.default_main_program(),
feed={"input": input_data, "label": label_data,"weight": weight_data},
fetch_list=[output],
return_numpy=True)
print(output)
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
weight = dg.to_variable(weight_data)
ce_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight, reduction='mean')
output = ce_loss(input, label)
print(output.numpy())
""" """
def __init__(self, weight=None, ignore_index=-100, reduction='mean'): def __init__(self,
weight=None,
ignore_index=-100,
reduction='mean',
soft_label=False,
axis=-1,
name=None):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
self.weight = weight self.weight = weight
self.reduction = reduction self.reduction = reduction
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.soft_label = soft_label
self.axis = axis
self.name = name
def forward(self, input, label): def forward(self, input, label):
fluid.data_feeder.check_variable_and_dtype( ret = paddle.nn.functional.softmax_cross_entropy(
input, 'input', ['float32', 'float64'], 'cross_entropy_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'cross_entropy_loss')
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or"
" 'none', but received %s, which is not allowed." %
self.reduction)
return paddle.nn.functional.cross_entropy(
input, input,
label, label,
weight=self.weight, weight=self.weight,
ignore_index=self.ignore_index, ignore_index=self.ignore_index,
reduction=self.reduction) reduction=self.reduction,
soft_label=self.soft_label,
axis=self.axis,
name=self.name)
return ret
class HSigmoidLoss(fluid.dygraph.Layer): class HSigmoidLoss(fluid.dygraph.Layer):
...@@ -491,29 +490,31 @@ class L1Loss(fluid.dygraph.Layer): ...@@ -491,29 +490,31 @@ class L1Loss(fluid.dygraph.Layer):
If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1].
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np
input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]]) paddle.disable_static()
label = paddle.to_tensor([[1.7, 1.0], [0.4, 0.5]]) input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
l1_loss = paddle.nn.loss.L1Loss() l1_loss = paddle.nn.loss.L1Loss()
output = l1_loss(input, label) output = l1_loss(input, label)
print(output) print(output.numpy())
# [0.35] # [0.35]
l1_loss = paddle.nn.loss.L1Loss(reduction='sum') l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
output = l1_loss(input, label) output = l1_loss(input, label)
print(output) print(output.numpy())
# [1.4] # [1.4]
l1_loss = paddle.nn.loss.L1Loss(reduction='none') l1_loss = paddle.nn.loss.L1Loss(reduction='none')
output = l1_loss(input, label) output = l1_loss(input, label)
print(output) print(output.numpy())
# [[0.20000005 0.19999999] # [[0.20000005 0.19999999]
# [0.2 0.79999995]] # [0.2 0.79999995]]
""" """
def __init__(self, reduction='mean', name=None): def __init__(self, reduction='mean', name=None):
...@@ -622,7 +623,9 @@ class BCELoss(fluid.dygraph.Layer): ...@@ -622,7 +623,9 @@ class BCELoss(fluid.dygraph.Layer):
class NLLLoss(fluid.dygraph.Layer): class NLLLoss(fluid.dygraph.Layer):
r""" """
:alias_main: paddle.nn.NLLLoss
:alias: paddle.nn.NLLLoss,paddle.nn.layer.NLLLoss,paddle.nn.layer.loss.NLLLoss
This class accepts input and target label and returns negative log likelihood This class accepts input and target label and returns negative log likelihood
cross error. It is useful to train a classification problem with C classes. cross error. It is useful to train a classification problem with C classes.
...@@ -689,7 +692,7 @@ class NLLLoss(fluid.dygraph.Layer): ...@@ -689,7 +692,7 @@ class NLLLoss(fluid.dygraph.Layer):
import paddle import paddle
import numpy as np import numpy as np
nll_loss = paddle.nn.NLLLoss() nll_loss = paddle.nn.layer.NLLLoss()
log_softmax = paddle.nn.LogSoftmax(axis=1) log_softmax = paddle.nn.LogSoftmax(axis=1)
input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ], input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ],
...@@ -699,11 +702,13 @@ class NLLLoss(fluid.dygraph.Layer): ...@@ -699,11 +702,13 @@ class NLLLoss(fluid.dygraph.Layer):
[0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32) [0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32)
label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64)
place = paddle.CPUPlace()
paddle.disable_static(place)
input = paddle.to_tensor(input_np) input = paddle.to_tensor(input_np)
log_out = log_softmax(input) log_out = log_softmax(input)
label = paddle.to_tensor(label_np) label = paddle.to_tensor(label_np)
result = nll_loss(log_out, label) result = nll_loss(log_out, label)
print(result) # [1.0720209] print(result.numpy()) # [1.0720209]
""" """
...@@ -999,7 +1004,7 @@ class SmoothL1Loss(fluid.dygraph.Layer): ...@@ -999,7 +1004,7 @@ class SmoothL1Loss(fluid.dygraph.Layer):
is the same as the shape of input. is the same as the shape of input.
Returns: Returns:
The tensor storing the smooth_l1_loss of input and label. The tensor variable storing the smooth_l1_loss of input and label.
Return type: Tensor. Return type: Tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册