diff --git a/python/paddle/fluid/tests/unittests/test_calc_gradient.py b/python/paddle/fluid/tests/unittests/test_calc_gradient.py index 40e5abccb2d5785b5d46e709ce38ae35853d88cf..63ba16c57e09b67d4c4368e72c1dc79d0d0ed312 100644 --- a/python/paddle/fluid/tests/unittests/test_calc_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_calc_gradient.py @@ -122,15 +122,16 @@ class TestDoubleGradient(unittest.TestCase): return start_prog, main_prog, [grad_x, jvp] def test_calc_gradient(self): - start_prog, main_prog, fetch_list = self.build_program() - exe = paddle.static.Executor() - exe.run(start_prog) - ans = exe.run(main_prog, - feed={'x': np.ones([2, 2]).astype(np.float32)}, - fetch_list=fetch_list) - self.assertEqual(len(ans), 2) - self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) - self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]]) + with paddle.fluid.scope_guard(paddle.static.Scope()): + start_prog, main_prog, fetch_list = self.build_program() + exe = paddle.static.Executor() + exe.run(start_prog) + ans = exe.run(main_prog, + feed={'x': np.ones([2, 2]).astype(np.float32)}, + fetch_list=fetch_list) + self.assertEqual(len(ans), 2) + self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) + self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]]) class TestDoubleGradient2(unittest.TestCase): @@ -158,15 +159,16 @@ class TestDoubleGradient2(unittest.TestCase): return start_prog, main_prog, [grad_x, jvp] def test_calc_gradient(self): - start_prog, main_prog, fetch_list = self.build_program() - exe = paddle.static.Executor() - exe.run(start_prog) - ans = exe.run(main_prog, - feed={'x': np.ones([2, 2]).astype(np.float32)}, - fetch_list=fetch_list) - self.assertEqual(len(ans), 2) - self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) - self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]]) + with paddle.fluid.scope_guard(paddle.static.Scope()): + start_prog, main_prog, fetch_list = self.build_program() + exe = paddle.static.Executor() + exe.run(start_prog) + ans = exe.run(main_prog, + feed={'x': np.ones([2, 2]).astype(np.float32)}, + fetch_list=fetch_list) + self.assertEqual(len(ans), 2) + self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) + self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]]) if __name__ == "__main__":