From 9cdf6bd6fc97a7985e934e0305922d92e373aa82 Mon Sep 17 00:00:00 2001 From: root <12272008@bjtu.edu.cn> Date: Wed, 28 Apr 2021 11:42:05 +0000 Subject: [PATCH] add ignore_index for test case --- .../fluid/tests/unittests/test_cross_entropy_loss.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index e9116ae1b44..01710d579a6 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1375,7 +1375,10 @@ class TestCrossEntropyFAPIError(unittest.TestCase): label_data[0] = 255 weight_data = paddle.rand([100]) paddle.nn.functional.cross_entropy( - input=input_data, label=label_data, weight=weight_data) + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=0) self.assertRaises(ValueError, test_LabelValue) @@ -1386,7 +1389,10 @@ class TestCrossEntropyFAPIError(unittest.TestCase): label_data[0] = -1 weight_data = paddle.rand([100]) paddle.nn.functional.cross_entropy( - input=input_data, label=label_data, weight=weight_data) + input=input_data, + label=label_data, + weight=weight_data, + ignore_index=0) self.assertRaises(ValueError, test_LabelValueNeg) -- GitLab