From f1ffd59a8686cda32dae65aff259d489fff83da7 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 1 Jul 2022 19:35:28 +0800 Subject: [PATCH] add clip_extra and change use_combine_name (#44008) --- python/paddle/fluid/dygraph/jit.py | 9 ++++++--- .../paddle/fluid/tests/unittests/test_jit_save_load.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 393f1c15704..e8c263fe033 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -379,7 +379,9 @@ class _SaveLoadConfig(object): def _parse_save_configs(configs): - supported_configs = ['output_spec', "with_hook", "use_combine"] + supported_configs = [ + 'output_spec', "with_hook", "combine_params", "clip_extra" + ] # input check for key in configs: @@ -392,7 +394,8 @@ def _parse_save_configs(configs): inner_config = _SaveLoadConfig() inner_config.output_spec = configs.get('output_spec', None) inner_config.with_hook = configs.get('with_hook', False) - inner_config.combine_params = configs.get("use_combine", False) + inner_config.combine_params = configs.get("combine_params", False) + inner_config.clip_extra = configs.get("clip_extra", False) return inner_config @@ -1015,7 +1018,7 @@ def save(layer, path, input_spec=None, **configs): params_filename=params_filename, export_for_deployment=configs._export_for_deployment, program_only=configs._program_only, - clip_extra=False) + clip_extra=configs.clip_extra) # collect all vars for var in concrete_program.main_program.list_vars(): diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index f467fbe4888..fd4129f47ff 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1209,7 +1209,7 @@ class TestJitSaveCombine(unittest.TestCase): with unique_name.guard(): net = Net() #save - paddle.jit.save(net, model_path, use_combine=True) + paddle.jit.save(net, model_path, combine_params=True) class LayerLoadFinetune(paddle.nn.Layer): -- GitLab