From 483ba282e7e84c48177243ad605cddb971516f7a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 15 Sep 2022 20:43:44 +0800 Subject: [PATCH] [jit] skip forward save (#45901) * skip forward save * fix bug * more ci for jit skip forward --- python/paddle/fluid/dygraph/jit.py | 20 ++++++++--- .../tests/unittests/test_jit_save_load.py | 34 +++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index a0275ac57c..c793515379 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -381,7 +381,8 @@ class _SaveLoadConfig(object): def _parse_save_configs(configs): supported_configs = [ - 'output_spec', "with_hook", "combine_params", "clip_extra" + 'output_spec', "with_hook", "combine_params", "clip_extra", + "skip_forward" ] # input check @@ -397,6 +398,7 @@ def _parse_save_configs(configs): inner_config.with_hook = configs.get('with_hook', False) inner_config.combine_params = configs.get("combine_params", False) inner_config.clip_extra = configs.get("clip_extra", False) + inner_config.skip_forward = configs.get("skip_forward", False) return inner_config @@ -522,7 +524,10 @@ def _build_load_path_and_config(path, config): "don't know which one to load, please make sure that the specified target " "of ``path`` is unique." % (path, path)) elif not prefix_format_exist and not directory_format_exist: - raise ValueError("The ``path`` (%s) to load model not exists." % path) + raise ValueError("The ``path`` (%s) to load model not exists. " + "Please make sure that *.pdmodel exists or " + "don't using ``skip_forward=True`` to jit.save." % + path) else: if prefix_format_exist: file_prefix = os.path.basename(path) @@ -906,6 +911,7 @@ def save(layer, path, input_spec=None, **configs): combine_vars = {} property_vals = [] # (value, key) + concrete_program = None for attr_func in functions: if isinstance(layer, Layer): static_func = getattr(inner_layer, attr_func, None) @@ -921,6 +927,10 @@ def save(layer, path, input_spec=None, **configs): concrete_program = static_func.concrete_program_specify_input_spec( inner_input_spec, with_hook=with_hook) elif 'forward' == attr_func: + if configs.skip_forward: + # do not jit.save forward function + continue + # transform in jit.save, if input_spec is incomplete, declarative will throw error # inner_input_spec is list[InputSpec], it should be packed with same structure # as original input_spec here. @@ -1100,10 +1110,10 @@ def save(layer, path, input_spec=None, **configs): # file `***.pdiparams.info` # "layer" can only be Layer or function or StaticFunction. - contain_parameter = False - for var in concrete_program.main_program.list_vars(): - contain_parameter |= isinstance(var, Parameter) + if concrete_program is not None: + for var in concrete_program.main_program.list_vars(): + contain_parameter |= isinstance(var, Parameter) if (isinstance(layer, Layer) or contain_parameter) and extra_var_info: with scope_guard(scope): diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 507083755c..da89cbf33c 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1740,6 +1740,40 @@ class TestInputSpecCompatibility(unittest.TestCase): shutil.rmtree(save_dir) +class NotJitForward(paddle.nn.Layer): + + def __init__(self): + super(NotJitForward, self).__init__() + + def forward(self, x, y): + return x + y + + +class TestNotJitForward(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_jit_not_save_forward(self): + layer = NotJitForward() + + save_dir = os.path.join(self.temp_dir.name, "jit_not_save_forward") + path = save_dir + "/model" + + paddle.jit.save(layer=layer, path=path, skip_forward=True) + + self.assertTrue(not os.path.exists(path + ".pdmodel")) + self.assertTrue(not os.path.exists(path + ".pdparam")) + + with self.assertRaises(ValueError): + paddle.jit.load(path=path) + + shutil.rmtree(save_dir) + + if __name__ == '__main__': with fluid.framework._test_eager_guard(): unittest.main() -- GitLab