未验证 提交 39d85bfb 编写于 作者: H hong 提交者: GitHub

Enable load program state in imperative mode (#24998) (#25441)

* enable load_program_state run in imperative mode; test=develop

* remove useless code; test=develop
上级 6e1d0efd
...@@ -191,7 +191,8 @@ def _load_program_scope(main=None, startup=None, scope=None): ...@@ -191,7 +191,8 @@ def _load_program_scope(main=None, startup=None, scope=None):
with paddle.fluid.scope_guard(scope): with paddle.fluid.scope_guard(scope):
with paddle.fluid.program_guard(prog, startup_prog): with paddle.fluid.program_guard(prog, startup_prog):
with paddle.fluid.unique_name.guard(): with paddle.fluid.unique_name.guard():
yield with paddle.fluid.framework._dygraph_guard(None):
yield
def _get_valid_program(main_program): def _get_valid_program(main_program):
......
...@@ -1153,7 +1153,6 @@ class TestProgramStateOldSave(unittest.TestCase): ...@@ -1153,7 +1153,6 @@ class TestProgramStateOldSave(unittest.TestCase):
# make sure all the paramerter or optimizer var have been set to zero # make sure all the paramerter or optimizer var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0) self.assertTrue(np.sum(np.abs(new_t)) == 0)
#fluid.load(test_program, "./test_1", None )
program_state = fluid.load_program_state("test_program_1") program_state = fluid.load_program_state("test_program_1")
fluid.set_program_state(main_program, program_state) fluid.set_program_state(main_program, program_state)
...@@ -1164,6 +1163,11 @@ class TestProgramStateOldSave(unittest.TestCase): ...@@ -1164,6 +1163,11 @@ class TestProgramStateOldSave(unittest.TestCase):
base_t = base_map[var.name] base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t)) self.assertTrue(np.array_equal(new_t, base_t))
with fluid.dygraph.guard(place):
load_state = fluid.load_program_state("test_program_1")
for k, v in load_state.items():
self.assertTrue(np.array_equal(base_map[k], v))
class TestProgramStateOldSaveSingleModel(unittest.TestCase): class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self): def test_ptb_rnn_cpu_float32(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册