未验证 提交 4d6a3b9f 编写于 作者: F From00 提交者: GitHub

Fix bug for UT test_calc_gradient (#41130)

上级 4b61918d
...@@ -122,15 +122,16 @@ class TestDoubleGradient(unittest.TestCase): ...@@ -122,15 +122,16 @@ class TestDoubleGradient(unittest.TestCase):
return start_prog, main_prog, [grad_x, jvp] return start_prog, main_prog, [grad_x, jvp]
def test_calc_gradient(self): def test_calc_gradient(self):
start_prog, main_prog, fetch_list = self.build_program() with paddle.fluid.scope_guard(paddle.static.Scope()):
exe = paddle.static.Executor() start_prog, main_prog, fetch_list = self.build_program()
exe.run(start_prog) exe = paddle.static.Executor()
ans = exe.run(main_prog, exe.run(start_prog)
feed={'x': np.ones([2, 2]).astype(np.float32)}, ans = exe.run(main_prog,
fetch_list=fetch_list) feed={'x': np.ones([2, 2]).astype(np.float32)},
self.assertEqual(len(ans), 2) fetch_list=fetch_list)
self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) self.assertEqual(len(ans), 2)
self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]]) self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]])
self.assertListEqual(ans[1].tolist(), [[2., 2.], [2., 2.]])
class TestDoubleGradient2(unittest.TestCase): class TestDoubleGradient2(unittest.TestCase):
...@@ -158,15 +159,16 @@ class TestDoubleGradient2(unittest.TestCase): ...@@ -158,15 +159,16 @@ class TestDoubleGradient2(unittest.TestCase):
return start_prog, main_prog, [grad_x, jvp] return start_prog, main_prog, [grad_x, jvp]
def test_calc_gradient(self): def test_calc_gradient(self):
start_prog, main_prog, fetch_list = self.build_program() with paddle.fluid.scope_guard(paddle.static.Scope()):
exe = paddle.static.Executor() start_prog, main_prog, fetch_list = self.build_program()
exe.run(start_prog) exe = paddle.static.Executor()
ans = exe.run(main_prog, exe.run(start_prog)
feed={'x': np.ones([2, 2]).astype(np.float32)}, ans = exe.run(main_prog,
fetch_list=fetch_list) feed={'x': np.ones([2, 2]).astype(np.float32)},
self.assertEqual(len(ans), 2) fetch_list=fetch_list)
self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]]) self.assertEqual(len(ans), 2)
self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]]) self.assertListEqual(ans[0].tolist(), [[0., 0.], [0., 0.]])
self.assertListEqual(ans[1].tolist(), [[5., 5.], [5., 5.]])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册