diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 3799640b98800f660e72e3c8b4580949d5deb12a..a2befb4a29a0f39d6d51c1a869c129889ec18015 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase): else paddle.CPUPlace() def test_api_static(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', self.x_shape) out1 = paddle.mean(x) @@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase): for out in res: self.assertEqual(np.allclose(out, out_ref), True) - def test_api_imperative(self): + def test_api_dygraph(self): + paddle.disable_static(self.place) + def test_case(x, axis=None, keepdim=False): x_tensor = paddle.to_variable(x) out = paddle.mean(x_tensor, axis, keepdim) @@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase): out_ref = np.mean(x, axis, keepdims=keepdim) self.assertEqual(np.allclose(out.numpy(), out_ref), True) - paddle.disable_static(self.place) test_case(self.x) test_case(self.x, []) test_case(self.x, -1) @@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase): paddle.enable_static() def test_errors(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', [10, 12], 'int8') self.assertRaises(TypeError, paddle.mean, x)