From 23f18f46b21d5ebaa4ce6b529bfe80d959284b3e Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 30 Jun 2022 19:17:49 +0800 Subject: [PATCH] [jit] save multi program into one param and seperate model (#43686) * save multi program into one param and seperate model * export class property --- .../dygraph_to_static/program_translator.py | 28 ++++- python/paddle/fluid/dygraph/jit.py | 108 ++++++++++++------ .../tests/unittests/test_jit_save_load.py | 59 ++++++++++ 3 files changed, 161 insertions(+), 34 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 49a218412c9..43ce1fae16f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -252,9 +252,11 @@ class StaticFunction(object): **kwargs(dict): other arguments like `build_strategy` et.al. """ # save the instance `self` while decorating a method of class. + if inspect.ismethod(function): self._dygraph_function = getattr(function, '__func__') self._class_instance = getattr(function, '__self__') + self._class_instance._original_funcs[ function.__name__] = self._dygraph_function else: @@ -272,6 +274,13 @@ class StaticFunction(object): self._cuda_graph_capture_mode = "" self._cuda_graph_pool_id = 0 + self._property = kwargs.get("property", False) + + @property + def is_property(self): + # whether is class proproty to be exported. + return self._property + def train(self): if isinstance(self._class_instance, layers.Layer) and self._class_instance.training == False: @@ -325,7 +334,8 @@ class StaticFunction(object): return self._descriptor_cache[instance] def _clone(self): - return self.__class__(self._dygraph_function, self._input_spec) + return self.__class__(self._dygraph_function, self._input_spec, + **self._kwargs) def __call__(self, *args, **kwargs): """ @@ -338,6 +348,8 @@ class StaticFunction(object): Return: Outputs of decorated function. """ + if self._property: + return self._call_dygraph_function(*args, **kwargs) # 1. call dygraph function directly if not enable `declarative` if not self._program_trans.enable_to_static: @@ -417,6 +429,15 @@ class StaticFunction(object): return dygraph_function(*args, **kwargs) + def _raise_when_property(self): + """raise RuntimeError when property=True + + Raises: + RuntimeError: can not call this func when property=True + """ + if self.is_property: + raise RuntimeError("Can not call the func when property=True.") + def get_concrete_program(self, *args, **kwargs): """ Returns traced concrete program and inner executable partial layer. @@ -428,6 +449,7 @@ class StaticFunction(object): Returns: Traced ConcreteProgram and executable translated Layer. """ + self._raise_when_property() with_hook = kwargs.get("with_hook", False) is_train = kwargs.get("is_train", True) @@ -518,6 +540,7 @@ class StaticFunction(object): input_spec (list[InputSpec], optional): Describes the input of the translate function. """ + self._raise_when_property() # if specific the `input_spec`, the length of program_cache will always 1, # else, return the last one. cached_program_len = len(self._program_cache) @@ -670,6 +693,7 @@ class StaticFunction(object): """ Returns input tensors of recent converted static program. """ + self._raise_when_property() concrete_program = self.concrete_program inputs = [ var for var in flatten(concrete_program.inputs) @@ -682,6 +706,7 @@ class StaticFunction(object): """ Returns output tensors of recent converted static program. """ + self._raise_when_property() concrete_program = self.concrete_program outputs = [ var for var in flatten(concrete_program.outputs) @@ -695,6 +720,7 @@ class StaticFunction(object): """ Returns recent converted static main program. """ + self._raise_when_property() concrete_program = self.concrete_program main_program = concrete_program.main_program return main_program diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index b6847efab1d..393f1c15704 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -160,7 +160,10 @@ def copy_decorator_attrs(original_func, decorated_obj): return decorated_obj -def declarative(function=None, input_spec=None, build_strategy=None): +def declarative(function=None, + input_spec=None, + build_strategy=None, + property=False): """ Converts imperative dygraph APIs into declarative function APIs. Decorator @declarative handles the Program and Executor of static mode and returns @@ -178,6 +181,7 @@ def declarative(function=None, input_spec=None, build_strategy=None): in the computational graph and memory optimization during the execution of the computational graph. For more information about build_strategy, please refer to :code:`paddle.static.BuildStrategy`. The default is None. + property(bool, Optional): whether the fucntion is python property. The default is False. Returns: @@ -215,7 +219,8 @@ def declarative(function=None, input_spec=None, build_strategy=None): decorated_obj=StaticFunction( function=python_func, input_spec=input_spec, - build_strategy=build_strategy)) + build_strategy=build_strategy, + property=property)) return static_layer @@ -304,6 +309,9 @@ class _SaveLoadConfig(object): self._program_only = False self.with_hook = False + # if True, multi `StaticFunction` will share params in one file. + self.combine_params = False + @property def output_spec(self): return self._output_spec @@ -371,7 +379,7 @@ class _SaveLoadConfig(object): def _parse_save_configs(configs): - supported_configs = ['output_spec', "with_hook"] + supported_configs = ['output_spec', "with_hook", "use_combine"] # input check for key in configs: @@ -384,6 +392,7 @@ def _parse_save_configs(configs): inner_config = _SaveLoadConfig() inner_config.output_spec = configs.get('output_spec', None) inner_config.with_hook = configs.get('with_hook', False) + inner_config.combine_params = configs.get("use_combine", False) return inner_config @@ -840,6 +849,9 @@ def save(layer, path, input_spec=None, **configs): # whether outermost layer has pre/post hook, if does, we need also save # these operators in program. with_hook = configs.with_hook + combine_params = configs.combine_params + if combine_params: + configs._program_only = True scope = core.Scope() extra_var_info = dict() @@ -852,10 +864,21 @@ def save(layer, path, input_spec=None, **configs): functions = [ layer, ] + + all_vars = set() + property_vals = [] # (value, key) for attr_func in functions: if isinstance(layer, Layer): static_func = getattr(inner_layer, attr_func, None) if isinstance(static_func, StaticFunction): + if static_func.is_property: + # property method to be exported + immediate_val = static_func() + property_vals.append( + (immediate_val, + layer.__class__.__name__ + '.' + attr_func)) + continue + concrete_program = static_func.concrete_program_specify_input_spec( inner_input_spec, with_hook=with_hook) elif 'forward' == attr_func: @@ -875,10 +898,15 @@ def save(layer, path, input_spec=None, **configs): inner_input_spec = None else: continue - else: # When layer is a function if isinstance(attr_func, StaticFunction): + if attr_func.is_property: + # property method to be exported + immediate_val = attr_func() + property_vals.append((immediate_val, attr_func)) + continue + concrete_program = attr_func.concrete_program_specify_input_spec( inner_input_spec) else: @@ -894,6 +922,7 @@ def save(layer, path, input_spec=None, **configs): '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`' .format(layer)) + # when save multi `StaticFunction`, all `StaticFunction` share params. dygraph_state_dict = None if isinstance(inner_layer, Layer): dygraph_state_dict = inner_layer.to_static_state_dict() @@ -913,35 +942,32 @@ def save(layer, path, input_spec=None, **configs): state_names_dict[var.name] = structured_name state_var_dict[var.name] = var - # 3. share parameters from Layer to scope & record var info - with dygraph.guard(): - for param_or_buffer in concrete_program.parameters: - # share to scope - if param_or_buffer.type == core.VarDesc.VarType.VOCAB: - scr_tensor = param_or_buffer.value().get_map_tensor() - tgt_var = scope.var(param_or_buffer.name) - tgt_var.set_vocab(scr_tensor) - else: - param_or_buffer_tensor = scope.var( - param_or_buffer.name).get_tensor() - #src_tensor = param_or_buffer.value().get_tensor() - src_tensor = state_var_dict[ - param_or_buffer.name].value().get_tensor() - param_or_buffer_tensor._share_data_with(src_tensor) - # record var info - if param_or_buffer.name not in extra_var_info: - extra_info_dict = dict() - if param_or_buffer.name in state_names_dict: - extra_info_dict[ - 'structured_name'] = state_names_dict[ - param_or_buffer.name] - extra_info_dict[ - 'stop_gradient'] = param_or_buffer.stop_gradient - if isinstance(param_or_buffer, - (ParamBase, EagerParamBase)): - extra_info_dict[ - 'trainable'] = param_or_buffer.trainable - extra_var_info[param_or_buffer.name] = extra_info_dict + # 3. share parameters from Layer to scope & record var info + with dygraph.guard(): + for param_or_buffer in concrete_program.parameters: + # share to scope + if param_or_buffer.type == core.VarDesc.VarType.VOCAB: + scr_tensor = param_or_buffer.value().get_map_tensor() + tgt_var = scope.var(param_or_buffer.name) + tgt_var.set_vocab(scr_tensor) + else: + param_or_buffer_tensor = scope.var( + param_or_buffer.name).get_tensor() + #src_tensor = param_or_buffer.value().get_tensor() + src_tensor = state_var_dict[ + param_or_buffer.name].value().get_tensor() + param_or_buffer_tensor._share_data_with(src_tensor) + # record var info + if param_or_buffer.name not in extra_var_info: + extra_info_dict = dict() + if param_or_buffer.name in state_names_dict: + extra_info_dict['structured_name'] = state_names_dict[ + param_or_buffer.name] + extra_info_dict[ + 'stop_gradient'] = param_or_buffer.stop_gradient + if isinstance(param_or_buffer, (ParamBase, EagerParamBase)): + extra_info_dict['trainable'] = param_or_buffer.trainable + extra_var_info[param_or_buffer.name] = extra_info_dict # 4. build input & output of save_infernece_model # NOTE(chenweihang): [ Get input variables name ] @@ -991,6 +1017,22 @@ def save(layer, path, input_spec=None, **configs): program_only=configs._program_only, clip_extra=False) + # collect all vars + for var in concrete_program.main_program.list_vars(): + all_vars.add(var) + + # save shared params + if combine_params: + params_filename = file_prefix + INFER_PARAMS_SUFFIX + with scope_guard(scope): + paddle.static.save_vars(Executor(_current_expected_place()), + dirname=model_path, + vars=list( + filter(paddle.fluid.io.is_persistable, + all_vars)), + filename=params_filename) + # TODO: save property + # NOTE(chenweihang): [ Save extra variable info ] # save_inference_model will lose some important variable information, including: # - Variable name and correspondence (when saved variables as one file) diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index bf5ccf1a854..f467fbe4888 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1153,6 +1153,65 @@ class LayerSaved(paddle.nn.Layer): return self._linear_2(y) +class Net(paddle.nn.Layer): + + def __init__(self): + super(Net, self).__init__() + self.fc1 = paddle.nn.Linear(4, 4) + self.fc2 = paddle.nn.Linear(4, 4) + self.bias = 0.4 + self.flag = paddle.ones([2], dtype="int32") + + @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) + def log_softmax(self, input): + return paddle.nn.functional.log_softmax(input, axis=-1) + + @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) + def forward(self, x): + out = self.fc1(x) + out = paddle.nn.functional.relu(out) + out = paddle.mean(out) + return out + + @paddle.jit.to_static(input_spec=[InputSpec([None, 4], dtype='float32')]) + def infer(self, input): + out = self.fc2(input) + out = out + self.bias + out = paddle.mean(out) + return out + + # For extra Python float + @paddle.jit.to_static(property=True) + def fbias(self): + return self.bias + 1 + + # For extra Tensor + @paddle.jit.to_static(property=True) + def fflag(self): + return self.flag + + +class TestJitSaveCombine(unittest.TestCase): + + def setUp(self): + # enable dygraph mode + paddle.disable_static() + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_save_load_finetune_load(self): + model_path = os.path.join(self.temp_dir.name, + "test_jit_save_combine/model") + + # Use new namespace + with unique_name.guard(): + net = Net() + #save + paddle.jit.save(net, model_path, use_combine=True) + + class LayerLoadFinetune(paddle.nn.Layer): def __init__(self, in_size, out_size, load_path): -- GitLab