未验证 提交 76312deb 编写于 作者: L lijianshe02 提交者: GitHub

fix nll_loss test random fail bug test=develop (#29236)

上级 8a2dd34a
...@@ -800,13 +800,16 @@ class TestNLLLossOp2DWithReduce(OpTest): ...@@ -800,13 +800,16 @@ class TestNLLLossOp2DWithReduce(OpTest):
self.init_test_case() self.init_test_case()
self.op_type = "nll_loss" self.op_type = "nll_loss"
self.with_weight = False self.with_weight = False
np.random.seed(200)
input_np = np.random.uniform(0.1, 0.8, input_np = np.random.uniform(0.1, 0.8,
self.input_shape).astype("float64") self.input_shape).astype("float64")
np.random.seed(200)
label_np = np.random.randint(0, self.input_shape[1], label_np = np.random.randint(0, self.input_shape[1],
self.label_shape).astype("int64") self.label_shape).astype("int64")
output_np, total_weight_np = nll_loss_2d(input_np, label_np) output_np, total_weight_np = nll_loss_2d(input_np, label_np)
self.inputs = {'X': input_np, 'Label': label_np} self.inputs = {'X': input_np, 'Label': label_np}
if self.with_weight: if self.with_weight:
np.random.seed(200)
weight_np = np.random.uniform(0.1, 0.8, weight_np = np.random.uniform(0.1, 0.8,
self.input_shape[1]).astype("float64") self.input_shape[1]).astype("float64")
output_np, total_weight_np = nll_loss_2d( output_np, total_weight_np = nll_loss_2d(
...@@ -832,8 +835,8 @@ class TestNLLLossOp2DWithReduce(OpTest): ...@@ -832,8 +835,8 @@ class TestNLLLossOp2DWithReduce(OpTest):
self.check_grad_with_place(place, ['X'], 'Out') self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self): def init_test_case(self):
self.input_shape = [5, 3, 5, 5] self.input_shape = [2, 3, 5, 5]
self.label_shape = [5, 5, 5] self.label_shape = [2, 5, 5]
class TestNLLLossOp2DNoReduce(OpTest): class TestNLLLossOp2DNoReduce(OpTest):
...@@ -899,7 +902,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): ...@@ -899,7 +902,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
place = paddle.CPUPlace() place = paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog): with paddle.static.program_guard(prog, startup_prog):
x = paddle.fluid.data(name='x', shape=[10, ], dtype='float64') x = paddle.fluid.data(name='x', shape=[10, ], dtype='float64')
label = paddle.fluid.data(name='label', shape=[10, ], dtype='float64') label = paddle.fluid.data(
name='label', shape=[10, ], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss() nll_loss = paddle.nn.loss.NLLLoss()
res = nll_loss(x, label) res = nll_loss(x, label)
...@@ -923,7 +927,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): ...@@ -923,7 +927,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
place = paddle.CPUPlace() place = paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog): with paddle.static.program_guard(prog, startup_prog):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
label = paddle.fluid.data(name='label', shape=[10], dtype='int64') label = paddle.fluid.data(
name='label', shape=[10], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss(reduction='') nll_loss = paddle.nn.loss.NLLLoss(reduction='')
res = nll_loss(x, label) res = nll_loss(x, label)
...@@ -947,7 +952,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): ...@@ -947,7 +952,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
place = paddle.CPUPlace() place = paddle.CPUPlace()
with paddle.static.program_guard(prog, startup_prog): with paddle.static.program_guard(prog, startup_prog):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
label = paddle.fluid.data(name='label', shape=[10], dtype='int64') label = paddle.fluid.data(
name='label', shape=[10], dtype='int64')
res = paddle.nn.functional.nll_loss(x, label, reduction='') res = paddle.nn.functional.nll_loss(x, label, reduction='')
self.assertRaises(ValueError, self.assertRaises(ValueError,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册