From 94952111f5ab0fbce5d40c9c1ed53d2d3c1374c1 Mon Sep 17 00:00:00 2001 From: root <12272008@bjtu.edu.cn> Date: Tue, 27 Apr 2021 12:26:29 +0000 Subject: [PATCH] add weigth data to unit test --- .../paddle/fluid/tests/unittests/test_cross_entropy_loss.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index f99c3b2129..e9116ae1b4 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1373,8 +1373,9 @@ class TestCrossEntropyFAPIError(unittest.TestCase): label_data = paddle.randint( 0, 100, shape=[20, 1], dtype="int64") label_data[0] = 255 + weight_data = paddle.rand([100]) paddle.nn.functional.cross_entropy( - input=input_data, label=label_data) + input=input_data, label=label_data, weight=weight_data) self.assertRaises(ValueError, test_LabelValue) @@ -1383,8 +1384,9 @@ class TestCrossEntropyFAPIError(unittest.TestCase): label_data = paddle.randint( 0, 100, shape=[20, 1], dtype="int64") label_data[0] = -1 + weight_data = paddle.rand([100]) paddle.nn.functional.cross_entropy( - input=input_data, label=label_data) + input=input_data, label=label_data, weight=weight_data) self.assertRaises(ValueError, test_LabelValueNeg) -- GitLab