From 76312deb3059659aa799fd770d2945890f15bda3 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Tue, 1 Dec 2020 12:55:00 +0800 Subject: [PATCH] fix nll_loss test random fail bug test=develop (#29236) --- .../fluid/tests/unittests/test_nll_loss.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nll_loss.py b/python/paddle/fluid/tests/unittests/test_nll_loss.py index 2b741fcd079..aa64a35564b 100644 --- a/python/paddle/fluid/tests/unittests/test_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_nll_loss.py @@ -800,13 +800,16 @@ class TestNLLLossOp2DWithReduce(OpTest): self.init_test_case() self.op_type = "nll_loss" self.with_weight = False + np.random.seed(200) input_np = np.random.uniform(0.1, 0.8, self.input_shape).astype("float64") + np.random.seed(200) label_np = np.random.randint(0, self.input_shape[1], self.label_shape).astype("int64") output_np, total_weight_np = nll_loss_2d(input_np, label_np) self.inputs = {'X': input_np, 'Label': label_np} if self.with_weight: + np.random.seed(200) weight_np = np.random.uniform(0.1, 0.8, self.input_shape[1]).astype("float64") output_np, total_weight_np = nll_loss_2d( @@ -832,8 +835,8 @@ class TestNLLLossOp2DWithReduce(OpTest): self.check_grad_with_place(place, ['X'], 'Out') def init_test_case(self): - self.input_shape = [5, 3, 5, 5] - self.label_shape = [5, 5, 5] + self.input_shape = [2, 3, 5, 5] + self.label_shape = [2, 5, 5] class TestNLLLossOp2DNoReduce(OpTest): @@ -899,7 +902,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): place = paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): x = paddle.fluid.data(name='x', shape=[10, ], dtype='float64') - label = paddle.fluid.data(name='label', shape=[10, ], dtype='float64') + label = paddle.fluid.data( + name='label', shape=[10, ], dtype='float64') nll_loss = paddle.nn.loss.NLLLoss() res = nll_loss(x, label) @@ -923,7 +927,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): place = paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') - label = paddle.fluid.data(name='label', shape=[10], dtype='int64') + label = paddle.fluid.data( + name='label', shape=[10], dtype='int64') nll_loss = paddle.nn.loss.NLLLoss(reduction='') res = nll_loss(x, label) @@ -947,7 +952,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): place = paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') - label = paddle.fluid.data(name='label', shape=[10], dtype='int64') + label = paddle.fluid.data( + name='label', shape=[10], dtype='int64') res = paddle.nn.functional.nll_loss(x, label, reduction='') self.assertRaises(ValueError, -- GitLab