From 9317e51fa60a2f778fed8710d852ee764485d66f Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 18 Aug 2020 10:07:25 +0800 Subject: [PATCH] Fix --- python/paddle/fluid/tests/unittests/test_mean_op.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 3799640b988..a2befb4a29a 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase): else paddle.CPUPlace() def test_api_static(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', self.x_shape) out1 = paddle.mean(x) @@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase): for out in res: self.assertEqual(np.allclose(out, out_ref), True) - def test_api_imperative(self): + def test_api_dygraph(self): + paddle.disable_static(self.place) + def test_case(x, axis=None, keepdim=False): x_tensor = paddle.to_variable(x) out = paddle.mean(x_tensor, axis, keepdim) @@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase): out_ref = np.mean(x, axis, keepdims=keepdim) self.assertEqual(np.allclose(out.numpy(), out_ref), True) - paddle.disable_static(self.place) test_case(self.x) test_case(self.x, []) test_case(self.x, -1) @@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase): paddle.enable_static() def test_errors(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', [10, 12], 'int8') self.assertRaises(TypeError, paddle.mean, x) -- GitLab