From 46e856c7c795a2b1ef42770efab64b820d4b4621 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 26 Dec 2021 13:30:54 +0800 Subject: [PATCH] Remove the labels range check under the dynamic graph --- .../unittests/test_cross_entropy_loss.py | 28 ------------- python/paddle/nn/functional/loss.py | 39 +++++++++---------- 2 files changed, 19 insertions(+), 48 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index d3ed76e34a..29676fcff1 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1465,34 +1465,6 @@ class TestCrossEntropyFAPIError(unittest.TestCase): self.assertRaises(ValueError, test_WeightLength_NotEqual) - def test_LabelValue_ExceedMax(): - input_data = paddle.rand(shape=[20, 100]) - label_data = paddle.randint( - 0, 100, shape=[20, 1], dtype="int64") - label_data[0] = 100 - weight_data = paddle.rand([100]) - paddle.nn.functional.cross_entropy( - input=input_data, - label=label_data, - weight=weight_data, - ignore_index=-100) - - self.assertRaises(ValueError, test_LabelValue_ExceedMax) - - def test_LabelValue_ExceedMin(): - input_data = paddle.rand(shape=[20, 100]) - label_data = paddle.randint( - 0, 100, shape=[20, 1], dtype="int64") - label_data[0] = -1 - weight_data = paddle.rand([100]) - paddle.nn.functional.cross_entropy( - input=input_data, - label=label_data, - weight=weight_data, - ignore_index=-100) - - self.assertRaises(ValueError, test_LabelValue_ExceedMin) - def static_test_WeightLength_NotEqual(): input_np = np.random.random([2, 4]).astype('float32') label_np = np.random.randint(0, 4, size=(2)).astype(np.int64) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 554651ea13..05f06ef534 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1665,26 +1665,6 @@ def cross_entropy(input, if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): - if soft_label == False: - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label < 0)) > 0: - invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label < 0)) - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: - invalid_label = paddle.gather_nd( - valid_label, - paddle.nonzero(valid_label >= input.shape[axis])) - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[axis] - 1)) if core.is_compiled_with_npu(): _, _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', @@ -1716,6 +1696,25 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: + valid_label = paddle.where(label == ignore_index, + paddle.zeros_like(label), label) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values + if len(paddle.nonzero(valid_label < 0)) > 0: + invalid_label = paddle.gather_nd( + valid_label, paddle.nonzero(valid_label < 0)) + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values + if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: + invalid_label = paddle.gather_nd( + valid_label, + paddle.nonzero(valid_label >= input.shape[axis])) + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[axis] - 1)) if input.shape[axis] != weight.shape[-1]: raise ValueError( "input's class_dimension({}) must equal to " -- GitLab