未验证 提交 93d34f83 编写于 作者: W WeiXin 提交者: GitHub

'jit.save/load' support save/load function without parameters. (#32430) (#32613)

* jit.save/load support function.

* delete unnittest test_jit_load_model_incomplete.

* edit code according to CI

* Modify the documentation.

* add note to doc.
上级 263710c9
...@@ -650,6 +650,7 @@ def _construct_params_and_buffers(model_path, ...@@ -650,6 +650,7 @@ def _construct_params_and_buffers(model_path,
append_suffix=True): append_suffix=True):
var_info_filename = str(params_filename) + ".info" var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
params_path = os.path.join(model_path, str(params_filename))
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(model_path, var_info_path,
...@@ -671,6 +672,9 @@ def _construct_params_and_buffers(model_path, ...@@ -671,6 +672,9 @@ def _construct_params_and_buffers(model_path,
var_dict.update( var_dict.update(
_load_persistable_vars(model_path, var_info_path, programs[ _load_persistable_vars(model_path, var_info_path, programs[
func_name], file_name)) func_name], file_name))
elif params_filename is not None and not os.path.exists(params_path):
# When saving XX, there is only '*.pdmodel'
return dict()
else: else:
var_dict = _load_persistable_vars_by_program( var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename) model_path, programs['forward'], params_filename)
......
...@@ -19,6 +19,7 @@ import pickle ...@@ -19,6 +19,7 @@ import pickle
import warnings import warnings
import functools import functools
from collections import OrderedDict from collections import OrderedDict
import inspect
import six import six
import paddle import paddle
...@@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config): ...@@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config):
@switch_to_static_graph @switch_to_static_graph
def save(layer, path, input_spec=None, **configs): def save(layer, path, input_spec=None, **configs):
""" """
Saves input Layer as ``paddle.jit.TranslatedLayer`` Saves input Layer or function as ``paddle.jit.TranslatedLayer``
format model, which can be used for inference or fine-tuning after loading. format model, which can be used for inference or fine-tuning after loading.
It will save the translated program and all related persistable It will save the translated program and all related persistable
...@@ -522,8 +523,12 @@ def save(layer, path, input_spec=None, **configs): ...@@ -522,8 +523,12 @@ def save(layer, path, input_spec=None, **configs):
- ``paddle.static.load_inference_model`` - ``paddle.static.load_inference_model``
- Other C++ inference APIs - Other C++ inference APIs
.. note::
When using ``paddle.jit.save`` to save a function, parameters will not be saved. If you have to
save the parameter, please pass the Layer containing function and parameter to ``paddle.jit.save``.
Args: Args:
layer (Layer): The Layer to be saved. layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of method, which can be described by InputSpec or example Tensor. If None, all input variables of
...@@ -543,6 +548,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -543,6 +548,7 @@ def save(layer, path, input_spec=None, **configs):
Examples: Examples:
.. code-block:: python .. code-block:: python
# example 1: save layer
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -609,6 +615,28 @@ def save(layer, path, input_spec=None, **configs): ...@@ -609,6 +615,28 @@ def save(layer, path, input_spec=None, **configs):
# save # save
path = "example_model/linear" path = "example_model/linear"
paddle.jit.save(layer, path) paddle.jit.save(layer, path)
# example 2: save function
import paddle
from paddle.static import InputSpec
def save_function():
@paddle.jit.to_static
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_1/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
print((load_result - origin).abs().max() < 1e-10)
save_function()
""" """
# 1. input build & check # 1. input build & check
...@@ -617,9 +645,11 @@ def save(layer, path, input_spec=None, **configs): ...@@ -617,9 +645,11 @@ def save(layer, path, input_spec=None, **configs):
raise RuntimeError( raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False." "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
) )
if not isinstance(layer, Layer):
if not (isinstance(layer, Layer) or inspect.isfunction(layer) or isinstance(
layer, StaticFunction)):
raise TypeError( raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s." "The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
% type(layer)) % type(layer))
# NOTE(chenweihang): If the input layer be wrapped by DataParallel, # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
...@@ -647,6 +677,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -647,6 +677,7 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec # avoid change user given input_spec
inner_input_spec = None inner_input_spec = None
if input_spec is not None: if input_spec is not None:
if isinstance(layer, Layer):
for attr_func in dir(inner_layer): for attr_func in dir(inner_layer):
static_func = getattr(inner_layer, attr_func, None) static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, if isinstance(static_func,
...@@ -654,6 +685,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -654,6 +685,7 @@ def save(layer, path, input_spec=None, **configs):
raise ValueError( raise ValueError(
"If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s." "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
% type(input_spec)) % type(input_spec))
if not isinstance(input_spec, (list, tuple)): if not isinstance(input_spec, (list, tuple)):
raise TypeError( raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s." "The input input_spec should be 'list', but received input_spec's type is %s."
...@@ -674,7 +706,13 @@ def save(layer, path, input_spec=None, **configs): ...@@ -674,7 +706,13 @@ def save(layer, path, input_spec=None, **configs):
configs = _parse_save_configs(configs) configs = _parse_save_configs(configs)
scope = core.Scope() scope = core.Scope()
extra_var_info = dict() extra_var_info = dict()
for attr_func in dir(inner_layer): if isinstance(layer, Layer):
functions = dir(inner_layer)
else:
# layer is function
functions = [layer, ]
for attr_func in functions:
if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None) static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction): if isinstance(static_func, StaticFunction):
concrete_program = static_func.concrete_program_specify_input_spec( concrete_program = static_func.concrete_program_specify_input_spec(
...@@ -696,25 +734,6 @@ def save(layer, path, input_spec=None, **configs): ...@@ -696,25 +734,6 @@ def save(layer, path, input_spec=None, **configs):
else: else:
continue continue
# 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
# NOTE(chenweihang): we maintain the mapping of variable name to # NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable) # structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer, # saved to inference program may not need by dygraph Layer,
...@@ -723,11 +742,11 @@ def save(layer, path, input_spec=None, **configs): ...@@ -723,11 +742,11 @@ def save(layer, path, input_spec=None, **configs):
for structured_name, var in six.iteritems(inner_layer.state_dict()): for structured_name, var in six.iteritems(inner_layer.state_dict()):
state_names_dict[var.name] = structured_name state_names_dict[var.name] = structured_name
# 4. share parameters from Layer to scope & record var info # 3. share parameters from Layer to scope & record var info
for param_or_buffer in concrete_program.parameters: for param_or_buffer in concrete_program.parameters:
# share to scope # share to scope
param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor( param_or_buffer_tensor = scope.var(
) param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor() src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor) param_or_buffer_tensor._share_data_with(src_tensor)
# record var info # record var info
...@@ -736,10 +755,42 @@ def save(layer, path, input_spec=None, **configs): ...@@ -736,10 +755,42 @@ def save(layer, path, input_spec=None, **configs):
if param_or_buffer.name in state_names_dict: if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[ extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name] param_or_buffer.name]
extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase): if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict extra_var_info[param_or_buffer.name] = extra_info_dict
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program
# 4. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
# 5. save inference model # 5. save inference model
from paddle.fluid.io import save_inference_model from paddle.fluid.io import save_inference_model
...@@ -748,7 +799,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -748,7 +799,7 @@ def save(layer, path, input_spec=None, **configs):
model_path = dirname model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename, # NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename # so we don't support set model_filename & params_filename
if 'forward' == attr_func: if 'forward' == attr_func or not isinstance(layer, Layer):
model_filename = file_prefix + INFER_MODEL_SUFFIX model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX
else: else:
...@@ -782,6 +833,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -782,6 +833,7 @@ def save(layer, path, input_spec=None, **configs):
# but we can save these information in `jit.save` without changing the original # but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into # storage to improve user experience. So we save extra information into
# file `***.pdiparams.info` # file `***.pdiparams.info`
if isinstance(layer, Layer) and extra_var_info:
with scope_guard(scope): with scope_guard(scope):
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
with open(extra_var_info_path, 'wb') as f: with open(extra_var_info_path, 'wb') as f:
......
...@@ -399,15 +399,6 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -399,15 +399,6 @@ class TestJitSaveLoad(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path) model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_jit_load_model_incomplete(self):
model_path = "test_jit_save_load.remove_variables/model"
self.train_and_save_model(model_path)
# remove `.pdiparams`
var_path = model_path + INFER_PARAMS_SUFFIX
os.remove(var_path)
with self.assertRaises(ValueError):
paddle.jit.load(model_path)
def test_jit_load_no_path(self): def test_jit_load_no_path(self):
path = "test_jit_save_load.no_path/model_path" path = "test_jit_save_load.no_path/model_path"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -1164,6 +1155,63 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase): ...@@ -1164,6 +1155,63 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase):
self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5) self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5)
class TestJitSaveLoadFunction(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_static_function(self):
@paddle.jit.to_static
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_1/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
def test_jit_save_load_function_input_spec(self):
@paddle.jit.to_static(input_spec=[
InputSpec(
shape=[None, 6], dtype='float32', name='x'),
])
def fun(inputs):
return paddle.nn.functional.relu(inputs)
path = 'test_jit_save_load_function_2/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
def test_jit_save_load_function_function(self):
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_3/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(
fun,
path,
input_spec=[
InputSpec(
shape=[None, 6], dtype='float32', name='x'),
])
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
class TestJitSaveLoadDataParallel(unittest.TestCase): class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path): def verify_inference_correctness(self, layer, path):
layer.eval() layer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册