From 93d34f835b91af1bf94229626210afea789b7c48 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Thu, 29 Apr 2021 13:59:48 +0800 Subject: [PATCH] 'jit.save/load' support save/load function without parameters. (#32430) (#32613) * jit.save/load support function. * delete unnittest test_jit_load_model_incomplete. * edit code according to CI * Modify the documentation. * add note to doc. --- python/paddle/fluid/dygraph/io.py | 4 + python/paddle/fluid/dygraph/jit.py | 180 +++++++++++------- .../tests/unittests/test_jit_save_load.py | 66 ++++++- 3 files changed, 177 insertions(+), 73 deletions(-) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index ce40fde1630..33eb16f1b2b 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -650,6 +650,7 @@ def _construct_params_and_buffers(model_path, append_suffix=True): var_info_filename = str(params_filename) + ".info" var_info_path = os.path.join(model_path, var_info_filename) + params_path = os.path.join(model_path, str(params_filename)) if os.path.exists(var_info_path): var_dict = _load_persistable_vars(model_path, var_info_path, @@ -671,6 +672,9 @@ def _construct_params_and_buffers(model_path, var_dict.update( _load_persistable_vars(model_path, var_info_path, programs[ func_name], file_name)) + elif params_filename is not None and not os.path.exists(params_path): + # When saving XX, there is only '*.pdmodel' + return dict() else: var_dict = _load_persistable_vars_by_program( model_path, programs['forward'], params_filename) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 4c7c7b17eb1..352a377fa3a 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -19,6 +19,7 @@ import pickle import warnings import functools from collections import OrderedDict +import inspect import six import paddle @@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config): @switch_to_static_graph def save(layer, path, input_spec=None, **configs): """ - Saves input Layer as ``paddle.jit.TranslatedLayer`` + Saves input Layer or function as ``paddle.jit.TranslatedLayer`` format model, which can be used for inference or fine-tuning after loading. It will save the translated program and all related persistable @@ -522,8 +523,12 @@ def save(layer, path, input_spec=None, **configs): - ``paddle.static.load_inference_model`` - Other C++ inference APIs + .. note:: + When using ``paddle.jit.save`` to save a function, parameters will not be saved. If you have to + save the parameter, please pass the Layer containing function and parameter to ``paddle.jit.save``. + Args: - layer (Layer): The Layer to be saved. + layer (Layer|function): The Layer or function to be saved. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. input_spec (list[InputSpec|Tensor]|tuple[InputSpec|Tensor], optional): Describes the input of the saved model's forward method, which can be described by InputSpec or example Tensor. If None, all input variables of @@ -543,6 +548,7 @@ def save(layer, path, input_spec=None, **configs): Examples: .. code-block:: python + # example 1: save layer import numpy as np import paddle import paddle.nn as nn @@ -609,6 +615,28 @@ def save(layer, path, input_spec=None, **configs): # save path = "example_model/linear" paddle.jit.save(layer, path) + + # example 2: save function + import paddle + from paddle.static import InputSpec + + + def save_function(): + @paddle.jit.to_static + def fun(inputs): + return paddle.tanh(inputs) + + path = 'test_jit_save_load_function_1/func' + inps = paddle.rand([3, 6]) + origin = fun(inps) + + paddle.jit.save(fun, path) + load_func = paddle.jit.load(path) + + load_result = load_func(inps) + print((load_result - origin).abs().max() < 1e-10) + + save_function() """ # 1. input build & check @@ -617,9 +645,11 @@ def save(layer, path, input_spec=None, **configs): raise RuntimeError( "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False." ) - if not isinstance(layer, Layer): + + if not (isinstance(layer, Layer) or inspect.isfunction(layer) or isinstance( + layer, StaticFunction)): raise TypeError( - "The input layer of paddle.jit.save should be 'Layer', but received layer type is %s." + "The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s." % type(layer)) # NOTE(chenweihang): If the input layer be wrapped by DataParallel, @@ -647,13 +677,15 @@ def save(layer, path, input_spec=None, **configs): # avoid change user given input_spec inner_input_spec = None if input_spec is not None: - for attr_func in dir(inner_layer): - static_func = getattr(inner_layer, attr_func, None) - if isinstance(static_func, - StaticFunction) and 'forward' != attr_func: - raise ValueError( - "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s." - % type(input_spec)) + if isinstance(layer, Layer): + for attr_func in dir(inner_layer): + static_func = getattr(inner_layer, attr_func, None) + if isinstance(static_func, + StaticFunction) and 'forward' != attr_func: + raise ValueError( + "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s." + % type(input_spec)) + if not isinstance(input_spec, (list, tuple)): raise TypeError( "The input input_spec should be 'list', but received input_spec's type is %s." @@ -674,29 +706,74 @@ def save(layer, path, input_spec=None, **configs): configs = _parse_save_configs(configs) scope = core.Scope() extra_var_info = dict() - for attr_func in dir(inner_layer): - static_func = getattr(inner_layer, attr_func, None) - if isinstance(static_func, StaticFunction): - concrete_program = static_func.concrete_program_specify_input_spec( - inner_input_spec) - elif 'forward' == attr_func: - # transform in jit.save, if input_spec is incomplete, declarative will throw error - # inner_input_spec is list[InputSpec], it should be packed with same sturcture - # as original input_spec here. - if inner_input_spec: - inner_input_spec = pack_sequence_as(input_spec, - inner_input_spec) - static_forward = declarative( - inner_layer.forward, input_spec=inner_input_spec) - concrete_program = static_forward.concrete_program - # the input_spec has been used in declarative, which is equal to - # @declarative with input_spec and jit.save without input_spec, - # avoid needless warning - inner_input_spec = None + if isinstance(layer, Layer): + functions = dir(inner_layer) + else: + # layer is function + functions = [layer, ] + for attr_func in functions: + if isinstance(layer, Layer): + static_func = getattr(inner_layer, attr_func, None) + if isinstance(static_func, StaticFunction): + concrete_program = static_func.concrete_program_specify_input_spec( + inner_input_spec) + elif 'forward' == attr_func: + # transform in jit.save, if input_spec is incomplete, declarative will throw error + # inner_input_spec is list[InputSpec], it should be packed with same sturcture + # as original input_spec here. + if inner_input_spec: + inner_input_spec = pack_sequence_as(input_spec, + inner_input_spec) + static_forward = declarative( + inner_layer.forward, input_spec=inner_input_spec) + concrete_program = static_forward.concrete_program + # the input_spec has been used in declarative, which is equal to + # @declarative with input_spec and jit.save without input_spec, + # avoid needless warning + inner_input_spec = None + else: + continue + + # NOTE(chenweihang): we maintain the mapping of variable name to + # structured name, the buffer variable (non-persistable) + # saved to inference program may not need by dygraph Layer, + # we only record the state_dict variable's structured name + state_names_dict = dict() + for structured_name, var in six.iteritems(inner_layer.state_dict()): + state_names_dict[var.name] = structured_name + + # 3. share parameters from Layer to scope & record var info + for param_or_buffer in concrete_program.parameters: + # share to scope + param_or_buffer_tensor = scope.var( + param_or_buffer.name).get_tensor() + src_tensor = param_or_buffer.value().get_tensor() + param_or_buffer_tensor._share_data_with(src_tensor) + # record var info + if param_or_buffer.name not in extra_var_info: + extra_info_dict = dict() + if param_or_buffer.name in state_names_dict: + extra_info_dict['structured_name'] = state_names_dict[ + param_or_buffer.name] + extra_info_dict[ + 'stop_gradient'] = param_or_buffer.stop_gradient + if isinstance(param_or_buffer, ParamBase): + extra_info_dict['trainable'] = param_or_buffer.trainable + extra_var_info[param_or_buffer.name] = extra_info_dict else: - continue - - # 3. build input & output of save_infernece_model + # When layer is a function + if isinstance(attr_func, StaticFunction): + concrete_program = attr_func.concrete_program_specify_input_spec( + inner_input_spec) + else: + if inner_input_spec: + inner_input_spec = pack_sequence_as(input_spec, + inner_input_spec) + static_function = declarative( + attr_func, input_spec=inner_input_spec) + concrete_program = static_function.concrete_program + + # 4. build input & output of save_infernece_model # NOTE(chenweihang): [ Get input variables name ] # There are two cases, whether to prune the inputs or not # - not prune inputs (recommend): @@ -715,32 +792,6 @@ def save(layer, path, input_spec=None, **configs): output_vars = _get_output_vars(concrete_program.outputs, configs.output_spec) - # NOTE(chenweihang): we maintain the mapping of variable name to - # structured name, the buffer variable (non-persistable) - # saved to inference program may not need by dygraph Layer, - # we only record the state_dict variable's structured name - state_names_dict = dict() - for structured_name, var in six.iteritems(inner_layer.state_dict()): - state_names_dict[var.name] = structured_name - - # 4. share parameters from Layer to scope & record var info - for param_or_buffer in concrete_program.parameters: - # share to scope - param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor( - ) - src_tensor = param_or_buffer.value().get_tensor() - param_or_buffer_tensor._share_data_with(src_tensor) - # record var info - if param_or_buffer.name not in extra_var_info: - extra_info_dict = dict() - if param_or_buffer.name in state_names_dict: - extra_info_dict['structured_name'] = state_names_dict[ - param_or_buffer.name] - extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient - if isinstance(param_or_buffer, ParamBase): - extra_info_dict['trainable'] = param_or_buffer.trainable - extra_var_info[param_or_buffer.name] = extra_info_dict - # 5. save inference model from paddle.fluid.io import save_inference_model @@ -748,7 +799,7 @@ def save(layer, path, input_spec=None, **configs): model_path = dirname # NOTE(chenweihang): because prefix contains model and params filename, # so we don't support set model_filename & params_filename - if 'forward' == attr_func: + if 'forward' == attr_func or not isinstance(layer, Layer): model_filename = file_prefix + INFER_MODEL_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX else: @@ -782,10 +833,11 @@ def save(layer, path, input_spec=None, **configs): # but we can save these information in `jit.save` without changing the original # storage to improve user experience. So we save extra information into # file `***.pdiparams.info` - with scope_guard(scope): - extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX - with open(extra_var_info_path, 'wb') as f: - pickle.dump(extra_var_info, f, protocol=2) + if isinstance(layer, Layer) and extra_var_info: + with scope_guard(scope): + extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX + with open(extra_var_info_path, 'wb') as f: + pickle.dump(extra_var_info, f, protocol=2) @dygraph_only diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 16adcb8f241..eef38182f6e 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -399,15 +399,6 @@ class TestJitSaveLoad(unittest.TestCase): with self.assertRaises(ValueError): model_dict, _ = fluid.dygraph.load_dygraph(model_path) - def test_jit_load_model_incomplete(self): - model_path = "test_jit_save_load.remove_variables/model" - self.train_and_save_model(model_path) - # remove `.pdiparams` - var_path = model_path + INFER_PARAMS_SUFFIX - os.remove(var_path) - with self.assertRaises(ValueError): - paddle.jit.load(model_path) - def test_jit_load_no_path(self): path = "test_jit_save_load.no_path/model_path" with self.assertRaises(ValueError): @@ -1164,6 +1155,63 @@ class TestJitSaveLoadFinetuneLoad(unittest.TestCase): self.assertTrue(float(((result_01 - result_11)).abs().max()) < 1e-5) +class TestJitSaveLoadFunction(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_jit_save_load_static_function(self): + @paddle.jit.to_static + def fun(inputs): + return paddle.tanh(inputs) + + path = 'test_jit_save_load_function_1/func' + inps = paddle.rand([3, 6]) + origin = fun(inps) + + paddle.jit.save(fun, path) + load_func = paddle.jit.load(path) + + load_result = load_func(inps) + self.assertTrue((load_result - origin).abs().max() < 1e-10) + + def test_jit_save_load_function_input_spec(self): + @paddle.jit.to_static(input_spec=[ + InputSpec( + shape=[None, 6], dtype='float32', name='x'), + ]) + def fun(inputs): + return paddle.nn.functional.relu(inputs) + + path = 'test_jit_save_load_function_2/func' + inps = paddle.rand([3, 6]) + origin = fun(inps) + + paddle.jit.save(fun, path) + load_func = paddle.jit.load(path) + load_result = load_func(inps) + self.assertTrue((load_result - origin).abs().max() < 1e-10) + + def test_jit_save_load_function_function(self): + def fun(inputs): + return paddle.tanh(inputs) + + path = 'test_jit_save_load_function_3/func' + inps = paddle.rand([3, 6]) + origin = fun(inps) + + paddle.jit.save( + fun, + path, + input_spec=[ + InputSpec( + shape=[None, 6], dtype='float32', name='x'), + ]) + load_func = paddle.jit.load(path) + + load_result = load_func(inps) + self.assertTrue((load_result - origin).abs().max() < 1e-10) + + class TestJitSaveLoadDataParallel(unittest.TestCase): def verify_inference_correctness(self, layer, path): layer.eval() -- GitLab