未验证 提交 c0fb03a0 编写于 作者: W WeiXin 提交者: GitHub
上级 9fec1618
...@@ -2180,6 +2180,7 @@ def load_program_state(model_path, var_list=None): ...@@ -2180,6 +2180,7 @@ def load_program_state(model_path, var_list=None):
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt" opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name): if os.path.exists(opt_file_name):
...@@ -2231,6 +2232,7 @@ def set_program_state(program, state_dict): ...@@ -2231,6 +2232,7 @@ def set_program_state(program, state_dict):
static.set_program_state(prog, program_state) static.set_program_state(prog, program_state)
""" """
state_dict = _pack_loaded_dict(state_dict)
parameter_list = list(filter(is_persistable, program.list_vars())) parameter_list = list(filter(is_persistable, program.list_vars()))
used_para_list = {} used_para_list = {}
......
...@@ -1365,6 +1365,25 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): ...@@ -1365,6 +1365,25 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
base_t = base_map[var.name] base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t)) self.assertTrue(np.array_equal(new_t, base_t))
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
program_state = fluid.load_program_state(path)
fluid.set_program_state(prog, program_state)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStateOldSaveSingleModel(unittest.TestCase): class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self): def test_ptb_rnn_cpu_float32(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册