未验证 提交 f1ffd59a 编写于 作者: C Chen Weihang 提交者: GitHub

add clip_extra and change use_combine_name (#44008)

上级 b4fef397
...@@ -379,7 +379,9 @@ class _SaveLoadConfig(object): ...@@ -379,7 +379,9 @@ class _SaveLoadConfig(object):
def _parse_save_configs(configs): def _parse_save_configs(configs):
supported_configs = ['output_spec', "with_hook", "use_combine"] supported_configs = [
'output_spec', "with_hook", "combine_params", "clip_extra"
]
# input check # input check
for key in configs: for key in configs:
...@@ -392,7 +394,8 @@ def _parse_save_configs(configs): ...@@ -392,7 +394,8 @@ def _parse_save_configs(configs):
inner_config = _SaveLoadConfig() inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None) inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False) inner_config.with_hook = configs.get('with_hook', False)
inner_config.combine_params = configs.get("use_combine", False) inner_config.combine_params = configs.get("combine_params", False)
inner_config.clip_extra = configs.get("clip_extra", False)
return inner_config return inner_config
...@@ -1015,7 +1018,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1015,7 +1018,7 @@ def save(layer, path, input_spec=None, **configs):
params_filename=params_filename, params_filename=params_filename,
export_for_deployment=configs._export_for_deployment, export_for_deployment=configs._export_for_deployment,
program_only=configs._program_only, program_only=configs._program_only,
clip_extra=False) clip_extra=configs.clip_extra)
# collect all vars # collect all vars
for var in concrete_program.main_program.list_vars(): for var in concrete_program.main_program.list_vars():
......
...@@ -1209,7 +1209,7 @@ class TestJitSaveCombine(unittest.TestCase): ...@@ -1209,7 +1209,7 @@ class TestJitSaveCombine(unittest.TestCase):
with unique_name.guard(): with unique_name.guard():
net = Net() net = Net()
#save #save
paddle.jit.save(net, model_path, use_combine=True) paddle.jit.save(net, model_path, combine_params=True)
class LayerLoadFinetune(paddle.nn.Layer): class LayerLoadFinetune(paddle.nn.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册