未验证 提交 a1486091 编写于 作者: C Chen Weihang 提交者: GitHub

Polish load_program_state design to loading file onebyone (#29041)

* change load dict file one by one to warn

* add unittests for coverage

* polish error message

* fix cond error
上级 a049dff7
...@@ -2025,35 +2025,63 @@ def load_program_state(model_path, var_list=None): ...@@ -2025,35 +2025,63 @@ def load_program_state(model_path, var_list=None):
None, None,
persistable=True) persistable=True)
def _load_vars_with_try_catch(exe,
dirname,
vars,
filename,
raise_error=True):
try:
load_vars(
executor=exe,
dirname=dirname,
vars=vars,
filename=filename)
return True
except:
error_str = "Failed to load model/variables `%s`, please make sure " \
"model/variables file is saved with the following APIs: " \
"save_params, save_persistables, save_vars."
filenames = [var.name for var in vars
] if filename is None else filename
if raise_error:
raise RuntimeError(error_str % filenames)
else:
warnings.warn(error_str % filenames, RuntimeWarning)
return False
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
loaded_var_list = [] loaded_var_list = []
if var_list is not None: if os.path.isfile(model_path):
# when model_path is file, var_list cannot be None
dir_name, file_name = os.path.split(model_path)
for var in var_list: for var in var_list:
loaded_var_list.append(clone_var_to_block(load_block, var)) loaded_var_list.append(clone_var_to_block(load_block, var))
_load_vars_with_try_catch(exe, dir_name, loaded_var_list,
file_name)
else: else:
for var_name in var_name_list: # var_list can be None or not None
loaded_var_list.append( if var_list is not None:
load_block.create_var( for var in var_list:
name=var_name, persistable=True)) loaded_var_list.append(
clone_var_to_block(load_block, var))
place = paddle.fluid.CPUPlace() _load_vars_with_try_catch(exe, model_path, loaded_var_list,
exe = paddle.fluid.Executor(place) None)
try:
if os.path.isfile(model_path):
dir_name, file_name = os.path.split(model_path)
else: else:
dir_name = model_path for var_name in var_name_list:
file_name = None # NOTE(chenweihang): If identify which files the user wants
load_vars( # to load from the disk, we load these variables one by one.
executor=exe, # If a file does not exist, we only warn the user that the
dirname=dir_name, # file may be an irrelevant file, but does not throw an error
vars=loaded_var_list, # to ensure that other legal variables can be loaded.
filename=file_name) temp_var = load_block.create_var(
except: name=var_name, persistable=True)
raise RuntimeError( if _load_vars_with_try_catch(exe, model_path,
"Failed to load model file , please make sure model file is saved with the " [temp_var], None, False):
"following APIs: save_params, save_persistables, save_vars") loaded_var_list.append(temp_var)
res_dict = {} res_dict = {}
for var in loaded_var_list: for var in loaded_var_list:
res_dict[var.name] = np.asarray(paddle.fluid.global_scope( res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
......
...@@ -1153,21 +1153,41 @@ class TestProgramStateOldSave(unittest.TestCase): ...@@ -1153,21 +1153,41 @@ class TestProgramStateOldSave(unittest.TestCase):
# make sure all the paramerter or optimizer var have been set to zero # make sure all the paramerter or optimizer var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0) self.assertTrue(np.sum(np.abs(new_t)) == 0)
# case 1: load basic
program_state = fluid.load_program_state("test_program_1") program_state = fluid.load_program_state("test_program_1")
fluid.set_program_state(main_program, program_state) fluid.set_program_state(main_program, program_state)
self.check_in_static(main_program, base_map)
# case 2: load with no need file
orig_filepath = './test_program_1/fc_0.w_0'
symlink_filepath = './test_program_1/link_fc_0.w_0'
if os.path.exists(symlink_filepath):
os.remove(symlink_filepath)
os.symlink(orig_filepath, symlink_filepath)
program_state = fluid.load_program_state("test_program_1")
fluid.set_program_state(main_program, program_state)
self.check_in_static(main_program, base_map)
for var in main_program.list_vars(): # case 3: load with var_list
if isinstance(var, framework.Parameter) or var.persistable: program_state = fluid.load_program_state(
new_t = np.array(fluid.global_scope().find_var(var.name) "test_program_1", main_program.all_parameters())
.get_tensor()) fluid.set_program_state(main_program, program_state)
base_t = base_map[var.name] self.check_in_static(main_program, base_map)
self.assertTrue(np.array_equal(new_t, base_t))
# make sure `load_program_state` can be used in dynamic graph mode
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
load_state = fluid.load_program_state("test_program_1") load_state = fluid.load_program_state("test_program_1")
for k, v in load_state.items(): for k, v in load_state.items():
self.assertTrue(np.array_equal(base_map[k], v)) self.assertTrue(np.array_equal(base_map[k], v))
def check_in_static(self, main_program, base_map):
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStateOldSaveSingleModel(unittest.TestCase): class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self): def test_ptb_rnn_cpu_float32(self):
...@@ -1301,4 +1321,5 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase): ...@@ -1301,4 +1321,5 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册