From 09d4a3a4737152b9fac4b105dcc2f389c3e6be2a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 28 Dec 2021 17:32:21 +0800 Subject: [PATCH] add static label check --- .../unittests/test_cross_entropy_loss.py | 28 ++++++++++++++++ python/paddle/nn/functional/loss.py | 32 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 29676fcff1..d3ed76e34a 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1465,6 +1465,34 @@ class TestCrossEntropyFAPIError(unittest.TestCase): self.assertRaises(ValueError, test_WeightLength_NotEqual) + def test_LabelValue_ExceedMax(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") + label_data[0] = 100 + weight_data = paddle.rand([100]) + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMax) + + def test_LabelValue_ExceedMin(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint( + 0, 100, shape=[20, 1], dtype="int64") + label_data[0] = -1 + weight_data = paddle.rand([100]) + paddle.nn.functional.cross_entropy( + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=-100) + + self.assertRaises(ValueError, test_LabelValue_ExceedMin) + def static_test_WeightLength_NotEqual(): input_np = np.random.random([2, 4]).astype('float32') label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 4d09f1d5c3..aee0366ab3 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1710,6 +1710,22 @@ def cross_entropy(input, ignore_weight_mask, dtype=label.dtype) * label # ignored position will be 0 + if len(paddle.nonzero(valid_label < 0)) > 0: + invalid_label = paddle.gather_nd( + valid_label, paddle.nonzero(valid_label < 0)) + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values + if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: + invalid_label = paddle.gather_nd( + valid_label, + paddle.nonzero(valid_label >= input.shape[axis])) + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[axis] - 1)) + ignore_weight_mask = paddle.cast( ignore_weight_mask, out.dtype) # convert from 0 to 0.0 @@ -1833,6 +1849,22 @@ def cross_entropy(input, ignore_weight_mask, dtype=label.dtype) * label # ignored position will be 0 + if len(paddle.nonzero(valid_label < 0)) > 0: + invalid_label = paddle.gather_nd( + valid_label, paddle.nonzero(valid_label < 0)) + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values + if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: + invalid_label = paddle.gather_nd( + valid_label, + paddle.nonzero(valid_label >= input.shape[axis])) + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[axis] - 1)) + ignore_weight_mask = paddle.cast(ignore_weight_mask, out.dtype) # convert from 0 to 0.0 -- GitLab