diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 897d76a35dcabfc5b39c9cfadd123e4ad4c27b6f..3e8d416de18fcf58b19eff38e290d9e67803e9df 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1363,5 +1363,19 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(dy_ret_value, expected)) +class TestCrossEntropyFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_LabelValue(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint(0, 100, shape=[5, 1], dtype="int64") + label_data[0] = 255 + paddle.nn.functional.cross_entropy( + input=input_data, label=label_data) + + self.assertRaises(ValueError, test_LabelValue) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index aa0bd8a8c5e3d5dcae7c26018d28ce93637947bd..323d6fb0288df45f1da7221f61c412da35f680f0 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1411,6 +1411,11 @@ def cross_entropy(input, out = core.ops.elementwise_mul(out, weight_gather_reshape) else: + for label_val in label: + if label_val < 0 or label_val >= input.shape[-1]: + raise ValueError( + 'Expected 0 <= label_value < class_dimension({}), but got label_value {}'. + format(input.shape[-1], label_val.numpy())) weight_gather = core.ops.gather_nd(weight, label) input_shape = list(label.shape) weight_gather_reshape = reshape(