未验证 提交 3eee0467 编写于 作者: C Chen Weihang 提交者: GitHub

Add limit support for load_dygraph loading jit.save result (#25935)

* add limit support for load_dygraph loading jit.save result

* simplify unittest

* add unittests for coverage

* remove encoding limit of loading extra var info
上级 12bf9d71
......@@ -16,12 +16,13 @@ from __future__ import print_function
import os
import collections
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle
import six
from . import learning_rate_scheduler
import warnings
from .. import core
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars
__all__ = [
'save_dygraph',
......@@ -140,22 +141,83 @@ def load_dygraph(model_path, keep_name_table=False):
elif model_prefix.endswith(".pdopt"):
model_prefix = model_prefix[:-6]
params_file_path = model_prefix + ".pdparams"
if not os.path.exists(params_file_path):
raise RuntimeError("Parameter file [ {} ] not exists".format(
params_file_path))
with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]
para_dict = None
opti_dict = None
params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt"
if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path):
# Load state dict by `jit.save` save format
# TODO(chenweihang): [Why not support `io.save_infernece_model` save format here]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# Although we reluctantly restore the `state_dict` in some scenarios,
# this may not be complete and there are some limitations, so this function
# will be considered later. The limitations include:
# 1. `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_dict`,
# but this argument is currently not public
# 2. if `save_inference_model` save all persistable variables in a single file,
# user need to give the variable name list to load `state_dict`
# 1. check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# 2. load `__variables.info__`
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if not os.path.exists(var_info_path):
raise RuntimeError(
"No target can be loaded. Now only supports loading `state_dict` from "
"the result saved by `imperative.save` and `imperative.jit.save`."
)
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
# 3. load `__variables__`
# TODO(chenweihang): now only supports loading from default save format:
# - all persistable vars saved in one file named `__variables__`
# for other case, we may need to modify the arguments of this API
var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME)
if not os.path.exists(var_file_path):
raise RuntimeError(
"The parameter file to be loaded was not found. "
"Now only supports loading from the default save format, "
"and does not support custom params_filename and "
"save parameters separately.")
# 4. load all persistable vars
load_var_list = []
for name in sorted(extra_var_info):
var = _varbase_creator(name=name, persistable=True)
load_var_list.append(var)
_dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
# 5. construct state_dict
para_dict = dict()
for var in load_var_list:
structured_name = extra_var_info[var.name].get('structured_name',
None)
if structured_name is None:
raise RuntimeError(
"Cannot find saved variable (%s)'s structured name in saved model.",
var.name)
para_dict[structured_name] = var.numpy()
# NOTE: `jit.save` doesn't save optimizer state
else:
# Load state dict by `save_dygraph` save format
if os.path.exists(params_file_path):
with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]
if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
return para_dict, opti_dict
......@@ -425,8 +425,7 @@ def _load_persistable_vars(model_path,
params_filename=None):
# 1. load extra var info
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
extra_var_info = pickle.load(f)
# 2. construct var dict
load_var_dict = dict()
......
......@@ -14,13 +14,15 @@
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME
BATCH_SIZE = 32
BATCH_NUM = 20
......@@ -77,8 +79,8 @@ class LinearNetReturnLoss(fluid.dygraph.Layer):
def train(layer):
# create optimizer
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.1, parameter_list=layer.parameters())
adam = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=layer.parameters())
# create data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=5)
train_loader.set_batch_generator(random_batch_reader())
......@@ -111,37 +113,43 @@ class TestJitSaveLoad(unittest.TestCase):
# config seed
fluid.default_main_program().random_seed = SEED
def train_and_save_model(self):
def train_and_save_model(self, model_path=None, configs=None):
layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer)
final_model_path = model_path if model_path else self.model_path
orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save(
layer=layer, model_path=self.model_path, input_spec=example_inputs)
layer=layer,
model_path=final_model_path,
input_spec=example_inputs,
configs=configs)
new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types)
return layer
def test_save(self):
# train and save model
self.train_and_save_model()
def test_load_infernece(self):
def test_save_load(self):
# train and save model
train_layer = self.train_and_save_model()
# load model
infer_layer = fluid.dygraph.jit.load(self.model_path)
program_translator = ProgramTranslator()
program_translator.enable(False)
loaded_layer = fluid.dygraph.jit.load(self.model_path)
self.load_and_inference(train_layer, loaded_layer)
self.load_dygraph_state_dict(train_layer)
self.load_and_finetune(train_layer, loaded_layer)
program_translator.enable(True)
def load_and_inference(self, train_layer, infer_layer):
train_layer.eval()
infer_layer.eval()
# inference & compare
x = fluid.dygraph.to_variable(
np.random.random((1, 784)).astype('float32'))
self.assertTrue(
np.array_equal(train_layer(x).numpy(), infer_layer(x).numpy()))
def test_load_finetune(self):
# train and save model
train_layer = self.train_and_save_model()
# load model
load_train_layer = fluid.dygraph.jit.load(self.model_path)
def load_and_finetune(self, train_layer, load_train_layer):
train_layer.train()
load_train_layer.train()
# train & compare
_, _, train_loss = train(train_layer)
......@@ -149,6 +157,19 @@ class TestJitSaveLoad(unittest.TestCase):
self.assertTrue(
np.array_equal(train_loss.numpy(), load_train_loss.numpy()))
def load_dygraph_state_dict(self, train_layer):
train_layer.eval()
# contruct new model
new_layer = LinearNet(784, 1)
model_dict, _ = fluid.dygraph.load_dygraph(self.model_path)
new_layer.set_dict(model_dict)
new_layer.eval()
# inference & compare
x = fluid.dygraph.to_variable(
np.random.random((1, 784)).astype('float32'))
self.assertTrue(
np.array_equal(train_layer(x).numpy(), new_layer(x).numpy()))
def test_save_get_program_failed(self):
layer = LinearNetNotDeclarative(784, 1)
example_inputs, layer, _ = train(layer)
......@@ -158,6 +179,31 @@ class TestJitSaveLoad(unittest.TestCase):
model_path=self.model_path,
input_spec=example_inputs)
def test_load_dygraoh_no_path(self):
model_path = "model.test_jit_save_load.no_path"
new_layer = LinearNet(784, 1)
with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_load_dygraph_no_var_info(self):
model_path = "model.test_jit_save_load.no_var_info"
self.train_and_save_model(model_path=model_path)
# remove `__variables.info__`
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
os.remove(var_info_path)
new_layer = LinearNet(784, 1)
with self.assertRaises(RuntimeError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_load_dygraph_not_var_file(self):
model_path = "model.test_jit_save_load.no_var_file"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.params_filename = "__params__"
self.train_and_save_model(model_path=model_path, configs=configs)
new_layer = LinearNet(784, 1)
with self.assertRaises(RuntimeError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
class TestJitSaveLoadConfig(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册