diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 9876fc620b870f47b10e9f99e4de34f5cb81fde1..93cb0bafc847b897816636f92255bd06b7e67321 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -195,58 +195,11 @@ def load_dygraph(model_path, config=None): params_file_path = model_prefix + ".pdparams" opti_file_path = model_prefix + ".pdopt" - # deal with argument `configs` - configs = config - if configs is None: - configs = SaveLoadConfig() - - if not os.path.exists(params_file_path) and not os.path.exists( - opti_file_path): - # Load state dict by `jit.save/io.save_inference_model` save format - # NOTE(chenweihang): [ Compatibility of save_inference_model save format ] - # The model saved by `save_inference_model` does not completely correspond to - # the information required by the `state_dict` under the dygraph. - # `save_inference_model` not save structured name, we need to remind - # the user to configure the `use_structured_name` argument when `set_state_dict` - # NOTE(chenweihang): `jit.save` doesn't save optimizer state - - # 1. check model path - if not os.path.isdir(model_prefix): - raise ValueError("Model saved directory '%s' is not exists." % - model_prefix) + # deal with argument `config` + if config is None: + config = SaveLoadConfig() - # 2. load program desc & construct _ProgramHolder - programs = _construct_program_holders(model_path, - configs.model_filename) - - # 3. load layer parameters & buffers - # NOTE: using fluid.dygraph.guard() here will cause import error in py2 - with guard(): - persistable_var_dict = _construct_params_and_buffers( - model_prefix, - programs, - configs.separate_params, - configs.params_filename, - append_suffix=False) - - # 4. construct state_dict - para_dict = dict() - for var_name in persistable_var_dict: - para_dict[var_name] = persistable_var_dict[var_name].numpy() - - # if __variables.info__ exists, we can recover structured_name - var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME) - if os.path.exists(var_info_path): - with open(var_info_path, 'rb') as f: - extra_var_info = pickle.load(f) - structured_para_dict = dict() - for var_name in para_dict: - structured_name = extra_var_info[var_name].get( - 'structured_name', None) - assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name - structured_para_dict[structured_name] = para_dict[var_name] - para_dict = structured_para_dict - else: + if os.path.exists(params_file_path) or os.path.exists(opti_file_path): # Load state dict by `save_dygraph` save format para_dict = {} if os.path.exists(params_file_path): @@ -254,12 +207,103 @@ def load_dygraph(model_path, config=None): para_dict = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') - if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict: + if not config.keep_name_table and "StructuredToParameterName@@" in para_dict: del para_dict["StructuredToParameterName@@"] if os.path.exists(opti_file_path): with open(opti_file_path, 'rb') as f: opti_dict = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') + else: + # check model path + if not os.path.isdir(model_prefix): + raise ValueError("Model saved directory '%s' is not exists." % + model_prefix) + + # check whether model file exists + if config.model_filename is None: + model_filename = '__model__' + else: + model_filename = config.model_filename + model_file_path = os.path.join(model_path, model_filename) + + if os.path.exists(model_file_path): + # Load state dict by `jit.save/io.save_inference_model` save format + # NOTE(chenweihang): [ Compatibility of save_inference_model save format ] + # The model saved by `save_inference_model` does not completely correspond to + # the information required by the `state_dict` under the dygraph. + # `save_inference_model` not save structured name, we need to remind + # the user to configure the `use_structured_name` argument when `set_state_dict` + # NOTE(chenweihang): `jit.save` doesn't save optimizer state + + # 1. load program desc & construct _ProgramHolder + programs = _construct_program_holders(model_path, + config.model_filename) + + # 2. load layer parameters & buffers + # NOTE: using fluid.dygraph.guard() here will cause import error in py2 + with guard(): + persistable_var_dict = _construct_params_and_buffers( + model_prefix, + programs, + config.separate_params, + config.params_filename, + append_suffix=False) + + # 3. construct state_dict + para_dict = dict() + for var_name in persistable_var_dict: + para_dict[var_name] = persistable_var_dict[var_name].numpy() + + # if __variables.info__ exists, we can recover structured_name + var_info_path = os.path.join(model_prefix, + EXTRA_VAR_INFO_FILENAME) + if os.path.exists(var_info_path): + with open(var_info_path, 'rb') as f: + extra_var_info = pickle.load(f) + structured_para_dict = dict() + for var_name in para_dict: + structured_name = extra_var_info[var_name].get( + 'structured_name', None) + assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name + structured_para_dict[structured_name] = para_dict[ + var_name] + para_dict = structured_para_dict + else: + # load state dict by `io.save_params/persistables` save format + # TODO(chenweihang): [ Now only supports loading parameters seperately ] + # If users save all parameters as one file, the [ variable.name -> variable ] + # mapping info will lost, so users need to give variable list, but users build + # variable list in dygraph mode is difficult, we recommend users to use + # paddle.io.load_program_state in this case + + # Try to load all the files in the directory in VarBase format, + # the file name is used as the name of VarBase + load_var_list = [] + + # 1. load file names + var_name_list = [] + for root, _, files in os.walk(model_path): + for filename in files: + file_path = os.path.join(root, filename) + tmp_var_name = os.path.relpath(file_path, model_path) + var_name = tmp_var_name.replace("\\", "/") + var_name_list.append(var_name) + + # 2. create and load VarBase + with guard(): + for name in var_name_list: + new_var = _varbase_creator(name=name, persistable=True) + _dygraph_tracer().trace_op( + type='load', + inputs={}, + outputs={'Out': new_var}, + attrs={'file_path': os.path.join(model_path, name)}) + load_var_list.append(new_var) + + # 3. construct state_dict + para_dict = dict() + for var in load_var_list: + para_dict[var.name] = var.numpy() return para_dict, opti_dict diff --git a/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py b/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py index ed1939dbe279f28883d9e33178f1cfa256140e33..a1a9b3f444fa411f90e869f5265fa0933393ff56 100644 --- a/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py +++ b/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py @@ -64,7 +64,7 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): self.batch_size = 128 self.batch_num = 10 - def train_and_save_model(self): + def train_and_save_model(self, only_params=False): with new_program_scope(): startup_program = fluid.default_startup_program() main_program = fluid.default_main_program() @@ -102,11 +102,15 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): static_param_dict[param.name] = fluid.executor._fetch_var( param.name) - fluid.io.save_inference_model( - self.save_dirname, ["img"], [prediction], - exe, - model_filename=self.model_filename, - params_filename=self.params_filename) + if only_params: + fluid.io.save_params( + exe, self.save_dirname, filename=self.params_filename) + else: + fluid.io.save_inference_model( + self.save_dirname, ["img"], [prediction], + exe, + model_filename=self.model_filename, + params_filename=self.params_filename) return static_param_dict @@ -120,9 +124,7 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): self.params_filename = None orig_param_dict = self.train_and_save_model() - configs = paddle.SaveLoadConfig() - configs.separate_params = True - load_param_dict, _ = paddle.load(self.save_dirname, configs) + load_param_dict, _ = paddle.load(self.save_dirname) self.check_load_state_dict(orig_param_dict, load_param_dict) def test_load_with_model_filename(self): @@ -160,6 +162,14 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): load_param_dict, _ = paddle.load(self.save_dirname, configs) self.check_load_state_dict(orig_param_dict, load_param_dict) + def test_load_state_dict_from_save_params(self): + self.save_dirname = "static_mnist.load_state_dict.save_params" + self.params_filename = None + orig_param_dict = self.train_and_save_model(True) + + load_param_dict, _ = paddle.load(self.save_dirname) + self.check_load_state_dict(orig_param_dict, load_param_dict) + if __name__ == '__main__': unittest.main()