未验证 提交 9317e51f 编写于 作者: Z zhupengyang 提交者: GitHub

Fix

上级 cd48bdad
...@@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase): ...@@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase):
else paddle.CPUPlace() else paddle.CPUPlace()
def test_api_static(self): def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_shape) x = paddle.data('X', self.x_shape)
out1 = paddle.mean(x) out1 = paddle.mean(x)
...@@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase): ...@@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase):
for out in res: for out in res:
self.assertEqual(np.allclose(out, out_ref), True) self.assertEqual(np.allclose(out, out_ref), True)
def test_api_imperative(self): def test_api_dygraph(self):
paddle.disable_static(self.place)
def test_case(x, axis=None, keepdim=False): def test_case(x, axis=None, keepdim=False):
x_tensor = paddle.to_variable(x) x_tensor = paddle.to_variable(x)
out = paddle.mean(x_tensor, axis, keepdim) out = paddle.mean(x_tensor, axis, keepdim)
...@@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase): ...@@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase):
out_ref = np.mean(x, axis, keepdims=keepdim) out_ref = np.mean(x, axis, keepdims=keepdim)
self.assertEqual(np.allclose(out.numpy(), out_ref), True) self.assertEqual(np.allclose(out.numpy(), out_ref), True)
paddle.disable_static(self.place)
test_case(self.x) test_case(self.x)
test_case(self.x, []) test_case(self.x, [])
test_case(self.x, -1) test_case(self.x, -1)
...@@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase): ...@@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_errors(self): def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12], 'int8') x = paddle.data('X', [10, 12], 'int8')
self.assertRaises(TypeError, paddle.mean, x) self.assertRaises(TypeError, paddle.mean, x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册