diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 81e2160a556d2fddf0e970e5a68315a7ec39f724..1a5e4b28355674010ba8f92b176d5cabca3e1a8d 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -18,6 +18,8 @@ import paddle import paddle.fluid as fluid import numpy as np import unittest +from test_softmax_op import stable_softmax +from test_softmax_with_cross_entropy_op import cross_entropy def stable_softmax(x): @@ -42,6 +44,8 @@ def cross_entropy_loss_1d(input, C = input_shape[1] out = np.zeros_like(label).astype(np.float64) total_weight = 0 + ###1. compute softmax cross_entropy (with weight) + ### Note: only support hard labels. for i in range(N): cur_target = label[i] if cur_target == ignore_index: @@ -50,6 +54,8 @@ def cross_entropy_loss_1d(input, cur_weight = weight[cur_target] if weight is not None else 1 total_weight += cur_weight out[i] = -log_softmax_out[i][cur_target] * cur_weight + + ###2. deal with reduction if reduction == 'sum': return np.sum(out), np.array([total_weight]).astype('float64') elif reduction == 'mean': @@ -92,7 +98,620 @@ def cross_entropy_loss_2d(input, return out +def cross_entropy_soft(softmax, + label, + axis, + N, + weight=None, + reduction='mean', + ignore_index=-100): + #1.loss + loss = cross_entropy( + softmax, + label, + True, #soft_label, + axis, + ignore_index) + + if weight is None and reduction == 'none': + return loss + + #2.weight + weighted_loss = loss + total_weight = N #for weight is None + if weight is not None: + weighted_loss = np.zeros_like(loss).astype(np.float64) + total_weight = 0 + for i in range(N): + cur_soft_label = label[i] + cur_weight = np.dot(weight, cur_soft_label) + total_weight += cur_weight + weighted_loss[i] = loss[i] * cur_weight + + #3.reduce + if reduction == 'none': + return weighted_loss + + elif reduction == 'mean': + weighted_loss_sum = np.sum(weighted_loss) + weighted_loss_mean = weighted_loss_sum / total_weight + return weighted_loss_mean + + else: + weighted_loss_sum = np.sum(weighted_loss) + return weighted_loss_sum + + +def cross_entropy_soft_2d(softmax, + label, + axis, + N, + H, + W, + weight=None, + reduction='mean', + ignore_index=-100): + #1.loss + loss = cross_entropy( + softmax, + label, + True, #soft_label, + axis, + ignore_index) + + if weight is None and reduction == 'none': + return loss + + #2.weight + weighted_loss = loss + total_weight = N #for weight is None + if weight is not None: + weighted_loss = np.zeros_like(loss).astype(np.float64) + total_weight = 0 + for i in range(N): + for h in range(H): + for w in range(W): + cur_soft_label = label[i][h][w] + cur_weight = np.dot(weight, cur_soft_label) + total_weight += cur_weight + weighted_loss[i][h][w] = loss[i][h][w] * cur_weight + + #3.reduce + if reduction == 'none': + return weighted_loss + + elif reduction == 'mean': + weighted_loss_sum = np.sum(weighted_loss) + weighted_loss_mean = weighted_loss_sum / total_weight + return weighted_loss_mean + + else: + weighted_loss_sum = np.sum(weighted_loss) + return weighted_loss_sum + + class CrossEntropyLoss(unittest.TestCase): + + ###test for deprecated softmax_with_cross_entropy + def test_softmax_with_cross_entropy(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 4 + self.C = 3 + self.shape = [self.N, self.C] + self.use_softmax = True + self.reduction = 'none' + self.weight = None + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + self.labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + + expected = cross_entropy_soft( + softmax, + self.labels, + self.axis, + self.N, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + + paddle.disable_static() + paddle_loss_swce = paddle.nn.functional.softmax_with_cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis) + + paddle_loss_ce = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=fluid.dygraph.to_variable(self.weight) + if self.weight is not None else None, + reduction=self.reduction) + + self.assertTrue(np.allclose(paddle_loss_swce.numpy(), expected)) + self.assertTrue(np.allclose(paddle_loss_ce.numpy(), expected)) + + ###soft_label test start + ###soft_label test 1 + def test_cross_entropy_loss_soft_1d(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 4 + self.C = 3 + self.shape = [self.N, self.C] + self.use_softmax = True + self.reduction = 'none' + self.weight = None + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + self.labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + + expected = cross_entropy_soft( + softmax, + self.labels, + self.axis, + self.N, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + + #2. dygraph + paddle.disable_static() + paddle_loss_none_weight = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=fluid.dygraph.to_variable(self.weight) + if self.weight is not None else None, + reduction=self.reduction) + dy_ret_value = paddle_loss_none_weight.numpy() + + #3. static + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[self.N, self.C], dtype='float64') + label = fluid.data( + name='label', shape=[self.N, self.C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + reduction=self.reduction, soft_label=True) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': self.logits, + 'label': self.labels, + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + paddle.disable_static() + + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + ###soft_label test 2 + def test_cross_entropy_loss_soft_1d_weight(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 4 + self.C = 3 + self.shape = [self.N, self.C] + self.use_softmax = True + self.reduction = 'none' + self.weight = np.random.uniform(0.1, 1.0, self.C).astype(self.dtype) + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + if self.soft_label: + self.labels = np.random.uniform(0.1, 1.0, + self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + else: + axis_dim = self.shape[self.axis] + self.shape[self.axis] = 1 + self.labels = np.random.randint( + 0, axis_dim, self.shape, dtype="int64") + + #1. numpy + expected = cross_entropy_soft( + softmax, + self.labels, + self.axis, + self.N, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + + #2. dygraph + paddle.disable_static() + paddle_loss_none_weight = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=fluid.dygraph.to_variable(self.weight), + reduction=self.reduction) + dy_ret_value = paddle_loss_none_weight.numpy() + + # 3.static + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[self.N, self.C], dtype='float64') + label = fluid.data( + name='label', shape=[self.N, self.C], dtype='float64') + weight = fluid.data(name='weight', shape=[self.C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction=self.reduction, soft_label=True) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': self.logits, + 'label': self.labels, + "weight": self.weight + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + paddle.disable_static() + + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + ###soft_label test 3 + def test_cross_entropy_loss_soft_1d_mean(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 4 + self.C = 3 + self.shape = [self.N, self.C] + self.use_softmax = True + self.reduction = 'mean' + self.weight = None + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + self.labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + + #1. numpy + expected = cross_entropy_soft( + softmax, + self.labels, + self.axis, + self.N, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + + #2 dygraph + paddle.disable_static() + paddle_loss_mean = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=self.weight, + reduction=self.reduction) + dy_ret_value = paddle_loss_mean.numpy() + + #3. static + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[self.N, self.C], dtype='float64') + label = fluid.data( + name='label', shape=[self.N, self.C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + reduction=self.reduction, soft_label=True) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run( + prog, + feed={'input': self.logits, + 'label': self.labels}, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + paddle.disable_static() + + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + ###soft_label test 4 + def test_cross_entropy_loss_soft_1d_weight_mean(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 4 + self.C = 3 + self.shape = [self.N, self.C] + self.use_softmax = True + self.reduction = 'mean' + self.weight = np.random.uniform(0.1, 1.0, self.C).astype(self.dtype) + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + self.labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + + #1. numpy + expected = cross_entropy_soft( + softmax, + self.labels, + self.axis, + self.N, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + paddle.disable_static() + + #2. dygraph + paddle_loss_none_weight = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=fluid.dygraph.to_variable(self.weight), + reduction=self.reduction) + dy_ret_value = paddle_loss_none_weight.numpy() + + #3. static + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[self.N, self.C], dtype='float64') + label = fluid.data( + name='label', shape=[self.N, self.C], dtype='float64') + weight = fluid.data(name='weight', shape=[self.C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction=self.reduction, soft_label=True) + ret = cross_entropy_loss(input, label) + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': self.logits, + 'label': self.labels, + "weight": self.weight + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + paddle.disable_static() + + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + ###soft_label test 5 + def test_cross_entropy_loss_soft_2d(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 3 + self.H = 2 + self.W = 2 + self.C = 5 + self.shape = [self.N, self.H, self.W, self.C] + self.use_softmax = True + self.reduction = 'none' + self.weight = None + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + self.labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + + #1. numpy + expected = cross_entropy_soft_2d( + softmax, + self.labels, + self.axis, + self.N, + self.H, + self.W, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + paddle.disable_static() + + #2. dygraph + paddle_loss_none_weight = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=fluid.dygraph.to_variable(self.weight) + if self.weight is not None else None, + reduction=self.reduction) + dy_ret_value = paddle_loss_none_weight.numpy() + + #3. static + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', + shape=[self.N, self.H, self.W, self.C], + dtype='float64') + label = fluid.data( + name='label', + shape=[self.N, self.H, self.W, self.C], + dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + reduction=self.reduction, soft_label=True) + ret = cross_entropy_loss(input, label) + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': self.logits, + 'label': self.labels, + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + paddle.disable_static() + + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + ###soft_label test 6 + def test_cross_entropy_loss_soft_2d_weight_mean(self): + self.numeric_stable_mode = False + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -100 #should not be changed + self.N = 3 + self.H = 2 + self.W = 2 + self.C = 5 + self.shape = [self.N, self.H, self.W, self.C] + self.use_softmax = True + self.reduction = 'mean' + self.weight = np.random.uniform(0.1, 1.0, self.C).astype(self.dtype) + self.logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, self.logits) + + self.labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + self.labels /= np.sum(self.labels, axis=self.axis, keepdims=True) + + #1. numpy + expected = cross_entropy_soft_2d( + softmax, + self.labels, + self.axis, + self.N, + self.H, + self.W, + weight=self.weight, + reduction=self.reduction, + ignore_index=self.ignore_index) + + paddle.set_device("cpu") + paddle.disable_static() + + #2. dygraph + paddle_loss_none_weight = paddle.nn.functional.cross_entropy( + fluid.dygraph.to_variable(self.logits), + fluid.dygraph.to_variable(self.labels), + soft_label=True, + axis=self.axis, + weight=fluid.dygraph.to_variable(self.weight), + reduction=self.reduction) + dy_ret_value = paddle_loss_none_weight.numpy() + + #3. static + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', + shape=[self.N, self.H, self.W, self.C], + dtype='float64') + label = fluid.data( + name='label', + shape=[self.N, self.H, self.W, self.C], + dtype='float64') + weight = fluid.data(name='weight', shape=[self.C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction=self.reduction, soft_label=True) + ret = cross_entropy_loss(input, label) + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': self.logits, + 'label': self.labels, + "weight": self.weight + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + paddle.disable_static() + + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + ###soft_label test end + def test_cross_entropy_loss_1d_with_mean_ignore(self): input_np = np.random.random([2, 4]).astype(np.float64) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) @@ -131,19 +750,21 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(dy_ret_value, expected)) def test_cross_entropy_loss_1d_with_weight_mean_ignore(self): - input_np = np.random.random([2, 4]).astype(np.float64) - label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) - weight_np = np.random.random([4]).astype(np.float64) #shape:C + N = 100 + C = 200 + input_np = np.random.random([N, C]).astype(np.float64) + label_np = np.random.randint(0, C, size=(N)).astype(np.int64) + weight_np = np.random.random([C]).astype(np.float64) paddle.enable_static() prog = fluid.Program() startup_prog = fluid.Program() place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( ) else fluid.CPUPlace() with fluid.program_guard(prog, startup_prog): - input = fluid.data(name='input', shape=[2, 4], dtype='float64') - label = fluid.data(name='label', shape=[2], dtype='int64') + input = fluid.data(name='input', shape=[N, C], dtype='float64') + label = fluid.data(name='label', shape=[N], dtype='int64') weight = fluid.data( - name='weight', shape=[4], + name='weight', shape=[C], dtype='float64') #weight for each class cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( weight=weight, ignore_index=0) @@ -158,8 +779,6 @@ class CrossEntropyLoss(unittest.TestCase): }, fetch_list=[ret]) self.assertIsNotNone(static_ret) - expected = cross_entropy_loss_1d( - input_np, label_np, weight=weight_np)[0] with fluid.dygraph.guard(): cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( @@ -173,6 +792,7 @@ class CrossEntropyLoss(unittest.TestCase): self.assertIsNotNone(dy_ret_value) expected = cross_entropy_loss_1d( input_np, label_np, weight=weight_np, ignore_index=0)[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) @@ -265,6 +885,7 @@ class CrossEntropyLoss(unittest.TestCase): input_np = np.random.random([100, 200]).astype(np.float64) #N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) #N,1 weight_np = np.random.random([200]).astype(np.float64) #C + paddle.enable_static() prog = fluid.Program() startup_prog = fluid.Program() @@ -274,6 +895,7 @@ class CrossEntropyLoss(unittest.TestCase): input = fluid.data(name='input', shape=[100, 200], dtype='float64') label = fluid.data(name='label', shape=[100], dtype='int64') weight = fluid.data(name='weight', shape=[200], dtype='float64') + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( weight=weight, reduction='none') ret = cross_entropy_loss(input, label) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c223addc2607bf0b169f24444aca738f557e703d..1dad1632e264a64895e79123d20697d20bfb5d34 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -* # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,7 +28,7 @@ from ...fluid.layers import dice_loss #DEFINE_ALIAS from ...fluid.layers import log_loss #DEFINE_ALIAS from ...fluid.layers import npair_loss #DEFINE_ALIAS from ...fluid.layers import reshape -from ...fluid.layers import softmax_with_cross_entropy #DEFINE_ALIAS +from ...fluid.layers import softmax_with_cross_entropy as fluid_softmax_with_cross_entropy #DEFINE_ALIAS from ...fluid.layers import square_error_cost #DEFINE_ALIAS from ...fluid.layers import edit_distance #DEFINE_ALIAS @@ -36,6 +37,7 @@ from ...fluid.layer_helper import LayerHelper from ...fluid.framework import in_dygraph_mode from ...fluid.framework import _varbase_creator from ...fluid.framework import Variable +from paddle.utils import deprecated __all__ = [ 'binary_cross_entropy', @@ -682,7 +684,6 @@ def l1_loss(input, label, reduction='mean', name=None): import paddle - paddle.disable_static() input = paddle.to_tensor([[1.5, 0.8], [0.2, 1.3]]) label = paddle.to_tensor([[1.7, 1], [0.4, 0.5]]) @@ -1112,6 +1113,19 @@ def ctc_loss(log_probs, return loss_out +@deprecated(since="2.0.0", update_to="paddle.nn.functional.cross_entropy") +def softmax_with_cross_entropy(logits, + label, + soft_label=False, + ignore_index=-100, + numeric_stable_mode=True, + return_softmax=False, + axis=-1): + return fluid_softmax_with_cross_entropy(logits, label, soft_label, + ignore_index, numeric_stable_mode, + return_softmax, axis) + + def cross_entropy(input, label, weight=None, @@ -1119,87 +1133,248 @@ def cross_entropy(input, reduction='mean', soft_label=False, axis=-1, + use_softmax=True, name=None): r""" - This operator implements the cross entropy loss function with softmax. This function + By default, this operator implements the cross entropy loss function with softmax. This function combines the calculation of the softmax operation and the cross entropy loss function - to provide a more numerically stable gradient. - Because this operator performs a softmax on logits internally, it expects - unscaled logits. This operator should not be used with the output of - softmax operator since that would produce incorrect results. + to provide a more numerically stable computing. - When the attribute :attr:`soft_label` is set :attr:`False`, this operators - expects mutually exclusive hard labels, each sample in a batch is in exactly - one class with a probability of 1.0. Each sample in the batch will have a - single label. + This operator will calculate the cross entropy loss function without softmax when use_softmax=False. - The equation is as follows: + By default, this operator will calculate the mean of the result, and you can also affect + the default behavior by using the reduction parameter. Please refer to the part of + parameters for details. - 1) Hard label (one-hot label, so every sample has exactly one class) + This operator can be used to calculate the softmax cross entropy loss with soft and hard labels. + Where, the hard labels mean the actual label value, 0, 1, 2, etc. And the soft labels + mean the probability of the actual label, 0.6, 0.8, 0.2, etc. - .. math:: + The calculation of this operator includes the following two steps. - loss_j = -\\text{logits}_{label_j} + - \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K + - **1.softmax cross entropy** - 2) Soft label (each sample can have a distribution over all classes) + 1. Hard label (each sample can only be assigned into one category) - .. math:: + 1.1. when use_softmax=True - loss_j = -\\sum_{i=0}^{K}\\text{label}_i - \\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K} - \\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K + .. math:: + \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N - - It is useful when training a classification problem with ``C`` classes. + where, N is the number of samples and C is the number of categories. + + 1.2. when use_softmax=False + + .. math:: + \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N + + where, N is the number of samples and C is the number of categories, P is input(the output of softmax). + + + 2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1). + + 2.1. when use_softmax=True + + .. math:: + \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N + + where, N is the number of samples and C is the number of categories. + + 2.2. when use_softmax=False + + .. math:: + \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N + + where, N is the number of samples and C is the number of categories, P is input(the output of softmax). + + + + + - **2. Weight and reduction processing** + + 1. Weight + + If the ``weight`` parameter is ``None`` , go to the next step directly. + + If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight + according to soft_label = False or True as follows. + + 1.1. Hard labels (soft_label = False) + + .. math:: + \\loss_j=loss_j*weight[label_j] + 1.2. Soft labels (soft_label = True) + + .. math:: + \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right) + + 2. reduction + + 2.1 if the ``reduction`` parameter is ``none`` + + Return the previous result directly + + 2.2 if the ``reduction`` parameter is ``sum`` + + Return the sum of the previous results + + .. math:: + \\loss=\sum_{j}loss_j + + 2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to + the ``weight`` parameter as follows. + + 2.3.1. If the ``weight`` parameter is ``None`` + + Return the average value of the previous results + + .. math:: + \\loss=\sum_{j}loss_j/N + + where, N is the number of samples and C is the number of categories. + + 2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned + + 1. Hard labels (soft_label = False) + + .. math:: + \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] + + 2. Soft labels (soft_label = True) + + .. math:: + \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right) + + Parameters: - input (Tensor): Input tensor, the data type is float32, float64. Shape is - (N, C), where C is number of classes, and if shape is more than 2D, this - is (N, D1, D2,..., Dk, C), k >= 1. - label (Tensor): Label tensor, the data type is int64. Shape is (N), where each - value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is - (N, D1, D2,..., Dk), k >= 1. - weight (Tensor, optional):a manual rescaling weight given to each class. + + - **input** (Tensor) + + Input tensor, the data type is float32, float64. Shape is + :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes , ``k >= 1`` . + + Note: + + 1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the + output of softmax operator, which will produce incorrect results. + + 2. when use_softmax=False, it expects the output of softmax operator. + + - **label** (Tensor) + + 1. If soft_label=False, the shape is + :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1. + the data type is int32, int64, float32, float64, where each value is [0, C-1]. + + 2. If soft_label=True, the shape and data type should be same with ``input`` , + and the sum of the labels for each sample should be 1. + + - **weight** (Tensor, optional) + + a manual rescaling weight given to each class. If given, has to be a Tensor of size C and the data type is float32, float64. - Default is ``'None'``. - reduction (str, optional): Indicate how to average the loss by batch_size, + Default is ``'None'`` . + + - **ignore_index** (int64, optional) + + Specifies a target value that is ignored + and does not contribute to the loss. A negative value means that no label + value needs to be ignored. Only valid when soft_label = False. + Default is ``-100`` . + + - **reduction** (str, optional) + + Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. If :attr:`reduction` is ``'none'``, the unreduced loss is returned. Default is ``'mean'``. - ignore_index (int64, optional): Specifies a target value that is ignored - and does not contribute to the input gradient. Default is ``-100``. - soft_label (bool): indicate whether label is soft. Default False, meaning that - the label is hard. If soft_label=True, the label is soft. - axis (int, optional): The index of dimension to perform softmax calculations. It - should be in range :math:`[-1, rank - 1]`, while :math:`rank` - is the rank of input :attr:`logits`. Default: -1. + - **soft_label** (bool, optional) + + Indicate whether label is soft. + Default is ``False``. + + - **axis** (int, optional) + + The index of dimension to perform softmax calculations. + It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the + number of dimensions of input :attr:`input`. + Default is ``-1`` . + + - **use_softmax** (bool, optional) + + Indicate whether compute softmax before cross_entropy. + Default is ``True``. + + - **name** (str,optional) + + The name of the operator. Default is ``None`` . + For more information, please refer to :ref:`api_guide_Name` . Returns: - Tensor.The tensor storing the cross_entropy_loss of input and label. + Tensor. Return the softmax cross_entropy loss of ``input`` and ``label``. + The data type is the same as input. - Examples: - .. code-block:: python + If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``. - import paddle - import numpy as np + If :attr:`reduction` is ``'none'``: - input_data = np.random.random([5, 100]).astype("float64") - label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) - weight_data = np.random.random([100]).astype("float64") + 1. If soft_label = False, the dimension of return value is the same with ``label`` . - input = paddle.to_tensor(input_data) - label = paddle.to_tensor(label_data) - weight = paddle.to_tensor(weight_data) + 2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . + + + Example1(hard labels): + + .. code-block:: python + + import paddle + paddle.seed(99999) + N=100 + C=200 + reduction='mean' + input = paddle.rand([N, C], dtype='float64') + label = paddle.randint(0, C, shape=[N], dtype='int64') + weight = paddle.rand([C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction=reduction) + dy_ret = cross_entropy_loss( + input, + label) + print(dy_ret.numpy()) #[5.41993642] + + + Example2(soft labels): + + .. code-block:: python + + import paddle + paddle.seed(99999) + axis = -1 + ignore_index = -100 + N = 4 + C = 3 + shape = [N, C] + reduction='mean' + weight = None + logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0) + labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0) + labels /= paddle.sum(labels, axis=axis, keepdim=True) + paddle_loss_mean = paddle.nn.functional.cross_entropy( + logits, + labels, + soft_label=True, + axis=axis, + weight=weight, + reduction=reduction) + print(paddle_loss_mean.numpy()) #[1.12908343] - loss = paddle.nn.functional.cross_entropy(input=input, label=label, weight=weight) - print(loss) - # [4.28546723] """ if reduction not in ['sum', 'mean', 'none']: @@ -1207,6 +1382,12 @@ def cross_entropy(input, "The value of 'reduction' in softmax_cross_entropy" "should be 'sum', 'mean' or 'none', but received %s, which is not allowed." % reduction) + if ignore_index > 0 and soft_label == True: + raise ValueError( + "When soft_label == True, the value of 'ignore_index' in softmax_cross_entropy" + "should be '-100', but received %s, which is not allowed." % + ignore_index) + input_dims = len(list(input.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: @@ -1216,27 +1397,46 @@ def cross_entropy(input, if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): - out = softmax_with_cross_entropy( - input, - label, - soft_label=soft_label, - ignore_index=ignore_index, - axis=axis) + _, out = core.ops.softmax_with_cross_entropy( + input, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', True, 'axis', axis, + 'use_softmax', use_softmax) + if weight is not None: - weight_gather = core.ops.gather_nd( - weight, label) #trans weight from class to sample, shape:N - input_shape = list(label.shape) - weight_gather_reshape = reshape(weight_gather, shape=input_shape) - out = core.ops.elementwise_mul(out, weight_gather_reshape) + + #trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. + if soft_label == True: + # chajchaj: + # weight's shape is C, where C is class num. + # for 1d case: label's shape is [N,C], weight_gather's shape is N. + # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W]. + weight_gather = paddle.matmul( + x=paddle.cast(label, weight.dtype), + y=weight, + transpose_x=False, + transpose_y=True) + out_shape = list(out.shape) + weight_gather_reshape = reshape(weight_gather, shape=out_shape) + out = paddle.cast(out, weight_gather_reshape.dtype) + + out = core.ops.elementwise_mul(out, weight_gather_reshape) + + else: + weight_gather = core.ops.gather_nd(weight, label) + input_shape = list(label.shape) + weight_gather_reshape = reshape( + weight_gather, shape=input_shape) + out = paddle.cast(out, weight_gather_reshape.dtype) + out = core.ops.elementwise_mul(out, weight_gather_reshape) if reduction == "sum": - # because of softmax_with_cross_entropy op's inner logic, + # because of fluid_softmax_with_cross_entropy op's inner logic, # in the out tensor of this op, the loss of sample with class_index==ignore_index is 0 # so, reduce_sum all directly is ok return core.ops.reduce_sum(out, 'reduce_all', True) elif reduction == "mean": #1. if weight==none, - # numerator: reduce_sum all loss directly is ok causeof softmax_with_cross_entropy's inner logic + # numerator: reduce_sum all loss directly is ok causeof fluid_softmax_with_cross_entropy's inner logic # denominator: count sample num with class_index!=ignore_index #2. else # numerator: loss's weighted sum @@ -1247,7 +1447,7 @@ def cross_entropy(input, #mask[i]=0, if label[i]==ignore_index #mask[i]=1, otherwise mask = (label != ignore_index) - if (weight is None): + if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = core.ops.reduce_sum(mask, 'reduce_all', True) ret = out_sum / count @@ -1277,20 +1477,48 @@ def cross_entropy(input, fluid.data_feeder.check_variable_and_dtype( label, 'label', ['int32', 'int64', 'float32', 'float64'], 'softmax_cross_entropy') - out = softmax_with_cross_entropy( - input, - label, - soft_label=soft_label, - ignore_index=ignore_index, - axis=axis) + attrs = { + 'soft_label': soft_label, + 'ignore_index': ignore_index, + 'numeric_stable_mode': True, + 'axis': axis, + 'use_softmax': use_softmax + } + helper = LayerHelper('softmax_with_cross_entropy', **locals()) + softmax = helper.create_variable_for_type_inference(dtype=input.dtype) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='softmax_with_cross_entropy', + inputs={'Logits': input, + 'Label': label}, + outputs={'Softmax': softmax, + 'Loss': out}, + attrs=attrs) + if weight is not None: fluid.data_feeder.check_variable_and_dtype( weight, 'weight', ['float32', 'float64'], 'softmax_cross_entropy') weight_name = name if reduction == 'none' else None - weight_gather = paddle.gather_nd( - weight, label) #trans weight from class to sample, shape:N - input_shape = list(label.shape) - weight_gather_reshape = reshape(weight_gather, shape=input_shape) + if soft_label == True: + # chajchaj: + #trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. + # weight's shape is C, where C is class num. + # for 1d case: label's shape is [N,C], weight_gather's shape is N. + # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W]. + weight_gather = paddle.matmul( + x=paddle.cast(label, weight.dtype), + y=weight, + transpose_x=False, + transpose_y=True) + + out_shape = list(out.shape) + weight_gather_reshape = reshape(weight_gather, shape=out_shape) + out = paddle.cast(out, weight_gather_reshape.dtype) + else: + weight_gather = paddle.gather_nd( + weight, label) #trans weight from class to sample, shape:N + input_shape = list(label.shape) + weight_gather_reshape = reshape(weight_gather, shape=input_shape) out = paddle.multiply(out, weight_gather_reshape, name=weight_name) if reduction == "sum": diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index ac1cb5a8187720292ff5e942110b6af280f6f9d6..ad046b9041750999c36f74b54e9127446b6d49a8 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -* # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -108,7 +109,6 @@ class BCEWithLogitsLoss(fluid.dygraph.Layer): .. code-block:: python import paddle - paddle.disable_static() logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") bce_logit_loss = paddle.nn.BCEWithLogitsLoss() @@ -142,85 +142,249 @@ class BCEWithLogitsLoss(fluid.dygraph.Layer): class CrossEntropyLoss(fluid.dygraph.Layer): r""" - This operator implements the cross entropy loss function with softmax. This function + By default, this operator implements the cross entropy loss function with softmax. This function combines the calculation of the softmax operation and the cross entropy loss function - to provide a more numerically stable gradient. + to provide a more numerically stable computing. - Because this operator performs a softmax on logits internally, it expects - unscaled logits. This operator should not be used with the output of - softmax operator since that would produce incorrect results. + This operator will calculate the cross entropy loss function without softmax when use_softmax=False. - When the attribute :attr:`soft_label` is set :attr:`False`, this operators - expects mutually exclusive hard labels, each sample in a batch is in exactly - one class with a probability of 1.0. Each sample in the batch will have a - single label. + By default, this operator will calculate the mean of the result, and you can also affect + the default behavior by using the reduction parameter. Please refer to the part of + parameters for details. - The equation is as follows: + This operator can be used to calculate the softmax cross entropy loss with soft and hard labels. + Where, the hard labels mean the actual label value, 0, 1, 2, etc. And the soft labels + mean the probability of the actual label, 0.6, 0.8, 0.2, etc. - 1) Hard label (one-hot label, so every sample has exactly one class) + The calculation of this operator includes the following two steps. - .. math:: + - **I.softmax cross entropy** - loss_j = -\\text{logits}_{label_j} + - \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K + 1. Hard label (each sample can only be assigned into one category) - 2) Soft label (each sample can have a distribution over all classes) + 1.1. when use_softmax=True - .. math:: + .. math:: + \\loss_j=-\text{logits}_{label_j}+\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right) , j = 1,...,N - loss_j = -\\sum_{i=0}^{K}\\text{label}_i - \\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K} - \\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K + where, N is the number of samples and C is the number of categories. + + 1.2. when use_softmax=False + + .. math:: + \\loss_j=-\log\left({P}_{label_j}\right) , j = 1,...,N + + where, N is the number of samples and C is the number of categories, P is input(the output of softmax). + + + 2. Soft label (each sample is assigned to multiple categories with a certain probability, and the probability sum is 1). + + 2.1. when use_softmax=True + + .. math:: + \\loss_j=-\sum_{i=0}^{C}\text{label}_i\left(\text{logits}_i-\log\left(\sum_{i=0}^{C}\exp(\text{logits}_i)\right)\right) , j = 1,...,N + + where, N is the number of samples and C is the number of categories. + + 2.2. when use_softmax=False + + .. math:: + \\loss_j=-\sum_{j=0}^{C}\left({label}_j*\log\left({P}_{label_j}\right)\right) , j = 1,...,N + + where, N is the number of samples and C is the number of categories, P is input(the output of softmax). + + + + - **II.Weight and reduction processing** + + 1. Weight + + If the ``weight`` parameter is ``None`` , go to the next step directly. + + If the ``weight`` parameter is not ``None`` , the cross entropy of each sample is weighted by weight + according to soft_label = False or True as follows. + + 1.1. Hard labels (soft_label = False) + + .. math:: + \\loss_j=loss_j*weight[label_j] - - It is useful when training a classification problem with ``C`` classes. + 1.2. Soft labels (soft_label = True) + .. math:: + \\loss_j=loss_j*\sum_{i}\left(weight[label_i]*logits_i\right) + + 2. reduction + + 2.1 if the ``reduction`` parameter is ``none`` + + Return the previous result directly + + 2.2 if the ``reduction`` parameter is ``sum`` + + Return the sum of the previous results + + .. math:: + \\loss=\sum_{j}loss_j + + 2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to + the ``weight`` parameter as follows. + + 2.3.1. If the ``weight`` parameter is ``None`` + + Return the average value of the previous results + + .. math:: + \\loss=\sum_{j}loss_j/N + + where, N is the number of samples and C is the number of categories. + + 2.3.2. If the 'weight' parameter is not 'None', the weighted average value of the previous result will be returned + + 1. Hard labels (soft_label = False) + + .. math:: + \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] + + 2. Soft labels (soft_label = True) + + .. math:: + \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right) + + Parameters: - input (Tensor): Input tensor, the data type is float32, float64. Shape is - (N, C), where C is number of classes, and if shape is more than 2D, this - is (N, C, D1, D2,..., Dk), k >= 1. - label (Tensor): Label tensor, the data type is int64. Shape is (N), where each - value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is - (N, D1, D2,..., Dk), k >= 1. - weight (Tensor, optional): Weight tensor, a manual rescaling weight given - to each class and the shape is (C). It has the same dimensions as class - number and the data type is float32, float64. Default is ``'None'``. - reduction (str, optional): Indicate how to average the loss by batch_size, + + - **weight** (Tensor, optional) + + a manual rescaling weight given to each class. + If given, has to be a Tensor of size C and the data type is float32, float64. + Default is ``'None'`` . + + - **ignore_index** (int64, optional) + + Specifies a target value that is ignored + and does not contribute to the loss. A negative value means that no label + value needs to be ignored. Only valid when soft_label = False. + Default is ``-100`` . + + - **reduction** (str, optional) + + Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. If :attr:`reduction` is ``'none'``, the unreduced loss is returned. Default is ``'mean'``. - ignore_index (int64, optional): Specifies a target value that is ignored - and does not contribute to the input gradient. Default is ``-100``. - soft_label (bool): indicate whether label is soft. Default False, meaning that - the label is hard. If soft_label=True, the label is soft. - axis (int, optional): The index of dimension to perform softmax calculations. It - should be in range :math:`[-1, rank - 1]`, while :math:`rank` - is the rank of input :attr:`logits`. Default: -1. + - **soft_label** (bool, optional) - Returns: - Tensor. The tensor storing the cross_entropy_loss of input and label. + Indicate whether label is soft. + If soft_label=False, the label is hard. If soft_label=True, the label is soft. + Default is ``False``. + - **axis** (int, optional) + + The index of dimension to perform softmax calculations. + It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the number + of dimensions of input :attr:`input`. + Default is ``-1`` . + + - **use_softmax** (bool, optional) + + Indicate whether compute softmax before cross_entropy. + Default is ``True``. + + - **name** (str,optional) + + The name of the operator. Default is ``None`` . + For more information, please refer to :ref:`api_guide_Name` . + + + Shape: + + - **input** (Tensor) + + Input tensor, the data type is float32, float64. Shape is + :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes , ``k >= 1`` . + + Note: + + 1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the + output of softmax operator, which will produce incorrect results. + + 2. when use_softmax=False, it expects the output of softmax operator. + + + - **label** (Tensor) + + 1. If soft_label=False,the shape is + :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1. + the data type is int32, int64, float32, float64, where each value is [0, C-1]. + + 2. If soft_label=True, the shape and data type should be same with ``input`` , + and the sum of the labels for each sample should be 1. + + - **output** (Tensor) + + Return the softmax cross_entropy loss of ``input`` and ``label``. + + The data type is the same as input. + + If :attr:`reduction` is ``'mean'`` or ``'sum'`` , the dimension of return value is ``1``. + + If :attr:`reduction` is ``'none'``: + + 1. If soft_label = False, the dimension of return value is the same with ``label`` . + + 2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . + + Example1(hard labels): - Examples: .. code-block:: python import paddle - import numpy as np + paddle.seed(99999) + N=100 + C=200 + reduction='mean' + input = paddle.rand([N, C], dtype='float64') + label = paddle.randint(0, C, shape=[N], dtype='int64') + weight = paddle.rand([C], dtype='float64') + + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction=reduction) + dy_ret = cross_entropy_loss( + input, + label) + print(dy_ret.numpy()) #[5.41993642] + + + Example2(soft labels): + + .. code-block:: python + + import paddle + paddle.seed(99999) + axis = -1 + ignore_index = -100 + N = 4 + C = 3 + shape = [N, C] + reduction='mean' + weight = None + logits = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0) + labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0) + labels /= paddle.sum(labels, axis=axis, keepdim=True) + paddle_loss_mean = paddle.nn.functional.cross_entropy( + logits, + labels, + soft_label=True, + axis=axis, + weight=weight, + reduction=reduction) + print(paddle_loss_mean.numpy()) #[1.12908343] - input_data = paddle.uniform([5, 100], dtype="float64") - label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) - weight_data = np.random.random([100]).astype("float64") - input = paddle.to_tensor(input_data) - label = paddle.to_tensor(label_data) - weight = paddle.to_tensor(weight_data) - ce_loss = paddle.nn.CrossEntropyLoss(weight=weight, reduction='mean') - output = ce_loss(input, label) - print(output) - # [4.84496039] """ def __init__(self, @@ -229,6 +393,7 @@ class CrossEntropyLoss(fluid.dygraph.Layer): reduction='mean', soft_label=False, axis=-1, + use_softmax=True, name=None): super(CrossEntropyLoss, self).__init__() self.weight = weight @@ -236,6 +401,7 @@ class CrossEntropyLoss(fluid.dygraph.Layer): self.ignore_index = ignore_index self.soft_label = soft_label self.axis = axis + self.use_softmax = use_softmax self.name = name def forward(self, input, label): @@ -247,6 +413,7 @@ class CrossEntropyLoss(fluid.dygraph.Layer): reduction=self.reduction, soft_label=self.soft_label, axis=self.axis, + use_softmax=self.use_softmax, name=self.name) return ret