From 51398ab90ad7146a546e62576320988c4f54f67f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 27 Dec 2021 15:56:14 +0800 Subject: [PATCH] remove hard labels check --- .../unittests/test_cross_entropy_loss.py | 28 ------------------- python/paddle/nn/functional/loss.py | 17 ----------- 2 files changed, 45 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 12dc47785d2..a30e5741bc8 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1466,34 +1466,6 @@ class TestCrossEntropyFAPIError(unittest.TestCase): self.assertRaises(ValueError, test_WeightLength_NotEqual) - def test_LabelValue_ExceedMax(): - input_data = paddle.rand(shape=[20, 100]) - label_data = paddle.randint( - 0, 100, shape=[20, 1], dtype="int64") # hard label - label_data[0] = 100 - weight_data = paddle.rand([100]) # provide weight - paddle.nn.functional.cross_entropy( - input=input_data, - label=label_data, - weight=weight_data, - ignore_index=-100) - - self.assertRaises(IndexError, test_LabelValue_ExceedMax) - - def test_LabelValue_ExceedMin(): - input_data = paddle.rand(shape=[20, 100]) - label_data = paddle.randint( - 0, 100, shape=[20, 1], dtype="int64") # hard label - label_data[0] = -1 - weight_data = paddle.rand([100]) # provide weight - paddle.nn.functional.cross_entropy( - input=input_data, - label=label_data, - weight=weight_data, - ignore_index=-100) - - self.assertRaises(IndexError, test_LabelValue_ExceedMin) - def static_test_WeightLength_NotEqual(): input_np = np.random.random([2, 4]).astype('float32') label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index cdf80fb58d7..c1800a781d4 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1705,23 +1705,6 @@ def cross_entropy(input, valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label < 0)) > 0: - invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label < 0)) - raise IndexError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: - invalid_label = paddle.gather_nd( - valid_label, - paddle.nonzero(valid_label >= input.shape[axis])) - raise IndexError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[axis] - 1)) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) -- GitLab