未验证 提交 ac8afe18 编写于 作者: C Chen Weihang 提交者: GitHub

use structured name in loaded dict (#27242)

上级 5e0dde02
...@@ -25,7 +25,7 @@ import warnings ...@@ -25,7 +25,7 @@ import warnings
from .. import core from .. import core
from .base import guard from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME
__all__ = [ __all__ = [
'save_dygraph', 'save_dygraph',
...@@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None): ...@@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None):
para_dict = dict() para_dict = dict()
for var_name in persistable_var_dict: for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy() para_dict[var_name] = persistable_var_dict[var_name].numpy()
# if __variables.info__ exists, we can recover structured_name
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
structured_para_dict = dict()
for var_name in para_dict:
structured_name = extra_var_info[var_name].get(
'structured_name', None)
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
structured_para_dict[structured_name] = para_dict[var_name]
para_dict = structured_para_dict
else: else:
# Load state dict by `save_dygraph` save format # Load state dict by `save_dygraph` save format
para_dict = {} para_dict = {}
......
...@@ -255,8 +255,11 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -255,8 +255,11 @@ class TestJitSaveLoad(unittest.TestCase):
train_layer.eval() train_layer.eval()
# construct new model # construct new model
new_layer = LinearNet(784, 1) new_layer = LinearNet(784, 1)
model_dict, _ = fluid.dygraph.load_dygraph(self.model_path) orig_state_dict = new_layer.state_dict()
new_layer.set_dict(model_dict) load_state_dict, _ = fluid.dygraph.load_dygraph(self.model_path)
for structured_name in orig_state_dict:
self.assertTrue(structured_name in load_state_dict)
new_layer.set_state_dict(load_state_dict)
new_layer.eval() new_layer.eval()
# inference & compare # inference & compare
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册