提交 a3f17280 编写于 作者: L lujun

fix dy-load bug, test=develop

上级 f0d48e2a
...@@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None): ...@@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None):
_save_var_to_file(vardict, dirname, filename) _save_var_to_file(vardict, dirname, filename)
def load_persistables(vardict, dirname, filename=None): def load_persistables(dirname):
""" """
This function trys to load persistable variables from the folder This function trys to load persistable variables from the folder
`dirname` or the file `filename`. `dirname` or the file `filename`.
...@@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None): ...@@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None):
the file name. the file name.
Args: Args:
vardict(dict of Parameters): The parameters will be loaded.
dirname(str): The directory path. dirname(str): The directory path.
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
saved in differnet files, set it to None.
Default: None
Returns: Returns:
dict: The parameter-dict resumed from file dict: The parameter-dict resumed from file
...@@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None): ...@@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None):
param_1 = param_dict['PtbModel_0.w_1'] param_1 = param_dict['PtbModel_0.w_1']
""" """
if isinstance(vardict, collections.OrderedDict): return _load_var_from_file(dirname)
return _load_var_from_file(vardict, dirname, filename)
return {}
def _save_var_to_file(stat_dict, file_dir, file_name): def _save_var_to_file(stat_dict, file_dir, file_name):
...@@ -139,41 +132,37 @@ def _save_var_to_file(stat_dict, file_dir, file_name): ...@@ -139,41 +132,37 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
}) })
def _load_var_from_file(stat_dict, file_dir, file_name): def _load_var_from_file(file_dir):
load_block = default_main_program().global_block()
load_var_map = {}
for var_key, each_var in stat_dict.items(): def walk_filename(file_dir):
assert isinstance(each_var, Variable) var_name_list = []
if each_var.type == core.VarDesc.VarType.RAW: if os.path.exists(file_dir) and os.path.exists(os.path.join(file_dir)):
continue base_path = os.path.join(file_dir)
new_var = _clone_var_in_block_(load_block, each_var) for dirpath, dirnames, filenames in os.walk(os.path.join(file_dir)):
if file_name is None: pt = dirpath.replace(base_path, "", 1)[1:]
load_block.append_op( for fth_name in filenames:
type='load', if fth_name[0] != '.':
inputs={}, var_name_list.append(os.path.join(pt, fth_name))
outputs={'Out': [new_var]},
attrs={
'file_path': os.path.join(file_dir,
os.path.normpath(each_var.name))
})
load_var_map[new_var.name] = new_var return var_name_list
if file_name is not None: load_block = default_main_program().global_block()
load_var_list = [] load_var_map = {}
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name]) file_var_list = walk_filename(file_dir)
for var_name in file_var_list:
new_var = Variable(block=load_block, name=var_name)
load_block.append_op( load_block.append_op(
type='load_combine', type='load',
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={'Out': [new_var]},
attrs={ attrs={
'file_path': os.path.join(file_dir, os.path.normpath(file_name)) 'file_path': os.path.join(file_dir,
os.path.normpath(new_var.name))
}) })
for res_var in load_var_list:
load_var_map[res_var.name] = res_var load_var_map[new_var.name] = new_var
return load_var_map return load_var_map
......
...@@ -45,6 +45,7 @@ class Layer(core.Layer): ...@@ -45,6 +45,7 @@ class Layer(core.Layer):
self._dtype = dtype self._dtype = dtype
self._parameters = collections.OrderedDict() self._parameters = collections.OrderedDict()
self._sub_layers = collections.OrderedDict() self._sub_layers = collections.OrderedDict()
self._loaddict_holder = collections.OrderedDict()
self._helper = LayerObjectHelper(self._full_name) self._helper = LayerObjectHelper(self._full_name)
...@@ -193,6 +194,9 @@ class Layer(core.Layer): ...@@ -193,6 +194,9 @@ class Layer(core.Layer):
""" """
assert isinstance(parameter, framework.Parameter) assert isinstance(parameter, framework.Parameter)
self._parameters[name] = parameter self._parameters[name] = parameter
if parameter.name in self._loaddict_holder:
self._parameters[name] = self._loaddict_holder[parameter.name]
parameter = self._loaddict_holder[parameter.name]
return parameter return parameter
def __getattr__(self, name): def __getattr__(self, name):
...@@ -207,7 +211,10 @@ class Layer(core.Layer): ...@@ -207,7 +211,10 @@ class Layer(core.Layer):
if params is None: if params is None:
raise ValueError( raise ValueError(
"super(YourLayer, self).__init__() should be called first") "super(YourLayer, self).__init__() should be called first")
params[name] = value if value.name in self._loaddict_holder:
params[name] = self._loaddict_holder[value.name]
else:
params[name] = value
elif isinstance(value, core.Layer): elif isinstance(value, core.Layer):
layers = self.__dict__.get('_sub_layers', None) layers = self.__dict__.get('_sub_layers', None)
if layers is None: if layers is None:
...@@ -244,6 +251,7 @@ class Layer(core.Layer): ...@@ -244,6 +251,7 @@ class Layer(core.Layer):
return destination return destination
def load_dict(self, stat_dict, include_sublayers=True): def load_dict(self, stat_dict, include_sublayers=True):
self._loaddict_holder = stat_dict
for name, item in self.__dict__.get('_parameters', None).items(): for name, item in self.__dict__.get('_parameters', None).items():
if item.name in stat_dict: if item.name in stat_dict:
var = item._ivar.value() var = item._ivar.value()
......
...@@ -142,14 +142,11 @@ class TestDygraphCheckpoint(unittest.TestCase): ...@@ -142,14 +142,11 @@ class TestDygraphCheckpoint(unittest.TestCase):
for param in mnist.parameters(): for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy() dy_param_init_value[param.name] = param.numpy()
mnist.load_dict( restore = fluid.dygraph.load_persistables("save_dir")
fluid.dygraph.load_persistables(mnist.state_dict(), mnist.load_dict(restore)
"save_dir"))
restore = mnist.parameters()
self.assertEqual(len(dy_param_init_value), len(restore)) self.assertEqual(len(dy_param_init_value), len(restore))
for value in restore: for ky, value in restore.items():
self.assertTrue( self.assertTrue(
np.allclose(value.numpy(), dy_param_init_value[ np.allclose(value.numpy(), dy_param_init_value[
value.name])) value.name]))
...@@ -158,7 +155,7 @@ class TestDygraphCheckpoint(unittest.TestCase): ...@@ -158,7 +155,7 @@ class TestDygraphCheckpoint(unittest.TestCase):
step += 1 step += 1
if step > 20: if step > 10:
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册