diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 30ded1f7eda295bab5567a082ba1fa3989b55fa2..9876fc620b870f47b10e9f99e4de34f5cb81fde1 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -25,7 +25,7 @@ import warnings from .. import core from .base import guard from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs -from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers +from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME __all__ = [ 'save_dygraph', @@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None): para_dict = dict() for var_name in persistable_var_dict: para_dict[var_name] = persistable_var_dict[var_name].numpy() + + # if __variables.info__ exists, we can recover structured_name + var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME) + if os.path.exists(var_info_path): + with open(var_info_path, 'rb') as f: + extra_var_info = pickle.load(f) + structured_para_dict = dict() + for var_name in para_dict: + structured_name = extra_var_info[var_name].get( + 'structured_name', None) + assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name + structured_para_dict[structured_name] = para_dict[var_name] + para_dict = structured_para_dict else: # Load state dict by `save_dygraph` save format para_dict = {} diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 7bf806bab557e7e84eb76e3a9876745e6107a5ab..f0680206de210a1090f04f5dfb8bf99f47839386 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -255,8 +255,11 @@ class TestJitSaveLoad(unittest.TestCase): train_layer.eval() # construct new model new_layer = LinearNet(784, 1) - model_dict, _ = fluid.dygraph.load_dygraph(self.model_path) - new_layer.set_dict(model_dict) + orig_state_dict = new_layer.state_dict() + load_state_dict, _ = fluid.dygraph.load_dygraph(self.model_path) + for structured_name in orig_state_dict: + self.assertTrue(structured_name in load_state_dict) + new_layer.set_state_dict(load_state_dict) new_layer.eval() # inference & compare x = fluid.dygraph.to_variable(