diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index b8086eaf4a1ea32ce126fc262f46b1675680034b..81e2160a556d2fddf0e970e5a68315a7ec39f724 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -93,6 +93,90 @@ def cross_entropy_loss_2d(input, class CrossEntropyLoss(unittest.TestCase): + def test_cross_entropy_loss_1d_with_mean_ignore(self): + input_np = np.random.random([2, 4]).astype(np.float64) + label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[2, 4], dtype='float64') + label = fluid.data(name='label', shape=[2], dtype='int64') + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(ignore_index=0) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + expected = cross_entropy_loss_1d(input_np, label_np)[0] + + with fluid.dygraph.guard(): + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + axis=1, ignore_index=0) + dy_ret = cross_entropy_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d(input_np, label_np, ignore_index=0)[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + + def test_cross_entropy_loss_1d_with_weight_mean_ignore(self): + input_np = np.random.random([2, 4]).astype(np.float64) + label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) + weight_np = np.random.random([4]).astype(np.float64) #shape:C + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[2, 4], dtype='float64') + label = fluid.data(name='label', shape=[2], dtype='int64') + weight = fluid.data( + name='weight', shape=[4], + dtype='float64') #weight for each class + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, ignore_index=0) + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + self.assertIsNotNone(static_ret) + expected = cross_entropy_loss_1d( + input_np, label_np, weight=weight_np)[0] + + with fluid.dygraph.guard(): + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=fluid.dygraph.to_variable(weight_np), + axis=1, + ignore_index=0) + dy_ret = cross_entropy_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_1d( + input_np, label_np, weight=weight_np, ignore_index=0)[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_1d_with_weight_mean(self): input_np = np.random.random([2, 4]).astype(np.float64) label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e1f050a57ed7d89e35ba9a436f6bd93c36402f0e..90a3ebc679cf7dc7ea8e391067152cc5175dc9f4 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1223,22 +1223,50 @@ def cross_entropy(input, ignore_index=ignore_index, axis=axis) if weight is not None: - weight_gather = core.ops.gather_nd(weight, label) #trans to sample + weight_gather = core.ops.gather_nd( + weight, label) #trans weight from class to sample, shape:N input_shape = list(label.shape) - weight_gather_reshape, _ = core.ops.reshape2(weight_gather, None, - 'shape', input_shape) + weight_gather_reshape = reshape(weight_gather, shape=input_shape) out = core.ops.elementwise_mul(out, weight_gather_reshape) if reduction == "sum": + # because of softmax_with_cross_entropy op's inner logic, + # in the out tensor of this op, the loss of sample with class_index==ignore_index is 0 + # so, reduce_sum all directly is ok return core.ops.reduce_sum(out, 'reduce_all', True) elif reduction == "mean": - if weight is not None: + #1. if weight==none, + # numerator: reduce_sum all loss directly is ok causeof softmax_with_cross_entropy's inner logic + # denominator: count sample num with class_index!=ignore_index + #2. else + # numerator: loss's weighted sum + # denominator: cal the sum of weight where the sample's class_index!=ignore_index + if ignore_index != -100: + out_sum = core.ops.reduce_sum(out, 'reduce_all', True) + #for each label[i],set 1 or 0, according to ignore_index + #mask[i]=0, if label[i]==ignore_index + #mask[i]=1, otherwise + mask = (label != ignore_index) + if (weight is None): + mask = paddle.cast(mask, dtype=out_sum.dtype) + count = core.ops.reduce_sum(mask, 'reduce_all', True) + ret = out_sum / count + else: + mask = paddle.cast(mask, weight_gather_reshape.dtype) + weight_ignored = core.ops.elementwise_mul( + mask, weight_gather_reshape) + weight_sum = core.ops.reduce_sum(weight_ignored, + 'reduce_all', True) + ret = out_sum / weight_sum + return ret + elif weight is not None: out_sum = core.ops.reduce_sum(out, 'reduce_all', True) total_weight = core.ops.reduce_sum(weight_gather_reshape, 'reduce_all', True) return out_sum / total_weight else: return core.ops.mean(out) + else: if input_dims - 1 == label_dims: out = paddle.squeeze(out, axis=axis) @@ -1258,7 +1286,8 @@ def cross_entropy(input, fluid.data_feeder.check_variable_and_dtype( weight, 'weight', ['float32', 'float64'], 'softmax_cross_entropy') weight_name = name if reduction == 'none' else None - weight_gather = paddle.gather_nd(weight, label) #trans to sample + weight_gather = paddle.gather_nd( + weight, label) #trans weight from class to sample, shape:N input_shape = list(label.shape) weight_gather_reshape = reshape(weight_gather, shape=input_shape) out = paddle.multiply(out, weight_gather_reshape, name=weight_name) @@ -1266,12 +1295,29 @@ def cross_entropy(input, if reduction == "sum": return paddle.sum(out, name=name) elif reduction == "mean": - if weight is not None: + if ignore_index != -100: + out_sum = paddle.sum(out, name=name) + #for each label[i],set 1 or 0, according to ignore_index + #mask[i]=0, if label[i]==ignore_index + #mask[i]=1, otherwise + mask = (label != ignore_index) + if (weight is None): + mask = paddle.cast(mask, dtype=out_sum.dtype) + count = paddle.sum(mask, name=name) + ret = out_sum / count + else: + mask = paddle.cast(mask, weight_gather_reshape.dtype) + weight_ignored = paddle.multiply(mask, weight_gather_reshape) + weight_sum = paddle.sum(weight_ignored, name=name) + ret = out_sum / weight_sum + return ret + elif weight is not None: out_sum = paddle.sum(out, name=name) total_weight = paddle.sum(weight_gather_reshape) return out_sum / total_weight else: return paddle.mean(out, name=name) + else: if input_dims - 1 == label_dims: out = paddle.squeeze(out, axis=axis)