未验证 提交 93d34f83 编写于 作者: W WeiXin 提交者: GitHub

'jit.save/load' support save/load function without parameters. (#32430) (#32613)

* jit.save/load support function.

* delete unnittest test_jit_load_model_incomplete.

* edit code according to CI

* Modify the documentation.

* add note to doc.
上级 263710c9
...@@ -650,6 +650,7 @@ def _construct_params_and_buffers(model_path, ...@@ -650,6 +650,7 @@ def _construct_params_and_buffers(model_path,
append_suffix=True): append_suffix=True):
var_info_filename = str(params_filename) + ".info" var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
params_path = os.path.join(model_path, str(params_filename))
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(model_path, var_info_path,
...@@ -671,6 +672,9 @@ def _construct_params_and_buffers(model_path, ...@@ -671,6 +672,9 @@ def _construct_params_and_buffers(model_path,
var_dict.update( var_dict.update(
_load_persistable_vars(model_path, var_info_path, programs[ _load_persistable_vars(model_path, var_info_path, programs[
func_name], file_name)) func_name], file_name))
elif params_filename is not None and not os.path.exists(params_path):
# When saving XX, there is only '*.pdmodel'
return dict()
else: else:
var_dict = _load_persistable_vars_by_program( var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename) model_path, programs['forward'], params_filename)
......
...@@ -19,6 +19,7 @@ import pickle ...@@ -19,6 +19,7 @@ import pickle
import warnings import warnings
import functools import functools
from collections import OrderedDict from collections import OrderedDict
import inspect
import six import six
import paddle import paddle
...@@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config): ...@@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config):
@switch_to_static_graph @switch_to_static_graph
def save(layer, path, input_spec=None, **configs): def save(layer, path, input_spec=None, **configs):
""" """
Saves input Layer as ``paddle.jit.TranslatedLayer`` Saves input Layer or function as ``paddle.jit.TranslatedLayer``
format model, which can be used for inference or fine-tuning after loading. format model, which can be used for inference or fine-tuning after loading.
It will save the translated program and all related persistable It will save the translated program and all related persistable
...@@ -522,8 +523,12 @@ def save(layer, path, input_spec=None, **configs): ...@@ -522,8 +523,12 @@ def save(layer, path, input_spec=None, **configs):
- ``paddle.static.load_inference_model`` - ``paddle.static.load_inference_model``
- Other C++ inference APIs - Other C++ inference APIs
.. note::
When using ``paddle.jit.save`` to save a function, parameters will not be saved. If you have to
save the parameter, please pass the Layer containing function and parameter to ``paddle.jit.save``.
Args: Args:
layer (Layer): The Layer to be saved. layer (Layer|function): The Layer or function to be saved.
path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of method, which can be described by InputSpec or example Tensor. If None, all input variables of
...@@ -543,6 +548,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -543,6 +548,7 @@ def save(layer, path, input_spec=None, **configs):
Examples: Examples:
.. code-block:: python .. code-block:: python
# example 1: save layer
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -609,6 +615,28 @@ def save(layer, path, input_spec=None, **configs): ...@@ -609,6 +615,28 @@ def save(layer, path, input_spec=None, **configs):
# save # save
path = "example_model/linear" path = "example_model/linear"
paddle.jit.save(layer, path) paddle.jit.save(layer, path)
# example 2: save function
import paddle
from paddle.static import InputSpec
def save_function():
@paddle.jit.to_static
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_1/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
print((load_result - origin).abs().max() < 1e-10)
save_function()
""" """
# 1. input build & check # 1. input build & check
...@@ -617,9 +645,11 @@ def save(layer, path, input_spec=None, **configs): ...@@ -617,9 +645,11 @@ def save(layer, path, input_spec=None, **configs):
raise RuntimeError( raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False." "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
) )
if not isinstance(layer, Layer):
if not (isinstance(layer, Layer) or inspect.isfunction(layer) or isinstance(
layer, StaticFunction)):
raise TypeError( raise TypeError(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s." "The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
% type(layer)) % type(layer))
# NOTE(chenweihang): If the input layer be wrapped by DataParallel, # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
...@@ -647,13 +677,15 @@ def save(layer, path, input_spec=None, **configs): ...@@ -647,13 +677,15 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec # avoid change user given input_spec
inner_input_spec = None inner_input_spec = None
if input_spec is not None: if input_spec is not None:
for attr_func in dir(inner_layer): if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None) for attr_func in dir(inner_layer):
if isinstance(static_func, static_func = getattr(inner_layer, attr_func, None)
StaticFunction) and 'forward' != attr_func: if isinstance(static_func,
raise ValueError( StaticFunction) and 'forward' != attr_func:
"If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s." raise ValueError(
% type(input_spec)) "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
% type(input_spec))
if not isinstance(input_spec, (list, tuple)): if not isinstance(input_spec, (list, tuple)):
raise TypeError( raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s." "The input input_spec should be 'list', but received input_spec's type is %s."
...@@ -674,29 +706,74 @@ def save(layer, path, input_spec=None, **configs): ...@@ -674,29 +706,74 @@ def save(layer, path, input_spec=None, **configs):
configs = _parse_save_configs(configs) configs = _parse_save_configs(configs)
scope = core.Scope() scope = core.Scope()
extra_var_info = dict() extra_var_info = dict()
for attr_func in dir(inner_layer): if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None) functions = dir(inner_layer)
if isinstance(static_func, StaticFunction): else:
concrete_program = static_func.concrete_program_specify_input_spec( # layer is function
inner_input_spec) functions = [layer, ]
elif 'forward' == attr_func: for attr_func in functions:
# transform in jit.save, if input_spec is incomplete, declarative will throw error if isinstance(layer, Layer):
# inner_input_spec is list[InputSpec], it should be packed with same sturcture static_func = getattr(inner_layer, attr_func, None)
# as original input_spec here. if isinstance(static_func, StaticFunction):
if inner_input_spec: concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec = pack_sequence_as(input_spec, inner_input_spec)
inner_input_spec) elif 'forward' == attr_func:
static_forward = declarative( # transform in jit.save, if input_spec is incomplete, declarative will throw error
inner_layer.forward, input_spec=inner_input_spec) # inner_input_spec is list[InputSpec], it should be packed with same sturcture
concrete_program = static_forward.concrete_program # as original input_spec here.
# the input_spec has been used in declarative, which is equal to if inner_input_spec:
# @declarative with input_spec and jit.save without input_spec, inner_input_spec = pack_sequence_as(input_spec,
# avoid needless warning inner_input_spec)
inner_input_spec = None static_forward = declarative(
inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec = None
else:
continue
# NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in six.iteritems(inner_layer.state_dict()):
state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info
for param_or_buffer in concrete_program.parameters:
# share to scope
param_or_buffer_tensor = scope.var(
param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# record var info
if param_or_buffer.name not in extra_var_info:
extra_info_dict = dict()
if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
else: else:
continue # When layer is a function
if isinstance(attr_func, StaticFunction):
# 3. build input & output of save_infernece_model concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program
# 4. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ] # NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not # There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend): # - not prune inputs (recommend):
...@@ -715,32 +792,6 @@ def save(layer, path, input_spec=None, **configs): ...@@ -715,32 +792,6 @@ def save(layer, path, input_spec=None, **configs):
output_vars = _get_output_vars(concrete_program.outputs, output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec) configs.output_spec)
# NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in six.iteritems(inner_layer.state_dict()):
state_names_dict[var.name] = structured_name
# 4. share parameters from Layer to scope & record var info
for param_or_buffer in concrete_program.parameters:
# share to scope
param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor(
)
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# record var info
if param_or_buffer.name not in extra_var_info:
extra_info_dict = dict()
if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
# 5. save inference model # 5. save inference model
from paddle.fluid.io import save_inference_model from paddle.fluid.io import save_inference_model
...@@ -748,7 +799,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -748,7 +799,7 @@ def save(layer, path, input_spec=None, **configs):
model_path = dirname model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename, # NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename # so we don't support set model_filename & params_filename
if 'forward' == attr_func: if 'forward' == attr_func or not isinstance(layer, Layer):
model_filename = file_prefix + INFER_MODEL_SUFFIX model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX
else: else:
...@@ -782,10 +833,11 @@ def save(layer, path, input_spec=None, **configs): ...@@ -782,10 +833,11 @@ def save(layer, path, input_spec=None, **configs):
# but we can save these information in `jit.save` without changing the original # but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into # storage to improve user experience. So we save extra information into
# file `***.pdiparams.info` # file `***.pdiparams.info`
with scope_guard(scope): if isinstance(layer, Layer) and extra_var_info:
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX with scope_guard(scope):
with open(extra_var_info_path, 'wb') as f: extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
pickle.dump(extra_var_info, f, protocol=2) with open(extra_var_info_path, 'wb') as f:
pickle.dump(extra_var_info, f, protocol=2)
@dygraph_only @dygraph_only
......
...@@ -399,15 +399,6 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -399,15 +399,6 @@ class TestJitSaveLoad(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path) model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_jit_load_model_incomplete(self):
model_path = "test_jit_save_load.remove_variables/model"
self.train_and_save_model(model_path)
# remove `.pdiparams`
var_path = model_path + INFER_PARAMS_SUFFIX
os.remove(var_path)
with self.assertRaises(ValueError):
paddle.jit.load(model_path)
def test_jit_load_no_path(self): def test_jit_load_no_path(self):
path = "test_jit_save_load.no_path/model_path" path = "test_jit_save_load.no_path/model_path"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -1164,6 +1155,63 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase): ...@@ -1164,6 +1155,63 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase):
self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5) self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5)
class TestJitSaveLoadFunction(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_static_function(self):
@paddle.jit.to_static
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_1/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
def test_jit_save_load_function_input_spec(self):
@paddle.jit.to_static(input_spec=[
InputSpec(
shape=[None, 6], dtype='float32', name='x'),
])
def fun(inputs):
return paddle.nn.functional.relu(inputs)
path = 'test_jit_save_load_function_2/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(fun, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
def test_jit_save_load_function_function(self):
def fun(inputs):
return paddle.tanh(inputs)
path = 'test_jit_save_load_function_3/func'
inps = paddle.rand([3, 6])
origin = fun(inps)
paddle.jit.save(
fun,
path,
input_spec=[
InputSpec(
shape=[None, 6], dtype='float32', name='x'),
])
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue((load_result - origin).abs().max() < 1e-10)
class TestJitSaveLoadDataParallel(unittest.TestCase): class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path): def verify_inference_correctness(self, layer, path):
layer.eval() layer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册