未验证 提交 c1a88687 编写于 作者: C Chen Weihang 提交者: GitHub

Change jit.save/load configs to config & update code examples (#27056)

* change configs to config & update examples

* fix deprecate decorator conflict
上级 0443b480
...@@ -24,7 +24,7 @@ from . import learning_rate_scheduler ...@@ -24,7 +24,7 @@ from . import learning_rate_scheduler
import warnings import warnings
from .. import core from .. import core
from .base import guard from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
__all__ = [ __all__ = [
...@@ -42,9 +42,9 @@ def deprecate_keep_name_table(func): ...@@ -42,9 +42,9 @@ def deprecate_keep_name_table(func):
warnings.warn( warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.", "The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning) DeprecationWarning)
configs = SaveLoadConfig() config = SaveLoadConfig()
configs.keep_name_table = keep_name_table config.keep_name_table = keep_name_table
return configs return config
# deal with arg `keep_name_table` # deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool): if len(args) > 1 and isinstance(args[1], bool):
...@@ -52,7 +52,7 @@ def deprecate_keep_name_table(func): ...@@ -52,7 +52,7 @@ def deprecate_keep_name_table(func):
args[1] = __warn_and_build_configs__(args[1]) args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs # deal with kwargs
elif 'keep_name_table' in kwargs: elif 'keep_name_table' in kwargs:
kwargs['configs'] = __warn_and_build_configs__(kwargs[ kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table']) 'keep_name_table'])
kwargs.pop('keep_name_table') kwargs.pop('keep_name_table')
else: else:
...@@ -135,8 +135,9 @@ def save_dygraph(state_dict, model_path): ...@@ -135,8 +135,9 @@ def save_dygraph(state_dict, model_path):
# TODO(qingqing01): remove dygraph_only to support loading static model. # TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready. # maybe need to unify the loading interface after 2.0 API is ready.
# @dygraph_only # @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table @deprecate_keep_name_table
def load_dygraph(model_path, configs=None): def load_dygraph(model_path, config=None):
''' '''
:api_attr: imperative :api_attr: imperative
...@@ -151,7 +152,7 @@ def load_dygraph(model_path, configs=None): ...@@ -151,7 +152,7 @@ def load_dygraph(model_path, configs=None):
Args: Args:
model_path(str) : The file prefix store the state_dict. model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams') (The path should Not contain suffix '.pdparams')
configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats. are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None. Default None.
...@@ -195,6 +196,7 @@ def load_dygraph(model_path, configs=None): ...@@ -195,6 +196,7 @@ def load_dygraph(model_path, configs=None):
opti_file_path = model_prefix + ".pdopt" opti_file_path = model_prefix + ".pdopt"
# deal with argument `configs` # deal with argument `configs`
configs = config
if configs is None: if configs is None:
configs = SaveLoadConfig() configs = SaveLoadConfig()
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册