From 12bcd023023e4b09448f5e039fbd130e7923a498 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 15 Aug 2021 13:43:19 +0800 Subject: [PATCH] fix weighted CE loss's bug --- .../unittests/test_cross_entropy_loss.py | 264 +++++++++++++++++- python/paddle/nn/functional/loss.py | 189 ++++++++----- 2 files changed, 377 insertions(+), 76 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index a4e676901c5..689515511e7 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -50,7 +50,7 @@ def cross_entropy_loss_1d(input, total_weight += cur_weight out[i] = -log_softmax_out[i][cur_target] * cur_weight - ###2. deal with reduction + ###2. deal with reduction if reduction == 'sum': return np.sum(out), np.array([total_weight]).astype('float64') elif reduction == 'mean': @@ -434,7 +434,7 @@ class CrossEntropyLoss(unittest.TestCase): paddle.set_device("cpu") - #2 dygraph + #2 dygraph paddle.disable_static() paddle_loss_mean = paddle.nn.functional.cross_entropy( fluid.dygraph.to_variable(self.logits), @@ -841,6 +841,55 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_1d_with_weight_mean_ignore_exceedlabel(self): + N = 100 + C = 200 + input_np = np.random.random([N, C]).astype(self.dtype) + label_np = np.random.randint(0, C, size=(N)).astype(np.int64) + label_np[0] = 255 + weight_np = np.random.random([C]).astype(self.dtype) + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[N, C], dtype=self.dtype) + label = fluid.data(name='label', shape=[N], dtype='int64') + weight = fluid.data( + name='weight', shape=[C], + dtype=self.dtype) #weight for each class + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, ignore_index=255) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + + with fluid.dygraph.guard(): + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=fluid.dygraph.to_variable(weight_np), + axis=1, + ignore_index=255) + dy_ret = cross_entropy_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d( + input_np, label_np, weight=weight_np, ignore_index=255)[0] + + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_1d_with_weight_mean(self): input_np = np.random.random([2, 4]).astype(self.dtype) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) @@ -1013,7 +1062,7 @@ class CrossEntropyLoss(unittest.TestCase): def test_cross_entropy_loss_1d_mean(self): input_np = np.random.random([100, 200]).astype(self.dtype) #N,C label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) #N,1 - weight_np = np.random.random([200]).astype(self.dtype) #C + # weight_np = np.random.random([200]).astype(self.dtype) #C paddle.enable_static() prog = fluid.Program() startup_prog = fluid.Program() @@ -1022,7 +1071,7 @@ class CrossEntropyLoss(unittest.TestCase): with fluid.program_guard(prog, startup_prog): input = fluid.data(name='input', shape=[100, 200], dtype=self.dtype) label = fluid.data(name='label', shape=[100], dtype='int64') - weight = fluid.data(name='weight', shape=[100], dtype=self.dtype) + # weight = fluid.data(name='weight', shape=[100], dtype=self.dtype) cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss() ret = cross_entropy_loss(input, label) exe = fluid.Executor(place) @@ -1156,6 +1205,58 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self): + N = 4 + C = 3 + H = 512 + W = 512 + input_np = np.random.random([N, H, W, C]).astype(self.dtype) + label_np = np.random.randint(0, C, size=(N, H, W)).astype(np.int64) + label_np[0, 0, 0] = 255 + weight_np = np.random.random([C]).astype(self.dtype) + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[N, H, W, C], dtype=self.dtype) + label = fluid.data(name='label', shape=[N, H, W], dtype='int64') + weight = fluid.data( + name='weight', shape=[C], + dtype=self.dtype) #weight for each class + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, ignore_index=255) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + + with fluid.dygraph.guard(): + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=fluid.dygraph.to_variable(weight_np), + axis=1, + ignore_index=255) + dy_ret = cross_entropy_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + input_np, label_np, weight=weight_np, ignore_index=255)[0] + + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_2d_with_weight_mean(self): input_np = np.random.random(size=(2, 2, 2, 3)).astype(self.dtype) #NHWC label_np = np.random.randint( @@ -1362,21 +1463,62 @@ class TestCrossEntropyFAPIError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): - def test_LabelValue(): + # def test_LabelValue(): + # input_data = paddle.rand(shape=[20, 100]) + # label_data = paddle.randint( + # 0, 100, shape=[20, 1], dtype="int64") + # label_data[0] = 255 + # weight_data = paddle.rand([100]) + # paddle.nn.functional.cross_entropy( + # input=input_data, + # label=label_data, + # weight=weight_data, + # ignore_index=255) + + # self.assertRaises(ValueError, test_LabelValue) + + # def test_LabelValueNeg(): + # input_data = paddle.rand(shape=[20, 100]) + # label_data = paddle.randint( + # 0, 100, shape=[20, 1], dtype="int64") + # label_data[0] = -1 + # weight_data = paddle.rand([100]) + # paddle.nn.functional.cross_entropy( + # input=input_data, + # label=label_data, + # weight=weight_data, + # ignore_index=-1) + + # self.assertRaises(ValueError, test_LabelValueNeg) + + def test_WeightLength_NotEqual(): input_data = paddle.rand(shape=[20, 100]) label_data = paddle.randint( 0, 100, shape=[20, 1], dtype="int64") - label_data[0] = 255 + weight_data = paddle.rand([100 + 1]) + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_WeightLength_NotEqual) + + def test_LabelValue_ExceedMax(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") + label_data[0] = 100 weight_data = paddle.rand([100]) paddle.nn.functional.cross_entropy( input=input_data, label=label_data, weight=weight_data, - ignore_index=255) + ignore_index=-100) - self.assertRaises(ValueError, test_LabelValue) + self.assertRaises(ValueError, test_LabelValue_ExceedMax) - def test_LabelValueNeg(): + def test_LabelValue_ExceedMin(): input_data = paddle.rand(shape=[20, 100]) label_data = paddle.randint( 0, 100, shape=[20, 1], dtype="int64") @@ -1386,9 +1528,107 @@ class TestCrossEntropyFAPIError(unittest.TestCase): input=input_data, label=label_data, weight=weight_data, - ignore_index=-1) - - self.assertRaises(ValueError, test_LabelValueNeg) + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMin) + + def static_test_WeightLength_NotEqual(): + input_np = np.random.random([2, 4]).astype(self.dtype) + label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) + weight_np = np.random.random([3]).astype(self.dtype) #shape:C + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[2, 4], dtype=self.dtype) + label = fluid.data(name='label', shape=[2], dtype='int64') + weight = fluid.data( + name='weight', shape=[3], + dtype=self.dtype) #weight for each class + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + + self.assertRaises(ValueError, static_test_WeightLength_NotEqual) + + def static_test_LabelValue_ExceedMax(): + input_np = np.random.random([2, 4]).astype(self.dtype) + label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) + label_np[0] = 255 + weight_np = np.random.random([4]).astype(self.dtype) #shape:C + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[2, 4], dtype=self.dtype) + label = fluid.data(name='label', shape=[2], dtype='int64') + weight = fluid.data( + name='weight', shape=[4], + dtype=self.dtype) #weight for each class + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + + self.assertRaises(ValueError, static_test_LabelValue_ExceedMax) + + def static_test_LabelValue_ExceedMin(): + input_np = np.random.random([2, 4]).astype(self.dtype) + label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) + label_np[0] = -1 + weight_np = np.random.random([4]).astype(self.dtype) #shape:C + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[2, 4], dtype=self.dtype) + label = fluid.data(name='label', shape=[2], dtype='int64') + weight = fluid.data( + name='weight', shape=[4], + dtype=self.dtype) #weight for each class + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + + self.assertRaises(ValueError, static_test_LabelValue_ExceedMin) if __name__ == "__main__": diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index a1cd80e42f7..270d4e71db4 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1389,18 +1389,18 @@ def cross_entropy(input, use_softmax=True, name=None): r""" - By default, this operator implements the cross entropy loss function with softmax. This function - combines the calculation of the softmax operation and the cross entropy loss function - to provide a more numerically stable computing. + By default, this operator implements the cross entropy loss function with softmax. This function + combines the calculation of the softmax operation and the cross entropy loss function + to provide a more numerically stable computing. This operator will calculate the cross entropy loss function without softmax when use_softmax=False. - By default, this operator will calculate the mean of the result, and you can also affect - the default behavior by using the reduction parameter. Please refer to the part of + By default, this operator will calculate the mean of the result, and you can also affect + the default behavior by using the reduction parameter. Please refer to the part of parameters for details. This operator can be used to calculate the softmax cross entropy loss with soft and hard labels. - Where, the hard labels mean the actual label value, 0, 1, 2, etc. And the soft labels + Where, the hard labels mean the actual label value, 0, 1, 2, etc. And the soft labels mean the probability of the actual label, 0.6, 0.8, 0.2, etc. The calculation of this operator includes the following two steps. @@ -1455,7 +1455,7 @@ def cross_entropy(input, 1.1. Hard labels (soft_label = False) .. math:: - \\loss_j=loss_j*weight[label_j] + \\loss_j=loss_j*weight[label_j] 1.2. Soft labels (soft_label = True) @@ -1465,21 +1465,21 @@ def cross_entropy(input, 2. reduction - 2.1 if the ``reduction`` parameter is ``none`` + 2.1 if the ``reduction`` parameter is ``none`` Return the previous result directly - 2.2 if the ``reduction`` parameter is ``sum`` + 2.2 if the ``reduction`` parameter is ``sum`` Return the sum of the previous results .. math:: \\loss=\sum_{j}loss_j - 2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to - the ``weight`` parameter as follows. + 2.3 if the ``reduction`` parameter is ``mean`` , it will be processed according to + the ``weight`` parameter as follows. - 2.3.1. If the ``weight`` parameter is ``None`` + 2.3.1. If the ``weight`` parameter is ``None`` Return the average value of the previous results @@ -1493,48 +1493,48 @@ def cross_entropy(input, 1. Hard labels (soft_label = False) .. math:: - \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] + \\loss=\sum_{j}loss_j/\sum_{j}weight[label_j] 2. Soft labels (soft_label = True) .. math:: \\loss=\sum_{j}loss_j/\sum_{j}\left(\sum_{i}weight[label_i]\right) - - + + Parameters: - **input** (Tensor) Input tensor, the data type is float32, float64. Shape is - :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes , ``k >= 1`` . + :math:`[N_1, N_2, ..., N_k, C]`, where C is number of classes , ``k >= 1`` . - Note: + Note: - 1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the + 1. when use_softmax=True, it expects unscaled logits. This operator should not be used with the output of softmax operator, which will produce incorrect results. 2. when use_softmax=False, it expects the output of softmax operator. - + - **label** (Tensor) 1. If soft_label=False, the shape is :math:`[N_1, N_2, ..., N_k]` or :math:`[N_1, N_2, ..., N_k, 1]`, k >= 1. the data type is int32, int64, float32, float64, where each value is [0, C-1]. - 2. If soft_label=True, the shape and data type should be same with ``input`` , + 2. If soft_label=True, the shape and data type should be same with ``input`` , and the sum of the labels for each sample should be 1. - **weight** (Tensor, optional) - a manual rescaling weight given to each class. - If given, has to be a Tensor of size C and the data type is float32, float64. + a manual rescaling weight given to each class. + If given, has to be a Tensor of size C and the data type is float32, float64. Default is ``'None'`` . - **ignore_index** (int64, optional) Specifies a target value that is ignored - and does not contribute to the loss. A negative value means that no label - value needs to be ignored. Only valid when soft_label = False. + and does not contribute to the loss. A negative value means that no label + value needs to be ignored. Only valid when soft_label = False. Default is ``-100`` . - **reduction** (str, optional) @@ -1548,14 +1548,14 @@ def cross_entropy(input, - **soft_label** (bool, optional) - Indicate whether label is soft. + Indicate whether label is soft. Default is ``False``. - **axis** (int, optional) - The index of dimension to perform softmax calculations. - It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the - number of dimensions of input :attr:`input`. + The index of dimension to perform softmax calculations. + It should be in range :math:`[-1, rank - 1]`, where :math:`rank` is the + number of dimensions of input :attr:`input`. Default is ``-1`` . - **use_softmax** (bool, optional) @@ -1577,24 +1577,24 @@ def cross_entropy(input, If :attr:`reduction` is ``'none'``: - 1. If soft_label = False, the dimension of return value is the same with ``label`` . + 1. If soft_label = False, the dimension of return value is the same with ``label`` . - 2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . + 2. if soft_label = True, the dimension of return value is :math:`[N_1, N_2, ..., N_k, 1]` . Example1(hard labels): .. code-block:: python - + import paddle paddle.seed(99999) N=100 C=200 reduction='mean' - input = paddle.rand([N, C], dtype='float64') + input = paddle.rand([N, C], dtype='float64') label = paddle.randint(0, C, shape=[N], dtype='int64') - weight = paddle.rand([C], dtype='float64') - + weight = paddle.rand([C], dtype='float64') + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( weight=weight, reduction=reduction) dy_ret = cross_entropy_loss( @@ -1606,7 +1606,7 @@ def cross_entropy(input, Example2(soft labels): .. code-block:: python - + import paddle paddle.seed(99999) axis = -1 @@ -1620,9 +1620,9 @@ def cross_entropy(input, labels = paddle.uniform(shape, dtype='float64', min=0.1, max=1.0) labels /= paddle.sum(labels, axis=axis, keepdim=True) paddle_loss_mean = paddle.nn.functional.cross_entropy( - logits, - labels, - soft_label=True, + logits, + labels, + soft_label=True, axis=axis, weight=weight, reduction=reduction) @@ -1657,7 +1657,7 @@ def cross_entropy(input, if weight is not None: - #trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. + # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. if soft_label == True: # chajchaj: # weight's shape is C, where C is class num. @@ -1675,14 +1675,43 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: - label_min = paddle.min(label) - label_max = paddle.max(label) - if label_min < 0 or label_max >= input.shape[-1]: + if input.shape[-1] != weight.shape[-1]: raise ValueError( - 'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '. - format(input.shape[-1], - label_min.numpy(), label_max.numpy())) - weight_gather = _C_ops.gather_nd(weight, label) + "input's class_dimension({}) must equal to \ + weight's class_dimension({}) \ + when weight is provided" + .format(input.shape[-1], weight.shape[-1])) + valid_label = paddle.where( + label == ignore_index, + paddle.to_tensor( + 0, dtype=label.dtype), + label) + + if (len(paddle.nonzero(valid_label < 0)) > 0) or ( + len(paddle.nonzero(valid_label >= input.shape[-1])) > 0 + ): + invalid_label = paddle.gather_nd( + input, paddle.nonzero(valid_label < 0)) + if invalid_label.numel() > 0: + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + invalid_label = paddle.gather_nd( + input, paddle.nonzero(valid_label >= input.shape[-1])) + if invalid_label.numel() > 0: + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[-1])) + + ignore_weight_mask = paddle.cast((label != ignore_index), + out.dtype) + if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ + -1] == 1: + ignore_weight_mask.squeeze_(-1) + weight_gather = _C_ops.gather_nd( + weight, valid_label) # ignore的位置暂时用label0的权重代替 + weight_gather = _C_ops.elementwise_mul(weight_gather, + ignore_weight_mask) input_shape = list(label.shape) weight_gather_reshape = reshape( weight_gather, shape=input_shape) @@ -1690,22 +1719,22 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) if reduction == "sum": - # because of fluid_softmax_with_cross_entropy op's inner logic, + # because of fluid_softmax_with_cross_entropy op's inner logic, # in the out tensor of this op, the loss of sample with class_index==ignore_index is 0 # so, reduce_sum all directly is ok return _C_ops.reduce_sum(out, 'reduce_all', True) elif reduction == "mean": - #1. if weight==none, - # numerator: reduce_sum all loss directly is ok causeof fluid_softmax_with_cross_entropy's inner logic - # denominator: count sample num with class_index!=ignore_index - #2. else - # numerator: loss's weighted sum - # denominator: cal the sum of weight where the sample's class_index!=ignore_index + # 1. if weight==none, + # numerator: reduce_sum all loss directly is ok causeof fluid_softmax_with_cross_entropy's inner logic + # denominator: count sample num with class_index!=ignore_index + # 2. else + # numerator: loss's weighted sum + # denominator: cal the sum of weight where the sample's class_index!=ignore_index if ignore_index != -100: out_sum = _C_ops.reduce_sum(out, 'reduce_all', True) - #for each label[i],set 1 or 0, according to ignore_index - #mask[i]=0, if label[i]==ignore_index - #mask[i]=1, otherwise + # for each label[i],set 1 or 0, according to ignore_index + # mask[i]=0, if label[i]==ignore_index + # mask[i]=1, otherwise mask = (label != ignore_index) if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) @@ -1761,7 +1790,7 @@ def cross_entropy(input, weight_name = name if reduction == 'none' else None if soft_label == True: # chajchaj: - #trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. + # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. # weight's shape is C, where C is class num. # for 1d case: label's shape is [N,C], weight_gather's shape is N. # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W]. @@ -1775,8 +1804,40 @@ def cross_entropy(input, weight_gather_reshape = reshape(weight_gather, shape=out_shape) out = paddle.cast(out, weight_gather_reshape.dtype) else: + if input.shape[-1] != weight.shape[-1]: + raise ValueError("input's class_dimension({}) must equal to \ + weight's class_dimension({}) \ + when weight is provided" + .format(input.shape[-1], weight.shape[-1])) + valid_label = paddle.where( + label == ignore_index, + paddle.to_tensor( + 0, dtype=label.dtype), + label) + if (len(paddle.nonzero(valid_label < 0)) > 0) or ( + len(paddle.nonzero(valid_label >= input.shape[-1])) > 0): + invalid_label = paddle.gather_nd( + input, paddle.nonzero(valid_label < 0)) + if invalid_label.numel() > 0: + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + invalid_label = paddle.gather_nd( + input, paddle.nonzero(valid_label >= input.shape[-1])) + if invalid_label.numel() > 0: + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[-1])) + + ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) + if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ + -1] == 1: + ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) + weight_gather = paddle.gather_nd( - weight, label) #trans weight from class to sample, shape:N + weight, + valid_label) #trans weight from class to sample, shape:N + weight_gather = paddle.multiply(weight_gather, ignore_weight_mask) input_shape = list(label.shape) weight_gather_reshape = reshape(weight_gather, shape=input_shape) out = paddle.multiply(out, weight_gather_reshape, name=weight_name) @@ -1786,9 +1847,9 @@ def cross_entropy(input, elif reduction == "mean": if ignore_index != -100: out_sum = paddle.sum(out, name=name) - #for each label[i],set 1 or 0, according to ignore_index - #mask[i]=0, if label[i]==ignore_index - #mask[i]=1, otherwise + # for each label[i],set 1 or 0, according to ignore_index + # mask[i]=0, if label[i]==ignore_index + # mask[i]=1, otherwise mask = (label != ignore_index) if (weight is None): mask = paddle.cast(mask, dtype=out_sum.dtype) @@ -1828,12 +1889,12 @@ def sigmoid_focal_loss(logit, it is used in one-stage object detection where the foreground-background class imbalance is extremely high. - This operator measures focal loss function as follows: + This operator measures focal loss function as follows: .. math:: Out = -Labels * alpha * {(1 - \sigma(Logit))}^{gamma}\log(\sigma(Logit)) - (1 - Labels) * (1 - alpha) * {\sigma(Logit)}^{gamma}\log(1 - \sigma(Logit)) - We know that :math:`\sigma(Logit) = \frac{1}{1 + \exp(-Logit)}`. + We know that :math:`\sigma(Logit) = \frac{1}{1 + \exp(-Logit)}`. Then, if :attr:`normalizer` is not None, this operator divides the normalizer tensor on the loss `Out`: @@ -1860,7 +1921,7 @@ def sigmoid_focal_loss(logit, For object detection task, it is the the number of positive samples. If set to None, the focal loss will not be normalized. Default is None. alpha(int|float, optional): Hyper-parameter to balance the positive and negative example, - it should be between 0 and 1. Default value is set to 0.25. + it should be between 0 and 1. Default value is set to 0.25. gamma(int|float, optional): Hyper-parameter to modulate the easy and hard examples. Default value is set to 2.0. reduction (str, optional): Indicate how to average the loss by batch_size, -- GitLab