未验证 提交 dbf66dd0 编写于 作者: L lujun 提交者: GitHub

Merge pull request #16954 from junjun315/fix-dygraph-checkpoint

Fix dygraph checkpoint bug
...@@ -464,7 +464,11 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext ...@@ -464,7 +464,11 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
void SetType(const std::string& name, void SetType(const std::string& name,
framework::proto::VarType::Type type) override { framework::proto::VarType::Type type) override {
var_set_[name]->SetType(type); if (name == "kLookupTablePath") {
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
} else {
var_set_[name]->SetType(type);
}
} }
framework::proto::VarType::Type GetDataType( framework::proto::VarType::Type GetDataType(
......
...@@ -113,14 +113,17 @@ def load_persistables(vardict, dirname, filename=None): ...@@ -113,14 +113,17 @@ def load_persistables(vardict, dirname, filename=None):
def _save_var_to_file(stat_dict, file_dir, file_name): def _save_var_to_file(stat_dict, file_dir, file_name):
save_block = default_main_program().global_block() save_block = default_main_program().global_block()
save_var_map = {} save_var_map = {}
for each_var in stat_dict.items(): for var_key, each_var in stat_dict.items():
save_var_map[each_var.name] = each_var save_var_map[each_var.name] = each_var
if file_name is None: if file_name is None:
save_block.append_op( save_block.append_op(
type='save', type='save',
inputs={'X': [each_var]}, inputs={'X': [each_var]},
outputs={}, outputs={},
attrs={'file_path': os.path.join(file_dir, each_var.name)}) attrs={
'file_path': os.path.join(file_dir,
os.path.normpath(each_var.name))
})
if file_name is not None: if file_name is not None:
save_var_list = [] save_var_list = []
...@@ -131,14 +134,16 @@ def _save_var_to_file(stat_dict, file_dir, file_name): ...@@ -131,14 +134,16 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
type='save_combine', type='save_combine',
inputs={'X': save_var_list}, inputs={'X': save_var_list},
outputs={}, outputs={},
attrs={'file_path': os.path.join(file_dir, file_name)}) attrs={
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
})
def _load_var_from_file(stat_dict, file_dir, file_name): def _load_var_from_file(stat_dict, file_dir, file_name):
load_block = default_main_program().global_block() load_block = default_main_program().global_block()
load_var_map = {} load_var_map = {}
for each_var in stat_dict.items(): for var_key, each_var in stat_dict.items():
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
...@@ -148,7 +153,10 @@ def _load_var_from_file(stat_dict, file_dir, file_name): ...@@ -148,7 +153,10 @@ def _load_var_from_file(stat_dict, file_dir, file_name):
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': [new_var]}, outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(file_dir, each_var.name)}) attrs={
'file_path': os.path.join(file_dir,
os.path.normpath(each_var.name))
})
load_var_map[new_var.name] = new_var load_var_map[new_var.name] = new_var
...@@ -161,7 +169,9 @@ def _load_var_from_file(stat_dict, file_dir, file_name): ...@@ -161,7 +169,9 @@ def _load_var_from_file(stat_dict, file_dir, file_name):
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(file_dir, file_name)}) attrs={
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
})
for res_var in load_var_list: for res_var in load_var_list:
load_var_map[res_var.name] = res_var load_var_map[res_var.name] = res_var
...@@ -175,5 +185,5 @@ def _clone_var_in_block_(block, var): ...@@ -175,5 +185,5 @@ def _clone_var_in_block_(block, var):
shape=var.shape, shape=var.shape,
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
lod_level=var.lod_level, lod_level=0,
persistable=True) persistable=True)
...@@ -246,7 +246,10 @@ class Layer(core.Layer): ...@@ -246,7 +246,10 @@ class Layer(core.Layer):
def load_dict(self, stat_dict, include_sublayers=True): def load_dict(self, stat_dict, include_sublayers=True):
for name, item in self.__dict__.get('_parameters', None).items(): for name, item in self.__dict__.get('_parameters', None).items():
if item.name in stat_dict: if item.name in stat_dict:
self.__setattr__(name, stat_dict[item.name]) var = item._ivar.value()
tensor = var.get_tensor()
tensor.set(stat_dict[item.name].numpy(),
framework._current_expected_place())
if include_sublayers: if include_sublayers:
for layer_name, layer_item in self._sub_layers.items(): for layer_name, layer_item in self._sub_layers.items():
......
...@@ -99,7 +99,7 @@ class MNIST(fluid.Layer): ...@@ -99,7 +99,7 @@ class MNIST(fluid.Layer):
class TestDygraphCheckpoint(unittest.TestCase): class TestDygraphCheckpoint(unittest.TestCase):
def save_load_persistables(self): def test_save_load_persistables(self):
seed = 90 seed = 90
epoch_num = 1 epoch_num = 1
...@@ -135,23 +135,26 @@ class TestDygraphCheckpoint(unittest.TestCase): ...@@ -135,23 +135,26 @@ class TestDygraphCheckpoint(unittest.TestCase):
avg_loss.backward() avg_loss.backward()
sgd.minimize(avg_loss) sgd.minimize(avg_loss)
fluid.dygraph.save_persistables(mnist, "save_dir") fluid.dygraph.save_persistables(mnist.state_dict(),
"save_dir")
mnist.clear_gradients() mnist.clear_gradients()
for param in mnist.parameters(): for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy() dy_param_init_value[param.name] = param.numpy()
mnist.load_dict( mnist.load_dict(
fluid.dygraph.load_persistables(mnist, "save_dir")) fluid.dygraph.load_persistables(mnist.state_dict(),
"save_dir"))
restore = mnist.parameters() restore = mnist.parameters()
self.assertEqual(len(dy_param_init_value), len(restore)) self.assertEqual(len(dy_param_init_value), len(restore))
for value in restore: for value in restore:
self.assertTrue( self.assertTrue(
np.allclose(value, dy_param_init_value[value.name])) np.allclose(value.numpy(), dy_param_init_value[
self.assertTrue(np.isfinite(value.all())) value.name]))
self.assertFalse(np.isnan(value.any())) self.assertTrue(np.isfinite(value.numpy().all()))
self.assertFalse(np.isnan(value.numpy().any()))
step += 1 step += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册