From d72db90194a640f78aa24db5ff9ad192e0ff0fd7 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 19 Apr 2019 20:29:58 +0800 Subject: [PATCH] Fix dygraph save load problem test=release/1.4 --- python/paddle/fluid/dygraph/checkpoint.py | 67 ++++++++----------- python/paddle/fluid/dygraph/layers.py | 10 ++- .../unittests/test_imperative_checkpoint.py | 11 ++- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 9d29b5317..f96b53e8c 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None): _save_var_to_file(vardict, dirname, filename) -def load_persistables(vardict, dirname, filename=None): +def load_persistables(dirname): """ This function trys to load persistable variables from the folder `dirname` or the file `filename`. @@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None): the file name. Args: - vardict(dict of Parameters): The parameters will be loaded. dirname(str): The directory path. - filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were - saved in differnet files, set it to None. - Default: None Returns: dict: The parameter-dict resumed from file @@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None): param_1 = param_dict['PtbModel_0.w_1'] """ - if isinstance(vardict, collections.OrderedDict): - return _load_var_from_file(vardict, dirname, filename) - - return {} + return _load_var_from_file(dirname) def _save_var_to_file(stat_dict, file_dir, file_name): @@ -139,41 +132,39 @@ def _save_var_to_file(stat_dict, file_dir, file_name): }) -def _load_var_from_file(stat_dict, file_dir, file_name): +def _load_var_from_file(file_dir): + def walk_filename(file_dir): + base_path = os.path.join(file_dir) + var_name_list = [] + if os.path.exists(base_path): + for dirpath, dirnames, filenames in os.walk(base_path): + pt = dirpath.replace(base_path, "", 1) + if pt.startswith("/") or pt.startswith("\\"): + pt = pt[1:] + for fth_name in filenames: + if fth_name[0] != '.': + name_path = os.path.join(pt, fth_name) + if "\\" in name_path: + name_path = name_path.replace("\\", "/") + var_name_list.append(name_path) + + return var_name_list + load_block = default_main_program().global_block() load_var_map = {} - - for var_key, each_var in stat_dict.items(): - assert isinstance(each_var, Variable) - if each_var.type == core.VarDesc.VarType.RAW: - continue - new_var = _clone_var_in_block_(load_block, each_var) - if file_name is None: - load_block.append_op( - type='load', - inputs={}, - outputs={'Out': [new_var]}, - attrs={ - 'file_path': os.path.join(file_dir, - os.path.normpath(each_var.name)) - }) - - load_var_map[new_var.name] = new_var - - if file_name is not None: - load_var_list = [] - for name in sorted(load_var_map.keys()): - load_var_list.append(load_var_map[name]) - + file_var_list = walk_filename(file_dir) + for var_name in file_var_list: + new_var = Variable(block=load_block, name=var_name) load_block.append_op( - type='load_combine', + type='load', inputs={}, - outputs={"Out": load_var_list}, + outputs={'Out': [new_var]}, attrs={ - 'file_path': os.path.join(file_dir, os.path.normpath(file_name)) + 'file_path': os.path.join(file_dir, + os.path.normpath(new_var.name)) }) - for res_var in load_var_list: - load_var_map[res_var.name] = res_var + + load_var_map[new_var.name] = new_var return load_var_map diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index fb0604ebc..f856bdfa7 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -45,6 +45,7 @@ class Layer(core.Layer): self._dtype = dtype self._parameters = collections.OrderedDict() self._sub_layers = collections.OrderedDict() + self._loaddict_holder = collections.OrderedDict() self._helper = LayerObjectHelper(self._full_name) @@ -193,6 +194,9 @@ class Layer(core.Layer): """ assert isinstance(parameter, framework.Parameter) self._parameters[name] = parameter + if parameter.name in self._loaddict_holder: + self._parameters[name] = self._loaddict_holder[parameter.name] + parameter = self._loaddict_holder[parameter.name] return parameter def __getattr__(self, name): @@ -207,7 +211,10 @@ class Layer(core.Layer): if params is None: raise ValueError( "super(YourLayer, self).__init__() should be called first") - params[name] = value + if value.name in self._loaddict_holder: + params[name] = self._loaddict_holder[value.name] + else: + params[name] = value elif isinstance(value, core.Layer): layers = self.__dict__.get('_sub_layers', None) if layers is None: @@ -244,6 +251,7 @@ class Layer(core.Layer): return destination def load_dict(self, stat_dict, include_sublayers=True): + self._loaddict_holder = stat_dict for name, item in self.__dict__.get('_parameters', None).items(): if item.name in stat_dict: var = item._ivar.value() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py b/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py index 000659174..889e7c0fa 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py @@ -142,14 +142,11 @@ class TestDygraphCheckpoint(unittest.TestCase): for param in mnist.parameters(): dy_param_init_value[param.name] = param.numpy() - mnist.load_dict( - fluid.dygraph.load_persistables(mnist.state_dict(), - "save_dir")) - - restore = mnist.parameters() + restore = fluid.dygraph.load_persistables("save_dir") + mnist.load_dict(restore) self.assertEqual(len(dy_param_init_value), len(restore)) - for value in restore: + for ky, value in restore.items(): self.assertTrue( np.allclose(value.numpy(), dy_param_init_value[ value.name])) @@ -158,7 +155,7 @@ class TestDygraphCheckpoint(unittest.TestCase): step += 1 - if step > 20: + if step > 10: break -- GitLab