From d49daff06f89a7854c039e6dd21be18d4e852160 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 26 Dec 2021 20:03:02 +0800 Subject: [PATCH] restore test for min,max labels --- .../unittests/test_cross_entropy_loss.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 29676fcff12..adf11e815fa 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1465,6 +1465,34 @@ class TestCrossEntropyFAPIError(unittest.TestCase): self.assertRaises(ValueError, test_WeightLength_NotEqual) + def test_LabelValue_ExceedMax(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") # hard label + label_data[0] = 100 + weight_data = paddle.rand([100]) # provide weight + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMax) + + def test_LabelValue_ExceedMin(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") # hard label + label_data[0] = -1 + weight_data = paddle.rand([100]) # provide weight + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMin) + def static_test_WeightLength_NotEqual(): input_np = np.random.random([2, 4]).astype('float32') label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) -- GitLab