未验证 提交 23f18f46 编写于 作者: H Hui Zhang 提交者: GitHub

[jit] save multi program into one param and seperate model (#43686)

* save multi program into one param and seperate model

* export class property
上级 d29a1214
...@@ -252,9 +252,11 @@ class StaticFunction(object): ...@@ -252,9 +252,11 @@ class StaticFunction(object):
**kwargs(dict): other arguments like `build_strategy` et.al. **kwargs(dict): other arguments like `build_strategy` et.al.
""" """
# save the instance `self` while decorating a method of class. # save the instance `self` while decorating a method of class.
if inspect.ismethod(function): if inspect.ismethod(function):
self._dygraph_function = getattr(function, '__func__') self._dygraph_function = getattr(function, '__func__')
self._class_instance = getattr(function, '__self__') self._class_instance = getattr(function, '__self__')
self._class_instance._original_funcs[ self._class_instance._original_funcs[
function.__name__] = self._dygraph_function function.__name__] = self._dygraph_function
else: else:
...@@ -272,6 +274,13 @@ class StaticFunction(object): ...@@ -272,6 +274,13 @@ class StaticFunction(object):
self._cuda_graph_capture_mode = "" self._cuda_graph_capture_mode = ""
self._cuda_graph_pool_id = 0 self._cuda_graph_pool_id = 0
self._property = kwargs.get("property", False)
@property
def is_property(self):
# whether is class proproty to be exported.
return self._property
def train(self): def train(self):
if isinstance(self._class_instance, if isinstance(self._class_instance,
layers.Layer) and self._class_instance.training == False: layers.Layer) and self._class_instance.training == False:
...@@ -325,7 +334,8 @@ class StaticFunction(object): ...@@ -325,7 +334,8 @@ class StaticFunction(object):
return self._descriptor_cache[instance] return self._descriptor_cache[instance]
def _clone(self): def _clone(self):
return self.__class__(self._dygraph_function, self._input_spec) return self.__class__(self._dygraph_function, self._input_spec,
**self._kwargs)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" """
...@@ -338,6 +348,8 @@ class StaticFunction(object): ...@@ -338,6 +348,8 @@ class StaticFunction(object):
Return: Return:
Outputs of decorated function. Outputs of decorated function.
""" """
if self._property:
return self._call_dygraph_function(*args, **kwargs)
# 1. call dygraph function directly if not enable `declarative` # 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_to_static: if not self._program_trans.enable_to_static:
...@@ -417,6 +429,15 @@ class StaticFunction(object): ...@@ -417,6 +429,15 @@ class StaticFunction(object):
return dygraph_function(*args, **kwargs) return dygraph_function(*args, **kwargs)
def _raise_when_property(self):
"""raise RuntimeError when property=True
Raises:
RuntimeError: can not call this func when property=True
"""
if self.is_property:
raise RuntimeError("Can not call the func when property=True.")
def get_concrete_program(self, *args, **kwargs): def get_concrete_program(self, *args, **kwargs):
""" """
Returns traced concrete program and inner executable partial layer. Returns traced concrete program and inner executable partial layer.
...@@ -428,6 +449,7 @@ class StaticFunction(object): ...@@ -428,6 +449,7 @@ class StaticFunction(object):
Returns: Returns:
Traced ConcreteProgram and executable translated Layer. Traced ConcreteProgram and executable translated Layer.
""" """
self._raise_when_property()
with_hook = kwargs.get("with_hook", False) with_hook = kwargs.get("with_hook", False)
is_train = kwargs.get("is_train", True) is_train = kwargs.get("is_train", True)
...@@ -518,6 +540,7 @@ class StaticFunction(object): ...@@ -518,6 +540,7 @@ class StaticFunction(object):
input_spec (list[InputSpec], optional): Describes the input of input_spec (list[InputSpec], optional): Describes the input of
the translate function. the translate function.
""" """
self._raise_when_property()
# if specific the `input_spec`, the length of program_cache will always 1, # if specific the `input_spec`, the length of program_cache will always 1,
# else, return the last one. # else, return the last one.
cached_program_len = len(self._program_cache) cached_program_len = len(self._program_cache)
...@@ -670,6 +693,7 @@ class StaticFunction(object): ...@@ -670,6 +693,7 @@ class StaticFunction(object):
""" """
Returns input tensors of recent converted static program. Returns input tensors of recent converted static program.
""" """
self._raise_when_property()
concrete_program = self.concrete_program concrete_program = self.concrete_program
inputs = [ inputs = [
var for var in flatten(concrete_program.inputs) var for var in flatten(concrete_program.inputs)
...@@ -682,6 +706,7 @@ class StaticFunction(object): ...@@ -682,6 +706,7 @@ class StaticFunction(object):
""" """
Returns output tensors of recent converted static program. Returns output tensors of recent converted static program.
""" """
self._raise_when_property()
concrete_program = self.concrete_program concrete_program = self.concrete_program
outputs = [ outputs = [
var for var in flatten(concrete_program.outputs) var for var in flatten(concrete_program.outputs)
...@@ -695,6 +720,7 @@ class StaticFunction(object): ...@@ -695,6 +720,7 @@ class StaticFunction(object):
""" """
Returns recent converted static main program. Returns recent converted static main program.
""" """
self._raise_when_property()
concrete_program = self.concrete_program concrete_program = self.concrete_program
main_program = concrete_program.main_program main_program = concrete_program.main_program
return main_program return main_program
......
...@@ -160,7 +160,10 @@ def copy_decorator_attrs(original_func, decorated_obj): ...@@ -160,7 +160,10 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj return decorated_obj
def declarative(function=None, input_spec=None, build_strategy=None): def declarative(function=None,
input_spec=None,
build_strategy=None,
property=False):
""" """
Converts imperative dygraph APIs into declarative function APIs. Decorator Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns @declarative handles the Program and Executor of static mode and returns
...@@ -178,6 +181,7 @@ def declarative(function=None, input_spec=None, build_strategy=None): ...@@ -178,6 +181,7 @@ def declarative(function=None, input_spec=None, build_strategy=None):
in the computational graph and memory optimization during the execution in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy, of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None. please refer to :code:`paddle.static.BuildStrategy`. The default is None.
property(bool, Optional): whether the fucntion is python property. The default is False.
Returns: Returns:
...@@ -215,7 +219,8 @@ def declarative(function=None, input_spec=None, build_strategy=None): ...@@ -215,7 +219,8 @@ def declarative(function=None, input_spec=None, build_strategy=None):
decorated_obj=StaticFunction( decorated_obj=StaticFunction(
function=python_func, function=python_func,
input_spec=input_spec, input_spec=input_spec,
build_strategy=build_strategy)) build_strategy=build_strategy,
property=property))
return static_layer return static_layer
...@@ -304,6 +309,9 @@ class _SaveLoadConfig(object): ...@@ -304,6 +309,9 @@ class _SaveLoadConfig(object):
self._program_only = False self._program_only = False
self.with_hook = False self.with_hook = False
# if True, multi `StaticFunction` will share params in one file.
self.combine_params = False
@property @property
def output_spec(self): def output_spec(self):
return self._output_spec return self._output_spec
...@@ -371,7 +379,7 @@ class _SaveLoadConfig(object): ...@@ -371,7 +379,7 @@ class _SaveLoadConfig(object):
def _parse_save_configs(configs): def _parse_save_configs(configs):
supported_configs = ['output_spec', "with_hook"] supported_configs = ['output_spec', "with_hook", "use_combine"]
# input check # input check
for key in configs: for key in configs:
...@@ -384,6 +392,7 @@ def _parse_save_configs(configs): ...@@ -384,6 +392,7 @@ def _parse_save_configs(configs):
inner_config = _SaveLoadConfig() inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None) inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False) inner_config.with_hook = configs.get('with_hook', False)
inner_config.combine_params = configs.get("use_combine", False)
return inner_config return inner_config
...@@ -840,6 +849,9 @@ def save(layer, path, input_spec=None, **configs): ...@@ -840,6 +849,9 @@ def save(layer, path, input_spec=None, **configs):
# whether outermost layer has pre/post hook, if does, we need also save # whether outermost layer has pre/post hook, if does, we need also save
# these operators in program. # these operators in program.
with_hook = configs.with_hook with_hook = configs.with_hook
combine_params = configs.combine_params
if combine_params:
configs._program_only = True
scope = core.Scope() scope = core.Scope()
extra_var_info = dict() extra_var_info = dict()
...@@ -852,10 +864,21 @@ def save(layer, path, input_spec=None, **configs): ...@@ -852,10 +864,21 @@ def save(layer, path, input_spec=None, **configs):
functions = [ functions = [
layer, layer,
] ]
all_vars = set()
property_vals = [] # (value, key)
for attr_func in functions: for attr_func in functions:
if isinstance(layer, Layer): if isinstance(layer, Layer):
static_func = getattr(inner_layer, attr_func, None) static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction): if isinstance(static_func, StaticFunction):
if static_func.is_property:
# property method to be exported
immediate_val = static_func()
property_vals.append(
(immediate_val,
layer.__class__.__name__ + '.' + attr_func))
continue
concrete_program = static_func.concrete_program_specify_input_spec( concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec, with_hook=with_hook) inner_input_spec, with_hook=with_hook)
elif 'forward' == attr_func: elif 'forward' == attr_func:
...@@ -875,10 +898,15 @@ def save(layer, path, input_spec=None, **configs): ...@@ -875,10 +898,15 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec = None inner_input_spec = None
else: else:
continue continue
else: else:
# When layer is a function # When layer is a function
if isinstance(attr_func, StaticFunction): if isinstance(attr_func, StaticFunction):
if attr_func.is_property:
# property method to be exported
immediate_val = attr_func()
property_vals.append((immediate_val, attr_func))
continue
concrete_program = attr_func.concrete_program_specify_input_spec( concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec) inner_input_spec)
else: else:
...@@ -894,6 +922,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -894,6 +922,7 @@ def save(layer, path, input_spec=None, **configs):
'`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`' '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'
.format(layer)) .format(layer))
# when save multi `StaticFunction`, all `StaticFunction` share params.
dygraph_state_dict = None dygraph_state_dict = None
if isinstance(inner_layer, Layer): if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.to_static_state_dict() dygraph_state_dict = inner_layer.to_static_state_dict()
...@@ -932,15 +961,12 @@ def save(layer, path, input_spec=None, **configs): ...@@ -932,15 +961,12 @@ def save(layer, path, input_spec=None, **configs):
if param_or_buffer.name not in extra_var_info: if param_or_buffer.name not in extra_var_info:
extra_info_dict = dict() extra_info_dict = dict()
if param_or_buffer.name in state_names_dict: if param_or_buffer.name in state_names_dict:
extra_info_dict[ extra_info_dict['structured_name'] = state_names_dict[
'structured_name'] = state_names_dict[
param_or_buffer.name] param_or_buffer.name]
extra_info_dict[ extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient 'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, if isinstance(param_or_buffer, (ParamBase, EagerParamBase)):
(ParamBase, EagerParamBase)): extra_info_dict['trainable'] = param_or_buffer.trainable
extra_info_dict[
'trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict extra_var_info[param_or_buffer.name] = extra_info_dict
# 4. build input & output of save_infernece_model # 4. build input & output of save_infernece_model
...@@ -991,6 +1017,22 @@ def save(layer, path, input_spec=None, **configs): ...@@ -991,6 +1017,22 @@ def save(layer, path, input_spec=None, **configs):
program_only=configs._program_only, program_only=configs._program_only,
clip_extra=False) clip_extra=False)
# collect all vars
for var in concrete_program.main_program.list_vars():
all_vars.add(var)
# save shared params
if combine_params:
params_filename = file_prefix + INFER_PARAMS_SUFFIX
with scope_guard(scope):
paddle.static.save_vars(Executor(_current_expected_place()),
dirname=model_path,
vars=list(
filter(paddle.fluid.io.is_persistable,
all_vars)),
filename=params_filename)
# TODO: save property
# NOTE(chenweihang): [ Save extra variable info ] # NOTE(chenweihang): [ Save extra variable info ]
# save_inference_model will lose some important variable information, including: # save_inference_model will lose some important variable information, including:
# - Variable name and correspondence (when saved variables as one file) # - Variable name and correspondence (when saved variables as one file)
......
...@@ -1153,6 +1153,65 @@ class LayerSaved(paddle.nn.Layer): ...@@ -1153,6 +1153,65 @@ class LayerSaved(paddle.nn.Layer):
return self._linear_2(y) return self._linear_2(y)
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.fc1 = paddle.nn.Linear(4, 4)
self.fc2 = paddle.nn.Linear(4, 4)
self.bias = 0.4
self.flag = paddle.ones([2], dtype="int32")
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def log_softmax(self, input):
return paddle.nn.functional.log_softmax(input, axis=-1)
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def forward(self, x):
out = self.fc1(x)
out = paddle.nn.functional.relu(out)
out = paddle.mean(out)
return out
@paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')])
def infer(self, input):
out = self.fc2(input)
out = out + self.bias
out = paddle.mean(out)
return out
# For extra Python float
@paddle.jit.to_static(property=True)
def fbias(self):
return self.bias + 1
# For extra Tensor
@paddle.jit.to_static(property=True)
def fflag(self):
return self.flag
class TestJitSaveCombine(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_save_load_finetune_load(self):
model_path = os.path.join(self.temp_dir.name,
"test_jit_save_combine/model")
# Use new namespace
with unique_name.guard():
net = Net()
#save
paddle.jit.save(net, model_path, use_combine=True)
class LayerLoadFinetune(paddle.nn.Layer): class LayerLoadFinetune(paddle.nn.Layer):
def __init__(self, in_size, out_size, load_path): def __init__(self, in_size, out_size, load_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册