From 400eb9d8e6676373f6df2e1c8961e33382ed2e71 Mon Sep 17 00:00:00 2001 From: huangjun12 <12272008@bjtu.edu.cn> Date: Mon, 26 Apr 2021 13:16:02 +0000 Subject: [PATCH] fix ce bug in label value, test=develop --- .../tests/unittests/test_cross_entropy_loss.py | 14 ++++++++++++++ python/paddle/nn/functional/loss.py | 5 +++++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 897d76a35d..3e8d416de1 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1363,5 +1363,19 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(dy_ret_value, expected)) +class TestCrossEntropyFAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_LabelValue(): + input_data = paddle.rand(shape=[20, 100]) + label_data = paddle.randint(0, 100, shape=[5, 1], dtype="int64") + label_data[0] = 255 + paddle.nn.functional.cross_entropy( + input=input_data, label=label_data) + + self.assertRaises(ValueError, test_LabelValue) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index aa0bd8a8c5..323d6fb028 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1411,6 +1411,11 @@ def cross_entropy(input, out = core.ops.elementwise_mul(out, weight_gather_reshape) else: + for label_val in label: + if label_val < 0 or label_val >= input.shape[-1]: + raise ValueError( + 'Expected 0 <= label_value < class_dimension({}), but got label_value {}'. + format(input.shape[-1], label_val.numpy())) weight_gather = core.ops.gather_nd(weight, label) input_shape = list(label.shape) weight_gather_reshape = reshape( -- GitLab