From 4d6a3b9f3c6fa4a3e08d324e32f57b63292e04cf Mon Sep 17 00:00:00 2001 From: From00 Date: Wed, 30 Mar 2022 21:10:53 +0800 Subject: [PATCH] Fix bug for UT test_calc_gradient (#41130) --- .../tests/unittests/test_calc_gradient.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_calc_gradient.py b/python/paddle/fluid/tests/unittests/test_calc_gradient.py index 40e5abccb2..63ba16c57e 100644 --- a/python/paddle/fluid/tests/unittests/test_calc_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_calc_gradient.py @@ -122,15 +122,16 @@ class TestDoubleGradient(unittest.TestCase): return start_prog, main_prog, [grad_x, jvp] def test_calc_gradient(self): - start_prog, main_prog, fetch_list = self.build_program() - exe = paddle.static.Executor() - exe.run(start_prog) - ans = exe.run(main_prog, - feed={'x': np.ones([2, 2]).astype(np.float32)}, - fetch_list=fetch_list) - self.assertEqual(len(ans), 2) - self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) - self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]]) + with paddle.fluid.scope_guard(paddle.static.Scope()): + start_prog, main_prog, fetch_list = self.build_program() + exe = paddle.static.Executor() + exe.run(start_prog) + ans = exe.run(main_prog, + feed={'x': np.ones([2, 2]).astype(np.float32)}, + fetch_list=fetch_list) + self.assertEqual(len(ans), 2) + self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) + self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]]) class TestDoubleGradient2(unittest.TestCase): @@ -158,15 +159,16 @@ class TestDoubleGradient2(unittest.TestCase): return start_prog, main_prog, [grad_x, jvp] def test_calc_gradient(self): - start_prog, main_prog, fetch_list = self.build_program() - exe = paddle.static.Executor() - exe.run(start_prog) - ans = exe.run(main_prog, - feed={'x': np.ones([2, 2]).astype(np.float32)}, - fetch_list=fetch_list) - self.assertEqual(len(ans), 2) - self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) - self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]]) + with paddle.fluid.scope_guard(paddle.static.Scope()): + start_prog, main_prog, fetch_list = self.build_program() + exe = paddle.static.Executor() + exe.run(start_prog) + ans = exe.run(main_prog, + feed={'x': np.ones([2, 2]).astype(np.float32)}, + fetch_list=fetch_list) + self.assertEqual(len(ans), 2) + self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) + self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]]) if __name__ == "__main__": -- GitLab