未验证 提交 1080be33 编写于 作者: W wawltor 提交者: GitHub

update the test_mean test case for bug fix

update the test_mean test case
上级 c10cf6d2
...@@ -101,7 +101,7 @@ class TestMeanAPI(unittest.TestCase): ...@@ -101,7 +101,7 @@ class TestMeanAPI(unittest.TestCase):
fetch_list=[out1, out2, out3, out4, out5]) fetch_list=[out1, out2, out3, out4, out5])
out_ref = np.mean(self.x) out_ref = np.mean(self.x)
for out in res: for out in res:
self.assertEqual(np.allclose(out, out_ref), True) self.assertEqual(np.allclose(out, out_ref, rtol=1e-04), True)
def test_api_dygraph(self): def test_api_dygraph(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
...@@ -114,7 +114,9 @@ class TestMeanAPI(unittest.TestCase): ...@@ -114,7 +114,9 @@ class TestMeanAPI(unittest.TestCase):
if len(axis) == 0: if len(axis) == 0:
axis = None axis = None
out_ref = np.mean(x, axis, keepdims=keepdim) out_ref = np.mean(x, axis, keepdims=keepdim)
self.assertEqual(np.allclose(out.numpy(), out_ref), True) self.assertEqual(
np.allclose(
out.numpy(), out_ref, rtol=1e-04), True)
test_case(self.x) test_case(self.x)
test_case(self.x, []) test_case(self.x, [])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册