From c85c7b2264b402a7779beb6112098ed6418572bc Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Fri, 12 Jun 2020 10:13:11 +0800 Subject: [PATCH] Enable load program state in imperative mode (#24998) * enable load_program_state run in imperative mode; test=develop * remove useless code; test=develop --- python/paddle/fluid/io.py | 5 ++--- .../paddle/fluid/tests/unittests/test_static_save_load.py | 6 +++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 3e32c4ec55..fadd247e0d 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -198,7 +198,8 @@ def _load_program_scope(main=None, startup=None, scope=None): with paddle.fluid.scope_guard(scope): with paddle.fluid.program_guard(prog, startup_prog): with paddle.fluid.unique_name.guard(): - yield + with paddle.fluid.framework._dygraph_guard(None): + yield def _get_valid_program(main_program): @@ -663,7 +664,6 @@ def save_persistables(executor, dirname, main_program=None, filename=None): filename=filename) -@dygraph_not_support def load_vars(executor, dirname, main_program=None, @@ -1830,7 +1830,6 @@ def load(program, model_path, executor=None, var_list=None): set_var(v, load_dict[v.name]) -@dygraph_not_support def load_program_state(model_path, var_list=None): """ :api_attr: Static Graph diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index ac61ab756f..72c9922343 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1153,7 +1153,6 @@ class TestProgramStateOldSave(unittest.TestCase): # make sure all the paramerter or optimizer var have been set to zero self.assertTrue(np.sum(np.abs(new_t)) == 0) - #fluid.load(test_program, "./test_1", None ) program_state = fluid.load_program_state("test_program_1") fluid.set_program_state(main_program, program_state) @@ -1164,6 +1163,11 @@ class TestProgramStateOldSave(unittest.TestCase): base_t = base_map[var.name] self.assertTrue(np.array_equal(new_t, base_t)) + with fluid.dygraph.guard(place): + load_state = fluid.load_program_state("test_program_1") + for k, v in load_state.items(): + self.assertTrue(np.array_equal(base_map[k], v)) + class TestProgramStateOldSaveSingleModel(unittest.TestCase): def test_ptb_rnn_cpu_float32(self): -- GitLab