未验证 提交 93d34f83 编写于 作者: W WeiXin 提交者: GitHub

'jit.save/load' support save/load function without parameters. (#32430) (#32613)

* jit.save/load support function.

* delete unnittest test_jit_load_model_incomplete.

* edit code according to CI

* Modify the documentation.

* add note to doc.
上级 263710c9
......@@ -650,6 +650,7 @@ def _construct_params_and_buffers(model_path,
append_suffix=True):
var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename)
params_path = os.path.join(model_path, str(params_filename))
if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path,
......@@ -671,6 +672,9 @@ def _construct_params_and_buffers(model_path,
var_dict.update(
_load_persistable_vars(model_path, var_info_path, programs[
func_name], file_name))
elif params_filename is not None and not os.path.exists(params_path):
# When saving XX, there is only '*.pdmodel'
return dict()
else:
var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename)
......
......@@ -19,6 +19,7 @@ import pickle
import warnings
import functools
from collections import OrderedDict
import inspect
import six
import paddle
......@@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config):
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
"""
Saves input Layer as ``paddle.jit.TranslatedLayer``
Saves input Layer or function as ``paddle.jit.TranslatedLayer``
format model, which can be used for inference or fine-tuning after loading.
It will save the translated program and all related persistable
......@@ -522,8 +523,12 @@ def save(layer, path, input_spec=None, **configs):
- ``paddle.static.load_inference_model``
- Other C++ inference APIs
.. note::
When using ``paddle.jit.save`` to save a function, parameters will not be saved. If you have to
save the parameter, please pass the Layer containing function and parameter to ``paddle.jit.save``.
Args:
layer (Layer): The Layer to be saved.
layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of
......@@ -543,6 +548,7 @@ def save(layer, path, input_spec=None, **configs):
Examples:
.. code-block:: python
# example 1: save layer
import numpy as np
import paddle
import paddle.nn as nn
......@@ -609,6 +615,28 @@ def save(layer, path, input_spec=None, **configs):
# save
path = "example_model/linear"
paddle.jit.save(layer, path)
# example 2: save function
import paddle
from paddle.static import InputSpec
def save_function():
@paddle.jit.to_static
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_1/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
print((load_result - origin).abs().max() < 1e-10)
save_function()
"""
# 1. input build & check
......@@ -617,9 +645,11 @@ def save(layer, path, input_spec=None, **configs):
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
if not (isinstance(layer, Layer) or inspect.isfunction(layer) or isinstance(
layer, StaticFunction)):
raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
"The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
% type(layer))
# NOTE(chenweihang): If the input layer be wrapped by DataParallel,
......@@ -647,6 +677,7 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec
inner_input_spec = None
if input_spec is not None:
if isinstance(layer, Layer):
for attr_func in dir(inner_layer):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func,
......@@ -654,6 +685,7 @@ def save(layer, path, input_spec=None, **configs):
raise ValueError(
"If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
% type(input_spec))
if not isinstance(input_spec, (list, tuple)):
raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s."
......@@ -674,7 +706,13 @@ def save(layer, path, input_spec=None, **configs):
configs = _parse_save_configs(configs)
scope = core.Scope()
extra_var_info = dict()
for attr_func in dir(inner_layer):
if isinstance(layer, Layer):
functions = dir(inner_layer)
else:
# layer is function
functions = [layer, ]
for attr_func in functions:
if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction):
concrete_program = static_func.concrete_program_specify_input_spec(
......@@ -696,25 +734,6 @@ def save(layer, path, input_spec=None, **configs):
else:
continue
# 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
# NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
......@@ -723,11 +742,11 @@ def save(layer, path, input_spec=None, **configs):
for structured_name, var in six.iteritems(inner_layer.state_dict()):
state_names_dict[var.name] = structured_name
# 4. share parameters from Layer to scope & record var info
# 3. share parameters from Layer to scope & record var info
for param_or_buffer in concrete_program.parameters:
# share to scope
param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor(
)
param_or_buffer_tensor = scope.var(
param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# record var info
......@@ -736,10 +755,42 @@ def save(layer, path, input_spec=None, **configs):
if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient
extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program
# 4. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(concrete_program.inputs,
inner_input_spec)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
# 5. save inference model
from paddle.fluid.io import save_inference_model
......@@ -748,7 +799,7 @@ def save(layer, path, input_spec=None, **configs):
model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename
if 'forward' == attr_func:
if 'forward' == attr_func or not isinstance(layer, Layer):
model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX
else:
......@@ -782,6 +833,7 @@ def save(layer, path, input_spec=None, **configs):
# but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into
# file `***.pdiparams.info`
if isinstance(layer, Layer) and extra_var_info:
with scope_guard(scope):
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
with open(extra_var_info_path, 'wb') as f:
......
......@@ -399,15 +399,6 @@ class TestJitSaveLoad(unittest.TestCase):
with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_jit_load_model_incomplete(self):
model_path = "test_jit_save_load.remove_variables/model"
self.train_and_save_model(model_path)
# remove `.pdiparams`
var_path = model_path + INFER_PARAMS_SUFFIX
os.remove(var_path)
with self.assertRaises(ValueError):
paddle.jit.load(model_path)
def test_jit_load_no_path(self):
path = "test_jit_save_load.no_path/model_path"
with self.assertRaises(ValueError):
......@@ -1164,6 +1155,63 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase):
self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5)
class TestJitSaveLoadFunction(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_static_function(self):
@paddle.jit.to_static
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_1/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
def test_jit_save_load_function_input_spec(self):
@paddle.jit.to_static(input_spec=[
InputSpec(
shape=[None, 6], dtype='float32', name='x'),
])
def fun(inputs):
return paddle.nn.functional.relu(inputs)
path = 'test_jit_save_load_function_2/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
def test_jit_save_load_function_function(self):
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_3/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(
fun,
path,
input_spec=[
InputSpec(
shape=[None, 6], dtype='float32', name='x'),
])
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path):
layer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册