未验证 提交 d4b4357b 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Change the Global Switch Name of ProgramTranslator for API 2.0 (#27203)

Change ProgramTranslator.enable_declarative to ProgramTranslator.enable_to_static to meet API 2.0
上级 8d05c00c
...@@ -246,7 +246,7 @@ class StaticLayer(object): ...@@ -246,7 +246,7 @@ class StaticLayer(object):
self._function_spec = FunctionSpec(function, input_spec) self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self._descriptor_cache = weakref.WeakKeyDictionary() self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`. # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator() self._program_trans = ProgramTranslator()
def __get__(self, instance, owner): def __get__(self, instance, owner):
...@@ -299,16 +299,17 @@ class StaticLayer(object): ...@@ -299,16 +299,17 @@ class StaticLayer(object):
""" """
# 1. call dygraph function directly if not enable `declarative` # 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_declarative: if not self._program_trans.enable_to_static:
logging_utils.warn( logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. " "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output.") "We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)")
return self._call_dygraph_function(*args, **kwargs) return self._call_dygraph_function(*args, **kwargs)
if not in_dygraph_mode() and self._program_trans.enable_declarative: if not in_dygraph_mode():
raise RuntimeError( raise RuntimeError(
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', " "Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the " "because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"following API: paddle.disable_static().".format( "following API: paddle.disable_static().".format(
self.dygraph_function)) self.dygraph_function))
...@@ -723,15 +724,15 @@ class ProgramTranslator(object): ...@@ -723,15 +724,15 @@ class ProgramTranslator(object):
return return
self._initialized = True self._initialized = True
self._program_cache = ProgramCache() self._program_cache = ProgramCache()
self.enable_declarative = True self.enable_to_static = True
def enable(self, enable_declarative): def enable(self, enable_to_static):
""" """
Enable or disable the converting from imperative to declarative by Enable or disable the converting from imperative to declarative by
ProgramTranslator globally. ProgramTranslator globally.
Args: Args:
enable_declarative (bool): True or False to enable or disable declarative. enable_to_static (bool): True or False to enable or disable declarative.
Returns: Returns:
None. None.
...@@ -760,9 +761,9 @@ class ProgramTranslator(object): ...@@ -760,9 +761,9 @@ class ProgramTranslator(object):
print(func(x).numpy()) # [[2. 2.]] print(func(x).numpy()) # [[2. 2.]]
""" """
check_type(enable_declarative, "enable_declarative", bool, check_type(enable_to_static, "enable_to_static", bool,
"ProgramTranslator.enable") "ProgramTranslator.enable")
self.enable_declarative = enable_declarative self.enable_to_static = enable_to_static
def get_output(self, dygraph_func, *args, **kwargs): def get_output(self, dygraph_func, *args, **kwargs):
""" """
...@@ -803,10 +804,12 @@ class ProgramTranslator(object): ...@@ -803,10 +804,12 @@ class ProgramTranslator(object):
assert callable( assert callable(
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output" ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_declarative: if not self.enable_to_static:
warnings.warn( warnings.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. " "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output.") "We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
try: try:
function_spec = FunctionSpec(dygraph_func) function_spec = FunctionSpec(dygraph_func)
...@@ -876,10 +879,11 @@ class ProgramTranslator(object): ...@@ -876,10 +879,11 @@ class ProgramTranslator(object):
assert callable( assert callable(
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func" ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_declarative: if not self.enable_to_static:
warnings.warn( warnings.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will " "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output.") "just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func return dygraph_func
static_func = convert_to_static(dygraph_func) static_func = convert_to_static(dygraph_func)
...@@ -929,10 +933,12 @@ class ProgramTranslator(object): ...@@ -929,10 +933,12 @@ class ProgramTranslator(object):
assert callable( assert callable(
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program" ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_declarative: if not self.enable_to_static:
warnings.warn( warnings.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False." "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output.") "We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
function_spec = FunctionSpec(dygraph_func) function_spec = FunctionSpec(dygraph_func)
......
...@@ -119,7 +119,7 @@ def _dygraph_to_static_func_(dygraph_func): ...@@ -119,7 +119,7 @@ def _dygraph_to_static_func_(dygraph_func):
# TODO: remove this decorator after we finalize training API # TODO: remove this decorator after we finalize training API
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_declarative: if in_dygraph_mode() or not program_translator.enable_to_static:
warnings.warn( warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in " "The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. " "dygraph mode or set ProgramTranslator.enable to False. "
...@@ -832,9 +832,9 @@ def save(layer, model_path, input_spec=None, config=None): ...@@ -832,9 +832,9 @@ def save(layer, model_path, input_spec=None, config=None):
# 1. input check # 1. input check
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
if not prog_translator.enable: if not prog_translator.enable_to_static:
raise RuntimeError( raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable=False." "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
) )
if not isinstance(layer, Layer): if not isinstance(layer, Layer):
raise TypeError( raise TypeError(
......
...@@ -1680,7 +1680,7 @@ class Model(object): ...@@ -1680,7 +1680,7 @@ class Model(object):
# TODO: # TODO:
# 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph. # 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph.
# 2. Save correct shape of input, now the interface stores the shape that the user sent to # 2. Save correct shape of input, now the interface stores the shape that the user sent to
# the inputs of the model in running. # the inputs of the model in running.
# 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode. # 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode.
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
...@@ -1689,9 +1689,9 @@ class Model(object): ...@@ -1689,9 +1689,9 @@ class Model(object):
# 1. input check # 1. input check
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
if not prog_translator.enable_declarative: if not prog_translator.enable_to_static:
raise RuntimeError( raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable=False." "save_inference_model doesn't work when setting ProgramTranslator.enable to False."
) )
if not isinstance(layer, Layer): if not isinstance(layer, Layer):
raise TypeError( raise TypeError(
...@@ -1902,8 +1902,8 @@ class Model(object): ...@@ -1902,8 +1902,8 @@ class Model(object):
assert isinstance(spec, Input) assert isinstance(spec, Input)
if spec.name is None: if spec.name is None:
raise ValueError( raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}.". "Requires Input[{}].name != None, but receive `None` with {}."
format(i, spec)) .format(i, spec))
return out_specs return out_specs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册