未验证 提交 23f18f46 编写于 作者: H Hui Zhang 提交者: GitHub

[jit] save multi program into one param and seperate model (#43686)

* save multi program into one param and seperate model

* export class property
上级 d29a1214
......@@ -252,9 +252,11 @@ class StaticFunction(object):
**kwargs(dict): other arguments like `build_strategy` et.al.
"""
# save the instance `self` while decorating a method of class.
if inspect.ismethod(function):
self._dygraph_function = getattr(function, '__func__')
self._class_instance = getattr(function, '__self__')
self._class_instance._original_funcs[
function.__name__] = self._dygraph_function
else:
......@@ -272,6 +274,13 @@ class StaticFunction(object):
self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0
self._property = kwargs.get("property", False)
@property
def is_property(self):
# whether is class proproty to be exported.
return self._property
def train(self):
if isinstance(self._class_instance,
layers.Layer) and self._class_instance.training == False:
......@@ -325,7 +334,8 @@ class StaticFunction(object):
return self._descriptor_cache[instance]
def _clone(self):
return self.__class__(self._dygraph_function, self._input_spec)
return self.__class__(self._dygraph_function, self._input_spec,
**self._kwargs)
def __call__(self, *args, **kwargs):
"""
......@@ -338,6 +348,8 @@ class StaticFunction(object):
Return:
Outputs of decorated function.
"""
if self._property:
return self._call_dygraph_function(*args, **kwargs)
# 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_to_static:
......@@ -417,6 +429,15 @@ class StaticFunction(object):
return dygraph_function(*args, **kwargs)
def _raise_when_property(self):
"""raise RuntimeError when property=True
Raises:
RuntimeError: can not call this func when property=True
"""
if self.is_property:
raise RuntimeError("Can not call the func when property=True.")
def get_concrete_program(self, *args, **kwargs):
"""
Returns traced concrete program and inner executable partial layer.
......@@ -428,6 +449,7 @@ class StaticFunction(object):
Returns:
Traced ConcreteProgram and executable translated Layer.
"""
self._raise_when_property()
with_hook = kwargs.get("with_hook", False)
is_train = kwargs.get("is_train", True)
......@@ -518,6 +540,7 @@ class StaticFunction(object):
input_spec (list[InputSpec], optional): Describes the input of
the translate function.
"""
self._raise_when_property()
# if specific the `input_spec`, the length of program_cache will always 1,
# else, return the last one.
cached_program_len = len(self._program_cache)
......@@ -670,6 +693,7 @@ class StaticFunction(object):
"""
Returns input tensors of recent converted static program.
"""
self._raise_when_property()
concrete_program = self.concrete_program
inputs = [
var for var in flatten(concrete_program.inputs)
......@@ -682,6 +706,7 @@ class StaticFunction(object):
"""
Returns output tensors of recent converted static program.
"""
self._raise_when_property()
concrete_program = self.concrete_program
outputs = [
var for var in flatten(concrete_program.outputs)
......@@ -695,6 +720,7 @@ class StaticFunction(object):
"""
Returns recent converted static main program.
"""
self._raise_when_property()
concrete_program = self.concrete_program
main_program = concrete_program.main_program
return main_program
......
......@@ -160,7 +160,10 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj
def declarative(function=None, input_spec=None, build_strategy=None):
def declarative(function=None,
input_spec=None,
build_strategy=None,
property=False):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns
......@@ -178,6 +181,7 @@ def declarative(function=None, input_spec=None, build_strategy=None):
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None.
property(bool, Optional): whether the fucntion is python property. The default is False.
Returns:
......@@ -215,7 +219,8 @@ def declarative(function=None, input_spec=None, build_strategy=None):
decorated_obj=StaticFunction(
function=python_func,
input_spec=input_spec,
build_strategy=build_strategy))
build_strategy=build_strategy,
property=property))
return static_layer
......@@ -304,6 +309,9 @@ class _SaveLoadConfig(object):
self._program_only = False
self.with_hook = False
# if True, multi `StaticFunction` will share params in one file.
self.combine_params = False
@property
def output_spec(self):
return self._output_spec
......@@ -371,7 +379,7 @@ class _SaveLoadConfig(object):
def _parse_save_configs(configs):
supported_configs = ['output_spec', "with_hook"]
supported_configs = ['output_spec', "with_hook", "use_combine"]
# input check
for key in configs:
......@@ -384,6 +392,7 @@ def _parse_save_configs(configs):
inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False)
inner_config.combine_params = configs.get("use_combine", False)
return inner_config
......@@ -840,6 +849,9 @@ def save(layer, path, input_spec=None, **configs):
# whether outermost layer has pre/post hook, if does, we need also save
# these operators in program.
with_hook = configs.with_hook
combine_params = configs.combine_params
if combine_params:
configs._program_only = True
scope = core.Scope()
extra_var_info = dict()
......@@ -852,10 +864,21 @@ def save(layer, path, input_spec=None, **configs):
functions = [
layer,
]
all_vars = set()
property_vals = [] # (value, key)
for attr_func in functions:
if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction):
if static_func.is_property:
# property method to be exported
immediate_val = static_func()
property_vals.append(
(immediate_val,
layer.__class__.__name__ + '.' + attr_func))
continue
concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec, with_hook=with_hook)
elif 'forward' == attr_func:
......@@ -875,10 +898,15 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec = None
else:
continue
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
if attr_func.is_property:
# property method to be exported
immediate_val = attr_func()
property_vals.append((immediate_val, attr_func))
continue
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
......@@ -894,6 +922,7 @@ def save(layer, path, input_spec=None, **configs):
'`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'
.format(layer))
# when save multi `StaticFunction`, all `StaticFunction` share params.
dygraph_state_dict = None
if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.to_static_state_dict()
......@@ -932,15 +961,12 @@ def save(layer, path, input_spec=None, **configs):
if param_or_buffer.name not in extra_var_info:
extra_info_dict = dict()
if param_or_buffer.name in state_names_dict:
extra_info_dict[
'structured_name'] = state_names_dict[
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer,
(ParamBase, EagerParamBase)):
extra_info_dict[
'trainable'] = param_or_buffer.trainable
if isinstance(param_or_buffer, (ParamBase, EagerParamBase)):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
# 4. build input & output of save_infernece_model
......@@ -991,6 +1017,22 @@ def save(layer, path, input_spec=None, **configs):
program_only=configs._program_only,
clip_extra=False)
# collect all vars
for var in concrete_program.main_program.list_vars():
all_vars.add(var)
# save shared params
if combine_params:
params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope):
paddle.static.save_vars(Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(paddle.fluid.io.is_persistable,
all_vars)),
filename=params_filename)
# TODO: save property
# NOTE(chenweihang): [ Save extra variable info ]
# save_inference_model will lose some important variable information, including:
# - Variable name and correspondence (when saved variables as one file)
......
......@@ -1153,6 +1153,65 @@ class LayerSaved(paddle.nn.Layer):
return self._linear_2(y)
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.fc1 = paddle.nn.Linear(4, 4)
self.fc2 = paddle.nn.Linear(4, 4)
self.bias = 0.4
self.flag = paddle.ones([2], dtype="int32")
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def log_softmax(self, input):
return paddle.nn.functional.log_softmax(input, axis=-1)
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def forward(self, x):
out = self.fc1(x)
out = paddle.nn.functional.relu(out)
out = paddle.mean(out)
return out
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def infer(self, input):
out = self.fc2(input)
out = out + self.bias
out = paddle.mean(out)
return out
# For extra Python float
@paddle.jit.to_static(property=True)
def fbias(self):
return self.bias + 1
# For extra Tensor
@paddle.jit.to_static(property=True)
def fflag(self):
return self.flag
class TestJitSaveCombine(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_save_load_finetune_load(self):
model_path = os.path.join(self.temp_dir.name,
"test_jit_save_combine/model")
# Use new namespace
with unique_name.guard():
net = Net()
#save
paddle.jit.save(net, model_path, use_combine=True)
class LayerLoadFinetune(paddle.nn.Layer):
def __init__(self, in_size, out_size, load_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册