diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index de4330cf51669ebbbfb1ca7e9edcc0c82b1d0e72..82018132cc8b8600958e5cd52df5844e3d37638e 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -16,12 +16,13 @@ from __future__ import print_function import os import collections -from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase +from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer import pickle import six from . import learning_rate_scheduler import warnings from .. import core +from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars __all__ = [ 'save_dygraph', @@ -140,22 +141,83 @@ def load_dygraph(model_path, keep_name_table=False): elif model_prefix.endswith(".pdopt"): model_prefix = model_prefix[:-6] - params_file_path = model_prefix + ".pdparams" - if not os.path.exists(params_file_path): - raise RuntimeError("Parameter file [ {} ] not exists".format( - params_file_path)) - - with open(params_file_path, 'rb') as f: - para_dict = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') - - if not keep_name_table and "StructuredToParameterName@@" in para_dict: - del para_dict["StructuredToParameterName@@"] + para_dict = None opti_dict = None + params_file_path = model_prefix + ".pdparams" opti_file_path = model_prefix + ".pdopt" - if os.path.exists(opti_file_path): - with open(opti_file_path, 'rb') as f: - opti_dict = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') + if not os.path.exists(params_file_path) and not os.path.exists( + opti_file_path): + # Load state dict by `jit.save` save format + # TODO(chenweihang): [Why not support `io.save_infernece_model` save format here] + # The model saved by `save_inference_model` does not completely correspond to + # the information required by the `state_dict` under the dygraph. + # Although we reluctantly restore the `state_dict` in some scenarios, + # this may not be complete and there are some limitations, so this function + # will be considered later. The limitations include: + # 1. `save_inference_model` not save structured name, we need to remind + # the user to configure the `use_structured_name` argument when `set_dict`, + # but this argument is currently not public + # 2. if `save_inference_model` save all persistable variables in a single file, + # user need to give the variable name list to load `state_dict` + + # 1. check model path + if not os.path.isdir(model_prefix): + raise ValueError("Model saved directory '%s' is not exists." % + model_prefix) + # 2. load `__variables.info__` + var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME) + if not os.path.exists(var_info_path): + raise RuntimeError( + "No target can be loaded. Now only supports loading `state_dict` from " + "the result saved by `imperative.save` and `imperative.jit.save`." + ) + with open(var_info_path, 'rb') as f: + extra_var_info = pickle.load(f) + # 3. load `__variables__` + # TODO(chenweihang): now only supports loading from default save format: + # - all persistable vars saved in one file named `__variables__` + # for other case, we may need to modify the arguments of this API + var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME) + if not os.path.exists(var_file_path): + raise RuntimeError( + "The parameter file to be loaded was not found. " + "Now only supports loading from the default save format, " + "and does not support custom params_filename and " + "save parameters separately.") + # 4. load all persistable vars + load_var_list = [] + for name in sorted(extra_var_info): + var = _varbase_creator(name=name, persistable=True) + load_var_list.append(var) + _dygraph_tracer().trace_op( + type='load_combine', + inputs={}, + outputs={'Out': load_var_list}, + attrs={'file_path': var_file_path}) + # 5. construct state_dict + para_dict = dict() + for var in load_var_list: + structured_name = extra_var_info[var.name].get('structured_name', + None) + if structured_name is None: + raise RuntimeError( + "Cannot find saved variable (%s)'s structured name in saved model.", + var.name) + para_dict[structured_name] = var.numpy() + # NOTE: `jit.save` doesn't save optimizer state + else: + # Load state dict by `save_dygraph` save format + if os.path.exists(params_file_path): + with open(params_file_path, 'rb') as f: + para_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + + if not keep_name_table and "StructuredToParameterName@@" in para_dict: + del para_dict["StructuredToParameterName@@"] + + if os.path.exists(opti_file_path): + with open(opti_file_path, 'rb') as f: + opti_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') return para_dict, opti_dict diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 38e4e517836ed8ddbeb36fb68a0c34fa9826f233..7396289392affa92e69e9f55fba622fd13fa979f 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -425,8 +425,7 @@ def _load_persistable_vars(model_path, params_filename=None): # 1. load extra var info with open(var_info_path, 'rb') as f: - extra_var_info = pickle.load(f) if six.PY2 else pickle.load( - f, encoding='latin1') + extra_var_info = pickle.load(f) # 2. construct var dict load_var_dict = dict() diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index abc46034957cf7414310f0f593f3bcce71a6d1de..a61d31e88253d7b45efde6226fe14cf5b5b11af9 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -14,13 +14,15 @@ from __future__ import print_function +import os import unittest import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import Linear -from paddle.fluid.dygraph import declarative +from paddle.fluid.dygraph import declarative, ProgramTranslator +from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME BATCH_SIZE = 32 BATCH_NUM = 20 @@ -77,8 +79,8 @@ class LinearNetReturnLoss(fluid.dygraph.Layer): def train(layer): # create optimizer - adam = fluid.optimizer.AdamOptimizer( - learning_rate=0.1, parameter_list=layer.parameters()) + adam = fluid.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=layer.parameters()) # create data loader train_loader = fluid.io.DataLoader.from_generator(capacity=5) train_loader.set_batch_generator(random_batch_reader()) @@ -111,37 +113,43 @@ class TestJitSaveLoad(unittest.TestCase): # config seed fluid.default_main_program().random_seed = SEED - def train_and_save_model(self): + def train_and_save_model(self, model_path=None, configs=None): layer = LinearNet(784, 1) example_inputs, layer, _ = train(layer) + final_model_path = model_path if model_path else self.model_path orig_input_types = [type(x) for x in example_inputs] fluid.dygraph.jit.save( - layer=layer, model_path=self.model_path, input_spec=example_inputs) + layer=layer, + model_path=final_model_path, + input_spec=example_inputs, + configs=configs) new_input_types = [type(x) for x in example_inputs] self.assertEqual(orig_input_types, new_input_types) return layer - def test_save(self): - # train and save model - self.train_and_save_model() - - def test_load_infernece(self): + def test_save_load(self): # train and save model train_layer = self.train_and_save_model() # load model - infer_layer = fluid.dygraph.jit.load(self.model_path) + program_translator = ProgramTranslator() + program_translator.enable(False) + loaded_layer = fluid.dygraph.jit.load(self.model_path) + self.load_and_inference(train_layer, loaded_layer) + self.load_dygraph_state_dict(train_layer) + self.load_and_finetune(train_layer, loaded_layer) + program_translator.enable(True) + + def load_and_inference(self, train_layer, infer_layer): train_layer.eval() + infer_layer.eval() # inference & compare x = fluid.dygraph.to_variable( np.random.random((1, 784)).astype('float32')) self.assertTrue( np.array_equal(train_layer(x).numpy(), infer_layer(x).numpy())) - def test_load_finetune(self): - # train and save model - train_layer = self.train_and_save_model() - # load model - load_train_layer = fluid.dygraph.jit.load(self.model_path) + def load_and_finetune(self, train_layer, load_train_layer): + train_layer.train() load_train_layer.train() # train & compare _, _, train_loss = train(train_layer) @@ -149,6 +157,19 @@ class TestJitSaveLoad(unittest.TestCase): self.assertTrue( np.array_equal(train_loss.numpy(), load_train_loss.numpy())) + def load_dygraph_state_dict(self, train_layer): + train_layer.eval() + # contruct new model + new_layer = LinearNet(784, 1) + model_dict, _ = fluid.dygraph.load_dygraph(self.model_path) + new_layer.set_dict(model_dict) + new_layer.eval() + # inference & compare + x = fluid.dygraph.to_variable( + np.random.random((1, 784)).astype('float32')) + self.assertTrue( + np.array_equal(train_layer(x).numpy(), new_layer(x).numpy())) + def test_save_get_program_failed(self): layer = LinearNetNotDeclarative(784, 1) example_inputs, layer, _ = train(layer) @@ -158,6 +179,31 @@ class TestJitSaveLoad(unittest.TestCase): model_path=self.model_path, input_spec=example_inputs) + def test_load_dygraoh_no_path(self): + model_path = "model.test_jit_save_load.no_path" + new_layer = LinearNet(784, 1) + with self.assertRaises(ValueError): + model_dict, _ = fluid.dygraph.load_dygraph(model_path) + + def test_load_dygraph_no_var_info(self): + model_path = "model.test_jit_save_load.no_var_info" + self.train_and_save_model(model_path=model_path) + # remove `__variables.info__` + var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME) + os.remove(var_info_path) + new_layer = LinearNet(784, 1) + with self.assertRaises(RuntimeError): + model_dict, _ = fluid.dygraph.load_dygraph(model_path) + + def test_load_dygraph_not_var_file(self): + model_path = "model.test_jit_save_load.no_var_file" + configs = fluid.dygraph.jit.SaveLoadConfig() + configs.params_filename = "__params__" + self.train_and_save_model(model_path=model_path, configs=configs) + new_layer = LinearNet(784, 1) + with self.assertRaises(RuntimeError): + model_dict, _ = fluid.dygraph.load_dygraph(model_path) + class TestJitSaveLoadConfig(unittest.TestCase): def setUp(self):