diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 7412eede118d122b14c69ab663836c156eb740e2..b32f5e8847d30fc785587541ccdc74d99d2b025c 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -40,6 +40,7 @@ std::map> op_ins_map = { {"assign", {"X"}}, {"fake_quantize_dequantize_moving_average_abs_max", {"X", "InScale", "InAccum", "InState"}}, + {"nll_loss", {"X", "Label", "Weight"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/tests/unittests/test_nll_loss.py b/python/paddle/fluid/tests/unittests/test_nll_loss.py index b14e3a15d979c6f66c2ffeeeec6536d5a8ab3b47..c25f8832807bc9a9da84ee44ee8172e8d1d0dd94 100644 --- a/python/paddle/fluid/tests/unittests/test_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_nll_loss.py @@ -445,7 +445,6 @@ class TestNLLLoss(unittest.TestCase): startup_prog = fluid.Program() place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( ) else fluid.CPUPlace() - #place = fluid.CPUPlace() with fluid.program_guard(prog, startup_prog): input = fluid.data( name='input', shape=[5, 3, 5, 5], dtype='float64') @@ -879,5 +878,105 @@ class TestNLLLossOp2DNoReduce(OpTest): self.label_shape = [5, 5, 5] +class TestNLLLossName(unittest.TestCase): + def test_name(self): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + place = paddle.CPUPlace() + with paddle.static.program_guard(prog, startup_prog): + x = paddle.data(name='x', shape=[10, 10], dtype='float64') + label = paddle.data(name='label', shape=[10], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss(name='nll_loss') + res = nll_loss(x, label) + self.assertTrue(res.name.startswith('nll_loss')) + + +class TestNLLLossInvalidArgs(unittest.TestCase): + def test_x_dim_value_error(self): + def test_x_dim_lt_2(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + place = paddle.CPUPlace() + with paddle.static.program_guard(prog, startup_prog): + x = paddle.data(name='x', shape=[10, ], dtype='float64') + label = paddle.data(name='label', shape=[10, ], dtype='float64') + nll_loss = paddle.nn.loss.NLLLoss() + res = nll_loss(x, label) + + self.assertRaises(ValueError, test_x_dim_lt_2) + + def test_x_dim_imperative_lt_2(): + with fluid.dygraph.guard(): + x_np = np.array( + [0.88103855, 0.9908683, 0.6226845, 0.53331435, + 0.07999352]).astype(np.float32) + label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + x = paddle.to_variable(x_np) + label = paddle.to_variable(label_np) + nll_loss = paddle.nn.loss.NLLLoss() + res = nll_loss(x, label) + + self.assertRaises(ValueError, test_x_dim_imperative_lt_2) + + def test_reduction_value_error(self): + def test_NLLLoss_reduction_not_sum_mean_none(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + place = paddle.CPUPlace() + with paddle.static.program_guard(prog, startup_prog): + x = paddle.data(name='x', shape=[10, 10], dtype='float64') + label = paddle.data(name='label', shape=[10], dtype='int64') + nll_loss = paddle.nn.loss.NLLLoss(reduction='') + res = nll_loss(x, label) + + self.assertRaises(ValueError, test_NLLLoss_reduction_not_sum_mean_none) + + def test_NLLLoss_reduction_imperative_not_sum_mean_none(): + with fluid.dygraph.guard(): + x_np = np.array( + [[0.88103855, 0.9908683, 0.6226845], + [0.53331435, 0.07999352, 0.8549948], + [0.25879037, 0.39530203, 0.698465], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954, 0.6325046]]).astype(np.float32) + label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + x = paddle.to_variable(x_np) + label = paddle.to_variable(label_np) + nll_loss = paddle.nn.loss.NLLLoss(reduction='') + res = nll_loss(x, label) + + self.assertRaises(ValueError, + test_NLLLoss_reduction_imperative_not_sum_mean_none) + + def test_nll_loss_function_reduction_not_sum_mean_none(): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + place = paddle.CPUPlace() + with paddle.static.program_guard(prog, startup_prog): + x = paddle.data(name='x', shape=[10, 10], dtype='float64') + label = paddle.data(name='label', shape=[10], dtype='int64') + res = paddle.nn.functional.nll_loss(x, label, reduction='') + + self.assertRaises(ValueError, + test_nll_loss_function_reduction_not_sum_mean_none) + + def test_nll_loss_function_reduction_imperative_not_sum_mean_none(): + with fluid.dygraph.guard(): + x_np = np.array( + [[0.88103855, 0.9908683, 0.6226845], + [0.53331435, 0.07999352, 0.8549948], + [0.25879037, 0.39530203, 0.698465], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954, 0.6325046]]).astype(np.float32) + label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + x = paddle.to_variable(x_np) + label = paddle.to_variable(label_np) + res = paddle.nn.functional.nll_loss(x, label, reduction='') + + self.assertRaises( + ValueError, + test_nll_loss_function_reduction_imperative_not_sum_mean_none) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 5ba5c154b4692622827bd47a12ace365b6ac3a9a..d6b88e741c6a88d215397d3383afd203b50fbee5 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -131,6 +131,7 @@ from .loss import l1_loss #DEFINE_ALIAS from .loss import log_loss #DEFINE_ALIAS from .loss import margin_rank_loss #DEFINE_ALIAS from .loss import mse_loss #DEFINE_ALIAS +from .loss import nll_loss #DEFINE_ALIAS # from .loss import nce #DEFINE_ALIAS from .loss import npair_loss #DEFINE_ALIAS from .loss import rank_loss #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 913d3d99710e1b5e50e8cea1be1d61c234cc12c7..4bbfaed81ea24b4e04e0e55a4a7b15c767dd3e6a 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -27,6 +27,7 @@ from ...fluid.layers import log_loss #DEFINE_ALIAS from ...fluid.layers import mse_loss #DEFINE_ALIAS from ...fluid.layers import npair_loss #DEFINE_ALIAS from ...fluid.layers import rank_loss #DEFINE_ALIAS +from ...fluid.layers import reshape from ...fluid.layers import sigmoid_cross_entropy_with_logits #DEFINE_ALIAS from ...fluid.layers import sigmoid_focal_loss #DEFINE_ALIAS from ...fluid.layers import smooth_l1 #DEFINE_ALIAS @@ -39,6 +40,9 @@ from ...fluid.layers import edit_distance #DEFINE_ALIAS from ...fluid.layers import huber_loss #DEFINE_ALIAS from ...fluid.layers import margin_rank_loss #DEFINE_ALIAS from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS +from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode +from ...fluid.framework import Variable __all__ = [ 'bpr_loss', @@ -54,6 +58,7 @@ __all__ = [ 'margin_rank_loss', 'mse_loss', # 'nce', + 'nll_loss', 'npair_loss', 'rank_loss', 'sampled_softmax_with_cross_entropy', @@ -154,3 +159,112 @@ def l1_loss(x, label, reduction='mean', name=None): return paddle.mean(unreduced, name=name) else: return paddle.elementwise_sub(x, label, act='abs', name=name) + + +def nll_loss(input, + label, + weight=None, + ignore_index=-100, + reduction='mean', + name=None): + """ + This api returns negative log likelihood. + See more detail in :ref:`api_nn_loss_NLLLoss` . + + Parameters: + input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes. + But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`. + The data type is float32, float64. + label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`. + The data type is int64. + weight (Tensor, optional): Weight tensor, a manual rescaling weight given + to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise, + it treated as if having all ones. the data type is + float32, float64, Default is ``'None'``. + ignore_index (int64, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. + reduction (str, optional): Indicate how to average the loss, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If `reduction` is ``'mean'``, the reduced mean loss is returned; + if `reduction` is ``'sum'``, the reduced sum loss is returned; + if `reduction` is ``'none'``, no reduction will be apllied. + Default is ``'mean'``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + `Tensor`, the value of negative log likelihood loss. + + Examples: + .. code-block:: python + import paddle + import numpy as np + from paddle.nn.functional import nll_loss + log_softmax = paddle.nn.LogSoftmax(axis=1) + + input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ], + [0.53331435, 0.07999352, 0.8549948 ], + [0.25879037, 0.39530203, 0.698465 ], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32) + label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + + place = paddle.CPUPlace() + paddle.disable_static(place) + input = paddle.to_variable(input_np) + log_out = log_softmax(input) + label = paddle.to_variable(label_np) + result = nll_loss(log_out, label) + print(result.numpy()) # [1.0720209] + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in nll_loss should be 'sum', 'mean' or " + "'none', but received %s, which is not allowed." % reduction) + + input_shape = list(input.shape) + input_dims = len(input_shape) + if input_dims < 2: + raise ValueError('Expected 2 or more dimensions (got {})'.format( + input_dims)) + n = input_shape[0] + c = input_shape[1] + if in_dygraph_mode(): + if input_dims != 2 and input_dims != 4: + input, _ = core.ops.reshape2(input, 'shape', [n, c, 1, -1]) + label, _ = core.ops.reshape2(label, 'shape', [n, 1, -1]) + out_shape = [n] + input_shape[2:] + out, total_weight = core.ops.nll_loss(input, label, weight, + 'ignore_index', ignore_index, + 'reduction', reduction) + if input_dims != 2 and input_dims != 4 and reduction == 'none': + out, _ = core.ops.reshape2(out, 'shape', out_shape) + return out + + helper = LayerHelper('nll_loss', **locals()) + + if input_dims != 2 and input_dims != 4: + input = reshape(input, shape=[n, c, 1, -1]) + label = reshape(label, shape=[n, 1, -1]) + out_shape = [n] + input_shape[2:] + + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'nll_loss') + fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], + 'nll_loss') + inputs = {'X': input, 'Label': label} + attrs = {'reduction': reduction, 'ignore_index': ignore_index} + if weight is not None: + if isinstance(weight, Variable): + inputs['Weight'] = weight + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + total_weight = helper.create_variable_for_type_inference(dtype=input.dtype) + outputs = {'Out': out, 'Total_weight': total_weight} + + helper.append_op( + type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs) + if input_dims != 2 and input_dims != 4 and reduction == 'none': + out = reshape(out, shape=out_shape) + + return out diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 11000e57f232e6c62330602fd6f5af4e750bc35e..006b81c9325221931ca6ece7f31bbaff7aaa6384 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -15,6 +15,7 @@ # TODO: define loss functions of neural network import paddle.fluid as fluid import paddle +from .. import functional as F __all__ = [ # 'NCELoss', @@ -460,11 +461,11 @@ class NLLLoss(fluid.dygraph.Layer): :alias_main: paddle.nn.NLLLoss :alias: paddle.nn.NLLLoss,paddle.nn.layer.NLLLoss,paddle.nn.layer.loss.NLLLoss - This op accepts input and target label and returns negative log likelihood + This class accepts input and target label and returns negative log likelihood cross error. It is useful to train a classification problem with C classes. The input for the loss is epected to contain log-probabilities of - each classes. It hs to be a Tensor of size either (batch_size, C) or + each classes. It has to be a Tensor of size either (batch_size, C) or (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case. The label for the loss should be a class index in the range [0, C-1] where C is the number of classes. If ignore_index is specified, the @@ -494,106 +495,77 @@ class NLLLoss(fluid.dygraph.Layer): \\end{cases} Parameters: - input (Variable): Input tensor, the data type is float32, float64. - label (Variable): Label tensor, the data type is int64_t. - weight (Variable, optional): Weight tensor, a manual rescaling weight given - to each class. If given, it has to be a Tensor of size `C`. Otherwise, + weight (Tensor, optional): Weight tensor, a manual rescaling weight given + to each class. If given, it has to be a 1D Tensor whose size is `[C, ]`. Otherwise, it treated as if having all ones. the data type is float32, float64, Default is ``'None'``. + ignore_index (int64, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. reduction (str, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. - If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If `reduction` is ``'mean'``, the reduced mean loss is returned; + if `reduction` is ``'sum'``, the reduced sum loss is returned; + if `reduction` is ``'none'``, no reduction will be apllied. Default is ``'mean'``. - ignore_index (int64, optional): Specifies a target value that is ignored - and does not contribute to the input gradient. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. - Returns: - The tensor variable storing the nll_loss. + Shape: + input (Tensor): Input tensor, the shape is :math:`[N, C]`, `C` is the number of classes. + But in K-dimension situation, the shape is :math:`[N, C, d_1, d_2, ..., d_K]`. + The data type is float32, float64. + label (Tensor): Label tensor, the shape is :math:`[N,]` or :math:`[N, d_1, d_2, ..., d_K]`. + The data type is int64. + output (Tensor): the `negative log likelihood loss` between input `x` and `label`. + If `reduction` is `'none'`, the shape is `[N, *]`. + If `reduction` is `'sum'` or `'mean'`, the shape is `[1]`. - Return type: Variable. - Examples: - .. code-block:: python - # declarative mode - import paddle.fluid as fluid - import numpy as np - import paddle + import paddle + import numpy as np - input_np = np.random.random(size=(10, 10)).astype(np.float32) - label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64) - prog = fluid.Program() - startup_prog = fluid.Program() - place = fluid.CPUPlace() - with fluid.program_guard(prog, startup_prog): - input = fluid.data(name='input', shape=[10, 10], dtype='float32') - label = fluid.data(name='label', shape=[10], dtype='int64') - nll_loss = paddle.nn.loss.NLLLoss() - res = nll_loss(input, label) - - exe = fluid.Executor(place) - static_result = exe.run( - prog, - feed={"input": input_np, - "label": label_np}, - fetch_list=[res]) - print(static_result) - - # imperative mode - import paddle.fluid.dygraph as dg - with dg.guard(place) as g: - input = dg.to_variable(input_np) - label = dg.to_variable(label_np) - output = nll_loss(input, label) - print(output.numpy()) - """ + nll_loss = paddle.nn.layer.NLLLoss() + log_softmax = paddle.nn.LogSoftmax(axis=1) - def __init__(self, weight=None, reduction='mean', ignore_index=-100): - super(NLLLoss, self).__init__() - self.weight = weight - self.reduction = reduction - self.ignore_index = ignore_index + input_np = np.array([[0.88103855, 0.9908683 , 0.6226845 ], + [0.53331435, 0.07999352, 0.8549948 ], + [0.25879037, 0.39530203, 0.698465 ], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954 , 0.6325046 ]]).astype(np.float32) + label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) - def forward(self, input, label): - dtype = self._helper.input_dtype(input) + place = paddle.CPUPlace() + paddle.disable_static(place) + input = paddle.to_variable(input_np) + log_out = log_softmax(input) + label = paddle.to_variable(label_np) + result = nll_loss(log_out, label) + print(result.numpy()) # [1.0720209] - fluid.data_feeder.check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'nll_loss') - fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], - 'nll_loss') + """ - if self.reduction not in ['sum', 'mean', 'none']: + def __init__(self, + weight=None, + ignore_index=-100, + reduction='mean', + name=None): + if reduction not in ['sum', 'mean', 'none']: raise ValueError( - "The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but " - "received %s, which is not allowed." % self.reduction) - - x_shape = list(input.shape) - n = x_shape[0] - c = x_shape[1] - x_dims = len(x_shape) - if x_dims < 2: - raise ValueError('Expected 2 or more dimensions (got {})'.format( - x_dims)) - if x_dims != 2 and x_dims != 4: - input = fluid.layers.reshape(input, shape=[n, c, 1, -1]) - label = fluid.layers.reshape(label, shape=[n, 1, -1]) - out_shape = [n] + x_shape[2:] - - inputs = {'X': input, 'Label': label} - attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index} - if self.weight is not None: - if isinstance(self.weight, fluid.framework.Variable): - inputs['Weight'] = self.weight - - out = self._helper.create_variable_for_type_inference(dtype=input.dtype) - total_weight = self._helper.create_variable_for_type_inference( - dtype=input.dtype) - outputs = {'Out': out, 'Total_weight': total_weight} - - self._helper.append_op( - type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs) - if x_dims != 2 and x_dims != 4 and self.reduction == 'none': - out = fluid.layers.reshape(out, shape=out_shape) + "The value of 'reduction' in nll_loss should be 'sum', 'mean' or " + "'none', but received %s, which is not allowed." % reduction) + super(NLLLoss, self).__init__() + self._weight = weight + self._ignore_index = ignore_index + self._reduction = reduction + self._name = name - return out + def forward(self, input, label): + return F.nll_loss( + input, + label, + weight=self._weight, + ignore_index=self._ignore_index, + reduction=self._reduction, + name=self._name)