From 5407e327a472c1c9144463b16ad83206a0568c1a Mon Sep 17 00:00:00 2001 From: chajchaj <57249073+chajchaj@users.noreply.github.com> Date: Fri, 21 Aug 2020 14:05:35 +0800 Subject: [PATCH] add cross_entropy to nn/layer and nn/functional, test=develop (#26478) * add cross_entropy to nn/layer and nn/functional, test=develop * use functional/cross_entropy in layer/CrossEntropy * use functional/cross_entropy in layer/CrossEntropy, test=develop --- .../unittests/test_cross_entropy_loss.py | 438 ++++++++++++++++++ python/paddle/nn/functional/loss.py | 130 +++++- python/paddle/nn/layer/loss.py | 19 +- 3 files changed, 573 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 7f667d6b71c..4982cd19582 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -535,5 +535,443 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(dy_ret_value, expected)) +class FuncCrossEntropyLoss(unittest.TestCase): + #1 + def test_cross_entropy_loss_1d_with_weight_mean(self): + input_np = np.random.random([100, 200]).astype(np.float64) + label_np = np.random.randint(0, 100, size=(100, )).astype(np.int64) + weight_np = np.random.random([200]).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[100, 200], dtype='float64') + label = fluid.data(name='label', shape=[100], dtype='int64') + weight = fluid.data(name='weight', shape=[200], dtype='float64') + ret = paddle.nn.functional.cross_entropy( + input, label, weight=weight) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + weight=fluid.dygraph.to_variable(weight_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d( + input_np, label_np, weight=weight_np)[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #2 + def test_cross_entropy_loss_1d_with_weight_sum(self): + input_np = np.random.random([100, 200]).astype(np.float64) + label_np = np.random.randint(0, 100, size=(100, )).astype(np.int64) + weight_np = np.random.random([200]).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[100, 200], dtype='float64') + label = fluid.data(name='label', shape=[100], dtype='int64') + weight = fluid.data(name='weight', shape=[200], dtype='float64') + ret = paddle.nn.functional.cross_entropy( + input, label, weight=weight, reduction='sum') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + weight=fluid.dygraph.to_variable(weight_np), + reduction='sum') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d( + input_np, label_np, weight=weight_np, reduction='sum')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #3 + def test_cross_entropy_loss_1d_with_weight_none(self): + input_np = np.random.random([100, 200]).astype(np.float64) + label_np = np.random.randint(0, 100, size=(100, )).astype(np.int64) + weight_np = np.random.random([200]).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[100, 200], dtype='float64') + label = fluid.data(name='label', shape=[100], dtype='int64') + weight = fluid.data(name='weight', shape=[200], dtype='float64') + ret = paddle.nn.functional.cross_entropy( + input, label, weight=weight, reduction='none') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + weight=fluid.dygraph.to_variable(weight_np), + reduction='none') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d( + input_np, label_np, weight=weight_np, reduction='none') + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #4 + def test_cross_entropy_loss_1d_mean(self): + input_np = np.random.random([100, 200]).astype(np.float64) + label_np = np.random.randint(0, 100, size=(100, )).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[100, 200], dtype='float64') + label = fluid.data(name='label', shape=[100], dtype='int64') + ret = paddle.nn.functional.cross_entropy(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={'input': input_np, + 'label': label_np}, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d(input_np, label_np)[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #5 + def test_cross_entropy_loss_1d_sum(self): + input_np = np.random.random([100, 200]).astype(np.float64) + label_np = np.random.randint(0, 100, size=(100, )).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[100, 200], dtype='float64') + label = fluid.data(name='label', shape=[100], dtype='int64') + ret = paddle.nn.functional.cross_entropy( + input, label, reduction='sum') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={'input': input_np, + 'label': label_np}, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + reduction='sum') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d(input_np, label_np, reduction='sum')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #6 + def test_cross_entropy_loss_1d_none(self): + input_np = np.random.random([100, 200]).astype(np.float64) + label_np = np.random.randint(0, 100, size=(100, )).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[100, 200], dtype='float64') + label = fluid.data(name='label', shape=[100], dtype='int64') + ret = paddle.nn.functional.cross_entropy( + input, label, reduction='none') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={'input': input_np, + 'label': label_np}, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + reduction='none') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d(input_np, label_np, reduction='none') + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #7 + def test_cross_entropy_loss_2d_with_weight_none(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + ret = paddle.nn.functional.cross_entropy( + input, label, weight=weight, reduction='none') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + weight=fluid.dygraph.to_variable(weight_np), + reduction='none') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + input_np, label_np, weight=weight_np, reduction='none') + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #8 + def test_cross_entropy_loss_2d_with_weight_mean(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + ret = paddle.nn.functional.cross_entropy( + input, label, weight=weight, reduction='mean') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + weight=fluid.dygraph.to_variable(weight_np), + reduction='mean') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + input_np, label_np, weight=weight_np, reduction='mean')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #9 + def test_cross_entropy_loss_2d_with_weight_sum(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + weight_np = np.random.random(size=(3, )).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype='float64') + ret = paddle.nn.functional.cross_entropy( + input, label, weight=weight, reduction='sum') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + weight=fluid.dygraph.to_variable(weight_np), + reduction='sum') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + input_np, label_np, weight=weight_np, reduction='sum')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #10 + def test_cross_entropy_loss_2d_none(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + ret = paddle.nn.functional.cross_entropy( + input, label, reduction='none') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + reduction='none') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d(input_np, label_np, reduction='none') + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #11 + def test_cross_entropy_loss_2d_mean(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + ret = paddle.nn.functional.cross_entropy( + input, label, reduction='mean') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + reduction='mean') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + input_np, label_np, reduction='mean')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + #12 + def test_cross_entropy_loss_2d_sum(self): + input_np = np.random.random(size=(5, 3, 5, 5)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, 5, 5)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[5, 3, 5, 5], dtype='float64') + label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') + ret = paddle.nn.functional.cross_entropy( + input, label, reduction='sum') + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + dy_ret = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np), + reduction='sum') + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d(input_np, label_np, reduction='sum')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 1bc8f0f94a6..9a214d3982a 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -22,7 +22,6 @@ from ...fluid.framework import core, in_dygraph_mode from ...fluid.layers.nn import _elementwise_op_in_dygraph from ...fluid.layers import bpr_loss #DEFINE_ALIAS from ...fluid.layers import center_loss #DEFINE_ALIAS -from ...fluid.layers import cross_entropy #DEFINE_ALIAS from ...fluid.layers import dice_loss #DEFINE_ALIAS from ...fluid.layers import iou_similarity #DEFINE_ALIAS from ...fluid.layers import log_loss #DEFINE_ALIAS @@ -786,3 +785,132 @@ def mse_loss(input, label, reduction='mean', name=None): return paddle.sum(paddle.fluid.layers.square( paddle.fluid.layers.elementwise_sub(input, label)), name=name) + + +def cross_entropy(input, + label, + weight=None, + ignore_index=-100, + reduction='mean'): + """ + This operator implements the cross entropy loss function. This OP combines ``LogSoftmax``, + and ``NLLLoss`` together. + + It is useful when training a classification problem with ``C`` classes. + If provided, the optional argument ``weight`` should be a 1D Variable assigning + weight to each of the classes. + + For predictions label, and target label, the loss is calculated as follows. + + .. math:: + + loss_j = -\\text{input[class]} + + \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K + + If weight is not ``None``: + + .. math:: + + loss_j = \\text{weight[class]}(-\\text{input[class]} + + \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K + + Parameters: + input (Tensor): Input tensor, the data type is float32, float64. Shape is + (N, C), where C is number of classes, and if shape is more than 2D, this + is (N, C, D1, D2,..., Dk), k >= 1. + label (Tensor): Label tensor, the data type is int64. Shape is (N), where each + value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is + (N, D1, D2,..., Dk), k >= 1. + weight (Tensor, optional): Weight tensor, a manual rescaling weight given + to each class and the shape is (C). It has the same dimensions as class + number and the data type is float32, float64. Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned. + Default is ``'mean'``. + ignore_index (int64, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. Default is ``-100``. + + Returns: + The tensor variable storing the cross_entropy_loss of input and label. + + Return type: Tensor. + + Examples: + .. code-block:: python + + import paddle + paddle.disable_static() + input_data = np.random.random([5, 100]).astype("float64") + label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) + weight_data = np.random.random([100]).astype("float64") + input = paddle.to_tensor(input_data) + label = paddle.to_tensor(label_data) + weight = paddle.to_tensor(weight_data) + loss = paddle.nn.functional.cross_entropy(input=input, label=label, weight=weight) + print(loss.numpy()) + + """ + if not in_dygraph_mode(): + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'cross_entropy_loss') + fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], + 'cross_entropy_loss') + + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or" + " 'none', but received %s, which is not allowed." % reduction) + + #step 1. log_softmax + log_softmax_out = paddle.nn.functional.log_softmax(input) + if weight is not None and not isinstance(weight, Variable): + raise ValueError( + "The weight' is not a Variable, please convert to Variable.") + + #step 2. nll_loss + input = log_softmax_out + helper = LayerHelper('nll_loss', **locals()) + dtype = helper.input_dtype(input) + + if not in_dygraph_mode(): + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'nll_loss') + fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], + 'nll_loss') + + x_shape = list(input.shape) + n = x_shape[0] + c = x_shape[1] + x_dims = len(x_shape) + if x_dims < 2: + raise ValueError('Expected 2 or more dimensions (got {})'.format( + x_dims)) + if x_dims != 2 and x_dims != 4: + input = reshape(input, shape=[n, c, 1, -1]) + label = reshape(label, shape=[n, 1, -1]) + out_shape = [n] + x_shape[2:] + + if not in_dygraph_mode(): + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'nll_loss') + fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'], + 'nll_loss') + inputs = {'X': input, 'Label': label} + attrs = {'reduction': reduction, 'ignore_index': ignore_index} + if weight is not None: + if isinstance(weight, Variable): + inputs['Weight'] = weight + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + total_weight = helper.create_variable_for_type_inference(dtype=input.dtype) + outputs = {'Out': out, 'Total_weight': total_weight} + + helper.append_op( + type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs) + if x_dims != 2 and x_dims != 4 and reduction == 'none': + out = reshape(out, shape=out_shape) + + return out diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index e1c323ebf3e..8f5c4cf8459 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -21,7 +21,6 @@ from .. import functional as F from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator __all__ = [ - # 'NCELoss', 'CrossEntropyLoss', 'MSELoss', 'L1Loss', @@ -119,7 +118,7 @@ class CrossEntropyLoss(fluid.dygraph.Layer): print(output.numpy()) """ - def __init__(self, weight=None, reduction='mean', ignore_index=-100): + def __init__(self, weight=None, ignore_index=-100, reduction='mean'): super(CrossEntropyLoss, self).__init__() self.weight = weight self.reduction = reduction @@ -137,18 +136,12 @@ class CrossEntropyLoss(fluid.dygraph.Layer): " 'none', but received %s, which is not allowed." % self.reduction) - log_softmax = paddle.nn.LogSoftmax() - log_softmax_out = log_softmax(input) - if self.weight is not None and not isinstance(self.weight, - fluid.framework.Variable): - raise ValueError( - "The weight' is not a Variable, please convert to Variable.") - nll_loss = paddle.nn.loss.NLLLoss( + return paddle.nn.functional.cross_entropy( + input, + label, weight=self.weight, - reduction=self.reduction, - ignore_index=self.ignore_index) - - return nll_loss(log_softmax_out, label) + ignore_index=self.ignore_index, + reduction=self.reduction) class MSELoss(fluid.dygraph.layers.Layer): -- GitLab