未验证 提交 1b377635 编写于 作者: W WeiXin 提交者: GitHub

jit.save/load support method with parameters. (#34070)

* jit.save/load support method with parameters.

* add unittest and warning

* polish warning message.
上级 52c1a950
...@@ -35,7 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTra ...@@ -35,7 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTra
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Block, ParamBase, Program, Variable from paddle.fluid.framework import Block, ParamBase, Program, Variable, Parameter
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, in_dygraph_mode from paddle.fluid.framework import dygraph_only, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
...@@ -659,6 +659,10 @@ def save(layer, path, input_spec=None, **configs): ...@@ -659,6 +659,10 @@ def save(layer, path, input_spec=None, **configs):
raise TypeError( raise TypeError(
"The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s." "The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
% type(layer)) % type(layer))
elif inspect.isfunction(layer) or isinstance(layer, StaticFunction):
warnings.warn(
'What you save is a function, and `jit.save` will generate the name of the model file according to `path` you specify. When loading these files with `jit.load`, you get a `TranslatedLayer` whose inference result is the same as the inference result of the function you saved.'
)
# NOTE(chenweihang): If the input layer be wrapped by DataParallel, # NOTE(chenweihang): If the input layer be wrapped by DataParallel,
# the args and kwargs of forward method will can't be parsed by # the args and kwargs of forward method will can't be parsed by
...@@ -741,12 +745,38 @@ def save(layer, path, input_spec=None, **configs): ...@@ -741,12 +745,38 @@ def save(layer, path, input_spec=None, **configs):
else: else:
continue continue
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program
if static_function._class_instance is None:
warnings.warn(
'`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'.
format(layer))
dygraph_state_dict = None
if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.state_dict()
elif isinstance(attr_func, StaticFunction):
if attr_func._class_instance:
dygraph_state_dict = attr_func._class_instance.state_dict()
if dygraph_state_dict:
# NOTE(chenweihang): we maintain the mapping of variable name to # NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable) # structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer, # saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name # we only record the state_dict variable's structured name
state_names_dict = dict() state_names_dict = dict()
for structured_name, var in six.iteritems(inner_layer.state_dict()): for structured_name, var in six.iteritems(dygraph_state_dict):
state_names_dict[var.name] = structured_name state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info # 3. share parameters from Layer to scope & record var info
...@@ -767,18 +797,6 @@ def save(layer, path, input_spec=None, **configs): ...@@ -767,18 +797,6 @@ def save(layer, path, input_spec=None, **configs):
if isinstance(param_or_buffer, ParamBase): if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict extra_var_info[param_or_buffer.name] = extra_info_dict
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program
# 4. build input & output of save_infernece_model # 4. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ] # NOTE(chenweihang): [ Get input variables name ]
...@@ -840,7 +858,14 @@ def save(layer, path, input_spec=None, **configs): ...@@ -840,7 +858,14 @@ def save(layer, path, input_spec=None, **configs):
# but we can save these information in `jit.save` without changing the original # but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into # storage to improve user experience. So we save extra information into
# file `***.pdiparams.info` # file `***.pdiparams.info`
if isinstance(layer, Layer) and extra_var_info:
# "layer" can only be Layer or function or StaticFunction.
contain_parameter = False
for var in concrete_program.main_program.list_vars():
contain_parameter |= isinstance(var, Parameter)
if (isinstance(layer, Layer) or contain_parameter) and extra_var_info:
with scope_guard(scope): with scope_guard(scope):
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
with open(extra_var_info_path, 'wb') as f: with open(extra_var_info_path, 'wb') as f:
......
...@@ -1227,6 +1227,99 @@ class TestJitSaveLoadFunctionCase3(unittest.TestCase): ...@@ -1227,6 +1227,99 @@ class TestJitSaveLoadFunctionCase3(unittest.TestCase):
self.assertTrue((load_result - origin).abs().max() < 1e-10) self.assertTrue((load_result - origin).abs().max() < 1e-10)
class TestJitSaveLoadFunctionWithParamCase1(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_function(self):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(5, 6)
def forward(self, x):
return paddle.tanh(x)
def anothor_forward(self, x):
return self._linear(x)
layer = LinearNet()
inps = paddle.rand([3, 5])
origin = layer.anothor_forward(inps)
func = paddle.jit.to_static(
layer.anothor_forward, [paddle.static.InputSpec(shape=[-1, 5])])
path = 'test_jit_save_load_function_with_params_case1/func'
paddle.jit.save(func, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue(np.array_equal(load_result.numpy(), origin.numpy()))
class TestJitSaveLoadFunctionWithParamCase2(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_function(self):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(5, 6)
def forward(self, x):
return paddle.tanh(x)
@paddle.jit.to_static(input_spec=[InputSpec(shape=[-1, 5])])
def anothor_forward(self, x):
return self._linear(x)
layer = LinearNet()
inps = paddle.rand([3, 5])
path = 'test_jit_save_load_function_with_params_case2/func'
paddle.jit.save(layer.anothor_forward, path)
origin_result = layer.anothor_forward(inps)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue(
np.array_equal(origin_result.numpy(), load_result.numpy()))
class TestJitSaveLoadFunctionWithParamCase3(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def test_jit_save_load_function(self):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(5, 6)
def forward(self, x):
return paddle.tanh(x)
@paddle.jit.to_static
def anothor_forward(self, x):
return self._linear(x)
layer = LinearNet()
inps = paddle.rand([3, 5])
origin = layer.anothor_forward(inps)
path = 'test_jit_save_load_function_with_params_case3/func'
paddle.jit.save(layer.anothor_forward, path)
load_func = paddle.jit.load(path)
load_result = load_func(inps)
self.assertTrue(np.array_equal(load_result.numpy(), origin.numpy()))
class TestJitSaveLoadDataParallel(unittest.TestCase): class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path): def verify_inference_correctness(self, layer, path):
layer.eval() layer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册