未验证 提交 1476e1f9 编写于 作者: W WeiXin 提交者: GitHub

save model after jit.load (#28748)

* Changed a variable name error

* Add comments

* Move member functions of TranslatedLayer out of function

* edit code according to review

* Edit input argument of '_run_static_graph'

* reset due to Segmentation fault

* rename variables when stitching graph

* modify code according CI

* Add comments to '__i_m_p_l__'

* remove blanks befor 'Get...'

* edit code according to review

* Add a comment to '_execution_method_creator'

* Edit a comment to '_execution_method_creator'
上级 0239f796
......@@ -25,8 +25,10 @@ from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_varargs_name
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.io import TranslatedLayer
class FunctionSpec(object):
......@@ -45,6 +47,11 @@ class FunctionSpec(object):
# parse full argument names list.
self._arg_names, self._default_kwargs = parse_arg_and_kwargs(function)
# parse *args
self.varargs_name = parse_varargs_name(function)
if self.varargs_name is not None and isinstance(function.__self__,
TranslatedLayer):
self._arg_names += function.__self__._input_args_names
def unified_args_and_kwargs(self, args, kwargs):
"""
......
......@@ -113,6 +113,15 @@ def parse_arg_and_kwargs(function):
return arg_names, default_kwargs
def parse_varargs_name(function):
"""
Returns varargs name string of function. e.g: 'input' from `foo(x, *input)`
"""
fullargspec = getfullargspec(function)
varargs = fullargspec.varargs
return varargs
def type_name(v):
return type(v).__name__
......@@ -478,11 +487,17 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
else:
module = SourceFileLoader(module_name, f.name).load_module()
func_name = dyfunc.__name__
if not hasattr(module, func_name):
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'):
callable_func = getattr(module, '__i_m_p_l__')
callable_func.__name__ = func_name
elif hasattr(module, func_name):
callable_func = getattr(module, func_name)
else:
raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name)
callable_func = getattr(module, func_name)
# After transform dygraph function into callable_func saved in tmp file,
# it lost the global variables from imported statements or defined in source file.
# Recovers the necessary variables by `__globals__`.
......
此差异已折叠。
......@@ -25,6 +25,7 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
from paddle.fluid import unique_name
BATCH_SIZE = 32
BATCH_NUM = 10
......@@ -863,6 +864,94 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
class LayerSaved(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LayerSaved, self).__init__()
self.hidden = 100
self._linear_0 = Linear(in_size, self.hidden)
self._linear_1_0 = Linear(self.hidden, self.hidden)
self._linear_1_1 = Linear(self.hidden, self.hidden)
self._linear_2 = Linear(self.hidden, out_size)
self._scale = paddle.to_tensor(9.9)
@paddle.jit.to_static
def forward(self, x):
y = self._linear_0(x)
# Multiple blocks
if x.shape[0] == 1:
y = self._linear_1_0(y)
else:
y += self._linear_1_1(y + self._scale)
return self._linear_2(y)
class LayerLoadFinetune(paddle.nn.Layer):
def __init__(self, in_size, out_size, load_path):
super(LayerLoadFinetune, self).__init__()
# Test duplicate name
self._linear_0 = Linear(in_size, in_size)
self._linear_1_0 = Linear(out_size, in_size)
self._linear_1_1 = Linear(out_size, in_size)
self._linear_2 = Linear(out_size, out_size)
self._scale = paddle.to_tensor(9.9)
# Load multiple times
self._load_l1 = paddle.jit.load(load_path)
self._load_l2 = paddle.jit.load(load_path)
@paddle.jit.to_static
def forward(self, x):
y = self._linear_0(x)
y = self._load_l1(y)
# Multiple blocks
if x.shape[0] == 1:
y = self._linear_1_0(y)
y = self._load_l1(y)
else:
y += self._linear_1_1(x + self._scale)
y = self._load_l2(y)
y = self._linear_1_0(y)
y = self._load_l1(y)
y = self._linear_1_0(y)
# Use the same layer multiple times.
y = self._load_l1(y)
return y
class TestJitSaveLoadFinetuneLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
def test_save_load_finetune_load(self):
model_path = "test_jit_save_load_finetune_load/model"
IMAGE_SIZE = 224
inps0 = paddle.randn([1, IMAGE_SIZE])
inps1 = paddle.randn([2, IMAGE_SIZE])
# Use new namespace
with unique_name.guard():
layer_save = LayerSaved(IMAGE_SIZE, IMAGE_SIZE)
layer_save(inps0)
#save
paddle.jit.save(layer_save, model_path)
#load
with unique_name.guard():
layer_load = LayerLoadFinetune(IMAGE_SIZE, IMAGE_SIZE, model_path)
#train
train(layer_load, input_size=IMAGE_SIZE)
result_00 = layer_load(inps0)
result_01 = layer_load(inps1)
#save
paddle.jit.save(layer_load, model_path)
#load
layer_finetune = paddle.jit.load(model_path)
result_10 = layer_finetune(inps0)
result_11 = layer_finetune(inps1)
self.assertTrue(float((result_00 - result_10).abs().max()) < 1e-5)
self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5)
class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path):
layer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册