diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 29676fcff1216a8dc45a356a1e6e8c2ace2fe381..adf11e815faa97bb547b9b1d50f839baa56e9595 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1465,6 +1465,34 @@ class TestCrossEntropyFAPIError(unittest.TestCase): self.assertRaises(ValueError, test_WeightLength_NotEqual) + def test_LabelValue_ExceedMax(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") # hard label + label_data[0] = 100 + weight_data = paddle.rand([100]) # provide weight + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMax) + + def test_LabelValue_ExceedMin(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") # hard label + label_data[0] = -1 + weight_data = paddle.rand([100]) # provide weight + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMin) + def static_test_WeightLength_NotEqual(): input_np = np.random.random([2, 4]).astype('float32') label_np = np.random.randint(0, 4, size=(2)).astype(np.int64)