未验证 提交 9317e51f 编写于 作者: Z zhupengyang 提交者: GitHub

Fix

上级 cd48bdad
......@@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase):
else paddle.CPUPlace()
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_shape)
out1 = paddle.mean(x)
......@@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase):
for out in res:
self.assertEqual(np.allclose(out, out_ref), True)
def test_api_imperative(self):
def test_api_dygraph(self):
paddle.disable_static(self.place)
def test_case(x, axis=None, keepdim=False):
x_tensor = paddle.to_variable(x)
out = paddle.mean(x_tensor, axis, keepdim)
......@@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase):
out_ref = np.mean(x, axis, keepdims=keepdim)
self.assertEqual(np.allclose(out.numpy(), out_ref), True)
paddle.disable_static(self.place)
test_case(self.x)
test_case(self.x, [])
test_case(self.x, -1)
......@@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase):
paddle.enable_static()
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12], 'int8')
self.assertRaises(TypeError, paddle.mean, x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册