From a1486091f1d411390d268d74d6ad7b706e5e9637 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 25 Nov 2020 20:41:30 +0800 Subject: [PATCH] Polish load_program_state design to loading file onebyone (#29041) * change load dict file one by one to warn * add unittests for coverage * polish error message * fix cond error --- python/paddle/fluid/io.py | 74 +++++++++++++------ .../tests/unittests/test_static_save_load.py | 33 +++++++-- 2 files changed, 78 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 58601fb5851..ebaa145d400 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -2025,35 +2025,63 @@ def load_program_state(model_path, var_list=None): None, persistable=True) + def _load_vars_with_try_catch(exe, + dirname, + vars, + filename, + raise_error=True): + try: + load_vars( + executor=exe, + dirname=dirname, + vars=vars, + filename=filename) + return True + except: + error_str = "Failed to load model/variables `%s`, please make sure " \ + "model/variables file is saved with the following APIs: " \ + "save_params, save_persistables, save_vars." + filenames = [var.name for var in vars + ] if filename is None else filename + if raise_error: + raise RuntimeError(error_str % filenames) + else: + warnings.warn(error_str % filenames, RuntimeWarning) + return False + + place = paddle.fluid.CPUPlace() + exe = paddle.fluid.Executor(place) + loaded_var_list = [] - if var_list is not None: + if os.path.isfile(model_path): + # when model_path is file, var_list cannot be None + dir_name, file_name = os.path.split(model_path) for var in var_list: loaded_var_list.append(clone_var_to_block(load_block, var)) + _load_vars_with_try_catch(exe, dir_name, loaded_var_list, + file_name) else: - for var_name in var_name_list: - loaded_var_list.append( - load_block.create_var( - name=var_name, persistable=True)) - - place = paddle.fluid.CPUPlace() - exe = paddle.fluid.Executor(place) - - try: - if os.path.isfile(model_path): - dir_name, file_name = os.path.split(model_path) + # var_list can be None or not None + if var_list is not None: + for var in var_list: + loaded_var_list.append( + clone_var_to_block(load_block, var)) + _load_vars_with_try_catch(exe, model_path, loaded_var_list, + None) else: - dir_name = model_path - file_name = None - load_vars( - executor=exe, - dirname=dir_name, - vars=loaded_var_list, - filename=file_name) - except: - raise RuntimeError( - "Failed to load model file , please make sure model file is saved with the " - "following APIs: save_params, save_persistables, save_vars") + for var_name in var_name_list: + # NOTE(chenweihang): If identify which files the user wants + # to load from the disk, we load these variables one by one. + # If a file does not exist, we only warn the user that the + # file may be an irrelevant file, but does not throw an error + # to ensure that other legal variables can be loaded. + temp_var = load_block.create_var( + name=var_name, persistable=True) + if _load_vars_with_try_catch(exe, model_path, + [temp_var], None, False): + loaded_var_list.append(temp_var) + res_dict = {} for var in loaded_var_list: res_dict[var.name] = np.asarray(paddle.fluid.global_scope( diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 72c99223433..baab747c57e 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1153,21 +1153,41 @@ class TestProgramStateOldSave(unittest.TestCase): # make sure all the paramerter or optimizer var have been set to zero self.assertTrue(np.sum(np.abs(new_t)) == 0) + # case 1: load basic program_state = fluid.load_program_state("test_program_1") fluid.set_program_state(main_program, program_state) + self.check_in_static(main_program, base_map) + + # case 2: load with no need file + orig_filepath = './test_program_1/fc_0.w_0' + symlink_filepath = './test_program_1/link_fc_0.w_0' + if os.path.exists(symlink_filepath): + os.remove(symlink_filepath) + os.symlink(orig_filepath, symlink_filepath) + program_state = fluid.load_program_state("test_program_1") + fluid.set_program_state(main_program, program_state) + self.check_in_static(main_program, base_map) - for var in main_program.list_vars(): - if isinstance(var, framework.Parameter) or var.persistable: - new_t = np.array(fluid.global_scope().find_var(var.name) - .get_tensor()) - base_t = base_map[var.name] - self.assertTrue(np.array_equal(new_t, base_t)) + # case 3: load with var_list + program_state = fluid.load_program_state( + "test_program_1", main_program.all_parameters()) + fluid.set_program_state(main_program, program_state) + self.check_in_static(main_program, base_map) + # make sure `load_program_state` can be used in dynamic graph mode with fluid.dygraph.guard(place): load_state = fluid.load_program_state("test_program_1") for k, v in load_state.items(): self.assertTrue(np.array_equal(base_map[k], v)) + def check_in_static(self, main_program, base_map): + for var in main_program.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + class TestProgramStateOldSaveSingleModel(unittest.TestCase): def test_ptb_rnn_cpu_float32(self): @@ -1301,4 +1321,5 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab