未验证 提交 a1486091 编写于 作者: C Chen Weihang 提交者: GitHub

Polish load_program_state design to loading file onebyone (#29041)

* change load dict file one by one to warn

* add unittests for coverage

* polish error message

* fix cond error
上级 a049dff7
......@@ -2025,35 +2025,63 @@ def load_program_state(model_path, var_list=None):
None,
persistable=True)
def _load_vars_with_try_catch(exe,
dirname,
vars,
filename,
raise_error=True):
try:
load_vars(
executor=exe,
dirname=dirname,
vars=vars,
filename=filename)
return True
except:
error_str = "Failed to load model/variables `%s`, please make sure " \
"model/variables file is saved with the following APIs: " \
"save_params, save_persistables, save_vars."
filenames = [var.name for var in vars
] if filename is None else filename
if raise_error:
raise RuntimeError(error_str % filenames)
else:
warnings.warn(error_str % filenames, RuntimeWarning)
return False
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
loaded_var_list = []
if var_list is not None:
if os.path.isfile(model_path):
# when model_path is file, var_list cannot be None
dir_name, file_name = os.path.split(model_path)
for var in var_list:
loaded_var_list.append(clone_var_to_block(load_block, var))
_load_vars_with_try_catch(exe, dir_name, loaded_var_list,
file_name)
else:
for var_name in var_name_list:
loaded_var_list.append(
load_block.create_var(
name=var_name, persistable=True))
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
try:
if os.path.isfile(model_path):
dir_name, file_name = os.path.split(model_path)
# var_list can be None or not None
if var_list is not None:
for var in var_list:
loaded_var_list.append(
clone_var_to_block(load_block, var))
_load_vars_with_try_catch(exe, model_path, loaded_var_list,
None)
else:
dir_name = model_path
file_name = None
load_vars(
executor=exe,
dirname=dir_name,
vars=loaded_var_list,
filename=file_name)
except:
raise RuntimeError(
"Failed to load model file , please make sure model file is saved with the "
"following APIs: save_params, save_persistables, save_vars")
for var_name in var_name_list:
# NOTE(chenweihang): If identify which files the user wants
# to load from the disk, we load these variables one by one.
# If a file does not exist, we only warn the user that the
# file may be an irrelevant file, but does not throw an error
# to ensure that other legal variables can be loaded.
temp_var = load_block.create_var(
name=var_name, persistable=True)
if _load_vars_with_try_catch(exe, model_path,
[temp_var], None, False):
loaded_var_list.append(temp_var)
res_dict = {}
for var in loaded_var_list:
res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
......
......@@ -1153,21 +1153,41 @@ class TestProgramStateOldSave(unittest.TestCase):
# make sure all the paramerter or optimizer var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
# case 1: load basic
program_state = fluid.load_program_state("test_program_1")
fluid.set_program_state(main_program, program_state)
self.check_in_static(main_program, base_map)
# case 2: load with no need file
orig_filepath = './test_program_1/fc_0.w_0'
symlink_filepath = './test_program_1/link_fc_0.w_0'
if os.path.exists(symlink_filepath):
os.remove(symlink_filepath)
os.symlink(orig_filepath, symlink_filepath)
program_state = fluid.load_program_state("test_program_1")
fluid.set_program_state(main_program, program_state)
self.check_in_static(main_program, base_map)
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
# case 3: load with var_list
program_state = fluid.load_program_state(
"test_program_1", main_program.all_parameters())
fluid.set_program_state(main_program, program_state)
self.check_in_static(main_program, base_map)
# make sure `load_program_state` can be used in dynamic graph mode
with fluid.dygraph.guard(place):
load_state = fluid.load_program_state("test_program_1")
for k, v in load_state.items():
self.assertTrue(np.array_equal(base_map[k], v))
def check_in_static(self, main_program, base_map):
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
......@@ -1301,4 +1321,5 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册