diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 393f1c1570453ca33ffb78121af2f8e2f640aba5..e8c263fe033554df1bbce3ee41a532c347cbe726 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -379,7 +379,9 @@ class _SaveLoadConfig(object): def _parse_save_configs(configs): - supported_configs = ['output_spec', "with_hook", "use_combine"] + supported_configs = [ + 'output_spec', "with_hook", "combine_params", "clip_extra" + ] # input check for key in configs: @@ -392,7 +394,8 @@ def _parse_save_configs(configs): inner_config = _SaveLoadConfig() inner_config.output_spec = configs.get('output_spec', None) inner_config.with_hook = configs.get('with_hook', False) - inner_config.combine_params = configs.get("use_combine", False) + inner_config.combine_params = configs.get("combine_params", False) + inner_config.clip_extra = configs.get("clip_extra", False) return inner_config @@ -1015,7 +1018,7 @@ def save(layer, path, input_spec=None, **configs): params_filename=params_filename, export_for_deployment=configs._export_for_deployment, program_only=configs._program_only, - clip_extra=False) + clip_extra=configs.clip_extra) # collect all vars for var in concrete_program.main_program.list_vars(): diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index f467fbe4888e64cb3f0edc4a7275d8750e5641ab..fd4129f47ff65f1ee4e83072c6a5d7973403c6ff 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1209,7 +1209,7 @@ class TestJitSaveCombine(unittest.TestCase): with unique_name.guard(): net = Net() #save - paddle.jit.save(net, model_path, use_combine=True) + paddle.jit.save(net, model_path, combine_params=True) class LayerLoadFinetune(paddle.nn.Layer):