未验证 提交 77036fff 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Rename in_declarative_mode into in_to_static_mode (#56881)

* [Dy2St]Renae in_declarative_mode into in_to_static_mode

[Dy2St]Renae in_declarative_mode into in_to_static_mode

fix comment
上级 a5fb72de
...@@ -143,13 +143,13 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): ...@@ -143,13 +143,13 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
if in_dygraph_mode(): if in_dygraph_mode():
return return
# NOTE: `in_declarative_mode` is used to determined whether this op is called under # NOTE: `in_to_static_mode` is used to determined whether this op is called under
# @to_static in transformation from dygrah to static layer. We add Tensor in # @to_static in transformation from dygrah to static layer. We add Tensor in
# expected_type to skip checking because Tensor may be created and used in unusual way. # expected_type to skip checking because Tensor may be created and used in unusual way.
from .dygraph.base import in_declarative_mode from .dygraph.base import in_to_static_mode
# Need a better design to be fix this. # Need a better design to be fix this.
if in_declarative_mode(): if in_to_static_mode():
if not isinstance(expected_type, tuple): if not isinstance(expected_type, tuple):
expected_type = (expected_type,) expected_type = (expected_type,)
expected_type += (core.eager.Tensor,) expected_type += (core.eager.Tensor,)
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import decorator import decorator
import contextlib
import functools
import inspect import inspect
import sys import sys
import numpy as np import numpy as np
...@@ -23,7 +21,6 @@ from paddle.base import framework ...@@ -23,7 +21,6 @@ from paddle.base import framework
from paddle.base.framework import global_var from paddle.base.framework import global_var
from paddle.base.multiprocess_utils import CleanupFuncRegistrar from paddle.base.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer from .tracer import Tracer
import logging
from ..data_feeder import convert_dtype from ..data_feeder import convert_dtype
import warnings import warnings
from ..framework import _get_paddle_place from ..framework import _get_paddle_place
...@@ -44,15 +41,19 @@ __all__ = [ ...@@ -44,15 +41,19 @@ __all__ = [
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable" NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"
def in_declarative_mode(): def in_to_static_mode():
""" """
Return a bool value that indicates whether running code under `@to_static` Return a bool value that indicates whether running code under `@to_static`
""" """
return global_var._in_declarative_mode_ return global_var._in_to_static_mode_
def declarative_unsupport_argument_warning( # TODO(Aurelius84): Need to remove this alias after clean usage in PaddleX
in_declarative_mode = in_to_static_mode
def to_static_unsupport_argument_warning(
func_name, input_names, inputs, support_values func_name, input_names, inputs, support_values
): ):
""" """
...@@ -81,12 +82,12 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_) ...@@ -81,12 +82,12 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager @signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True): def _to_static_mode_guard_(is_to_static=True):
global global_var global global_var
original_val = global_var._in_declarative_mode_ original_val = global_var._in_to_static_mode_
global_var._in_declarative_mode_ = is_declarative global_var._in_to_static_mode_ = is_to_static
yield yield
global_var._in_declarative_mode_ = original_val global_var._in_to_static_mode_ = original_val
@signature_safe_contextmanager @signature_safe_contextmanager
...@@ -105,7 +106,7 @@ def program_desc_tracing_guard(enable): ...@@ -105,7 +106,7 @@ def program_desc_tracing_guard(enable):
@signature_safe_contextmanager @signature_safe_contextmanager
def param_guard(parameters): def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers # Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not paddle.in_dynamic_mode() and parameters: if in_to_static_mode() and not paddle.in_dynamic_mode() and parameters:
try: try:
origin_parameters = parameters.copy() origin_parameters = parameters.copy()
for name, var_base in parameters.items(): for name, var_base in parameters.items():
...@@ -322,7 +323,7 @@ def no_grad(func=None): ...@@ -322,7 +323,7 @@ def no_grad(func=None):
test_layer() test_layer()
""" """
if in_declarative_mode(): if in_to_static_mode():
warnings.warn( warnings.warn(
"paddle.no_grad is only supported for inference model, and not supported for training under @to_static." "paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
) )
...@@ -732,12 +733,12 @@ def grad( ...@@ -732,12 +733,12 @@ def grad(
grad_y1 = paddle.to_tensor(3.0) grad_y1 = paddle.to_tensor(3.0)
print(test_dygraph_grad([grad_y1, grad_value])) # [24.] print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
''' '''
if in_declarative_mode(): if in_to_static_mode():
# In dy2static context, we call static interface `gradients` # In dy2static context, we call static interface `gradients`
# to calculate grads. # to calculate grads.
from paddle.static import gradients from paddle.static import gradients
declarative_unsupport_argument_warning( to_static_unsupport_argument_warning(
"paddle.grad", "paddle.grad",
["retain_graph", "create_grad", "only_inputs", "allow_unused"], ["retain_graph", "create_grad", "only_inputs", "allow_unused"],
[retain_graph, create_graph, only_inputs, allow_unused], [retain_graph, create_graph, only_inputs, allow_unused],
......
...@@ -84,15 +84,13 @@ class GlobalThreadLocal(threading.local): ...@@ -84,15 +84,13 @@ class GlobalThreadLocal(threading.local):
TODO(xiongkun): how to access another thread local data ? TODO(xiongkun): how to access another thread local data ?
""" """
global _dygraph_tracer_ global _dygraph_tracer_
self._in_declarative_mode_ = False self._in_to_static_mode_ = False
self._functional_dygraph_context_manager = None self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_ self._dygraph_tracer_ = _dygraph_tracer_
def __str__(self): def __str__(self):
strings = [] strings = []
strings.append( strings.append("_in_to_static_mode_:" + str(self._in_to_static_mode_))
"_in_declarative_mode_:" + str(self._in_declarative_mode_)
)
strings.append( strings.append(
"_functional_dygraph_context_manager:" "_functional_dygraph_context_manager:"
+ str(self._functional_dygraph_context_manager) + str(self._functional_dygraph_context_manager)
...@@ -528,9 +526,9 @@ def _dygraph_only_(func): ...@@ -528,9 +526,9 @@ def _dygraph_only_(func):
def _non_static_only_(func): def _non_static_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
from .dygraph.base import in_declarative_mode from .dygraph.base import in_to_static_mode
assert in_dygraph_mode() or in_declarative_mode(), ( assert in_dygraph_mode() or in_to_static_mode(), (
"We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode." "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode."
% func.__name__ % func.__name__
) )
...@@ -2371,9 +2369,9 @@ class Variable(metaclass=VariableMetaClass): ...@@ -2371,9 +2369,9 @@ class Variable(metaclass=VariableMetaClass):
return _getitem_static(self, item) return _getitem_static(self, item)
def __setitem__(self, item, value): def __setitem__(self, item, value):
from .dygraph.base import in_declarative_mode from .dygraph.base import in_to_static_mode
if in_declarative_mode(): if in_to_static_mode():
if is_compiled_with_xpu(): if is_compiled_with_xpu():
# (NOTE): Currently, there is no index_put_xpu kernel. # (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(self, item, value) return _setitem_impl_(self, item, value)
...@@ -4325,12 +4323,9 @@ class Block: ...@@ -4325,12 +4323,9 @@ class Block:
'while', 'while',
'while_grad', 'while_grad',
} }
from .dygraph.base import in_declarative_mode from .dygraph.base import in_to_static_mode
if ( if in_to_static_mode() and not _stride_in_no_check_dy2st_diff_mode:
in_declarative_mode()
and not _stride_in_no_check_dy2st_diff_mode
):
check_if_to_static_diff_with_dygraph( check_if_to_static_diff_with_dygraph(
op_type, inplace_map, outputs op_type, inplace_map, outputs
) )
...@@ -4347,7 +4342,7 @@ class Block: ...@@ -4347,7 +4342,7 @@ class Block:
) )
self.ops.append(op) self.ops.append(op)
if in_declarative_mode(): if in_to_static_mode():
record_is_view_var(op_type, inputs, outputs) record_is_view_var(op_type, inputs, outputs)
return op return op
...@@ -7645,10 +7640,10 @@ def _get_var(name, program=None): ...@@ -7645,10 +7640,10 @@ def _get_var(name, program=None):
@signature_safe_contextmanager @signature_safe_contextmanager
def dygraph_guard_if_declarative(): def dygraph_guard_if_declarative():
from .dygraph.base import in_declarative_mode from .dygraph.base import in_to_static_mode
from .dygraph import Tracer from .dygraph import Tracer
if in_declarative_mode(): if in_to_static_mode():
# Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily. # Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
with _dygraph_guard(tracer=Tracer()): with _dygraph_guard(tracer=Tracer()):
yield yield
......
...@@ -333,10 +333,10 @@ def generate_inplace_fn(inplace_op_type): ...@@ -333,10 +333,10 @@ def generate_inplace_fn(inplace_op_type):
inplace_op_type, origin_op_type inplace_op_type, origin_op_type
) )
) )
from ..dygraph.base import in_declarative_mode from ..dygraph.base import in_to_static_mode
if ( if (
in_declarative_mode() in_to_static_mode()
and hasattr(x, "is_view_var") and hasattr(x, "is_view_var")
and x.is_view_var and x.is_view_var
): ):
......
...@@ -18,7 +18,7 @@ import inspect ...@@ -18,7 +18,7 @@ import inspect
from .. import core from .. import core
from ..framework import Variable, unique_name, static_only from ..framework import Variable, unique_name, static_only
from .layer_function_generator import OpProtoHolder from .layer_function_generator import OpProtoHolder
from paddle.base.dygraph.base import in_declarative_mode from paddle.base.dygraph.base import in_to_static_mode
_supported_int_dtype_ = [ _supported_int_dtype_ = [
core.VarDesc.VarType.BOOL, core.VarDesc.VarType.BOOL,
...@@ -302,7 +302,7 @@ def monkey_patch_variable(): ...@@ -302,7 +302,7 @@ def monkey_patch_variable():
""" """
if not isinstance(var, Variable): if not isinstance(var, Variable):
if in_declarative_mode(): if in_to_static_mode():
"""in dy2static mode, x may be tensorable values such as int, float, np.array""" """in dy2static mode, x may be tensorable values such as int, float, np.array"""
from paddle.tensor.creation import to_tensor from paddle.tensor.creation import to_tensor
......
...@@ -1017,9 +1017,9 @@ def _setitem_static(x, indices, values): ...@@ -1017,9 +1017,9 @@ def _setitem_static(x, indices, values):
def get_tensor_with_basic_indexing( def get_tensor_with_basic_indexing(
x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice
): ):
from .dygraph.base import in_declarative_mode from .dygraph.base import in_to_static_mode
if in_declarative_mode() and hasattr(x, "is_view_var"): if in_to_static_mode() and hasattr(x, "is_view_var"):
x.is_view_var = True x.is_view_var = True
if len(axes) == 0: if len(axes) == 0:
...@@ -1111,7 +1111,7 @@ def get_tensor_with_basic_indexing( ...@@ -1111,7 +1111,7 @@ def get_tensor_with_basic_indexing(
out = paddle.unsqueeze(out, axis=none_axes) out = paddle.unsqueeze(out, axis=none_axes)
if in_declarative_mode() and hasattr(out, "is_view_var"): if in_to_static_mode() and hasattr(out, "is_view_var"):
out.is_view_var = True out.is_view_var = True
return out return out
......
...@@ -16,7 +16,7 @@ import re ...@@ -16,7 +16,7 @@ import re
import paddle import paddle
from paddle.base.data_feeder import convert_dtype from paddle.base.data_feeder import convert_dtype
from paddle.base.dygraph.base import _convert_into_variable, in_declarative_mode from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode
from paddle.base.framework import Variable, core, default_main_program from paddle.base.framework import Variable, core, default_main_program
from .utils import ( from .utils import (
...@@ -38,14 +38,14 @@ def convert_attr(x, attr): ...@@ -38,14 +38,14 @@ def convert_attr(x, attr):
def convert_load(x): def convert_load(x):
if in_declarative_mode() and isinstance(x, paddle.base.core.eager.Tensor): if in_to_static_mode() and isinstance(x, paddle.base.core.eager.Tensor):
""" """
TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed. TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed.
""" """
return _convert_into_variable(x) return _convert_into_variable(x)
# get the new output of the var # get the new output of the var
if in_declarative_mode() and isinstance(x, Variable): if in_to_static_mode() and isinstance(x, Variable):
cur_block = default_main_program().current_block() cur_block = default_main_program().current_block()
from paddle.jit.dy2static.program_translator import ProgramTranslator from paddle.jit.dy2static.program_translator import ProgramTranslator
......
...@@ -23,7 +23,7 @@ import weakref ...@@ -23,7 +23,7 @@ import weakref
from paddle.base import core, framework from paddle.base import core, framework
from paddle.base.data_feeder import check_type from paddle.base.data_feeder import check_type
from paddle.base.dygraph.base import ( from paddle.base.dygraph.base import (
_switch_declarative_mode_guard_, _to_static_mode_guard_,
param_guard, param_guard,
switch_to_static_graph, switch_to_static_graph,
) )
...@@ -1192,9 +1192,9 @@ class ConcreteProgram: ...@@ -1192,9 +1192,9 @@ class ConcreteProgram:
new_name_generator = UniqueNameGenerator() new_name_generator = UniqueNameGenerator()
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_( with _to_static_mode_guard_(is_to_static=True), UniqueNameGuard(
is_declarative=True new_name_generator
), UniqueNameGuard(new_name_generator): ):
# 1. Adds `paddle.static.data` layers for input if needed # 1. Adds `paddle.static.data` layers for input if needed
static_inputs = func_spec.to_static_inputs_with_spec( static_inputs = func_spec.to_static_inputs_with_spec(
input_spec, main_program input_spec, main_program
......
...@@ -26,9 +26,10 @@ from paddle import nn, profiler ...@@ -26,9 +26,10 @@ from paddle import nn, profiler
from paddle.base import core, framework, unique_name from paddle.base import core, framework, unique_name
from paddle.base.core import VarDesc from paddle.base.core import VarDesc
from paddle.base.dygraph import no_grad from paddle.base.dygraph import no_grad
from paddle.base.dygraph.base import in_declarative_mode # noqa F401
from paddle.base.dygraph.base import ( from paddle.base.dygraph.base import (
_convert_into_variable, _convert_into_variable,
in_declarative_mode, in_to_static_mode,
program_desc_tracing_guard, program_desc_tracing_guard,
) )
from paddle.base.dygraph_utils import _append_activation_in_dygraph from paddle.base.dygraph_utils import _append_activation_in_dygraph
...@@ -1336,7 +1337,7 @@ class Layer: ...@@ -1336,7 +1337,7 @@ class Layer:
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
if ( if (
(not in_declarative_mode()) (not in_to_static_mode())
and (not self._forward_pre_hooks) and (not self._forward_pre_hooks)
and (not self._forward_post_hooks) and (not self._forward_post_hooks)
and (not self._built) and (not self._built)
...@@ -1561,7 +1562,7 @@ class Layer: ...@@ -1561,7 +1562,7 @@ class Layer:
if '_parameters' in self.__dict__: if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters'] _parameters = self.__dict__['_parameters']
if name in self._parameters: if name in self._parameters:
if in_declarative_mode(): if in_to_static_mode():
return _convert_into_variable(self._parameters[name]) return _convert_into_variable(self._parameters[name])
return self._parameters[name] return self._parameters[name]
if '_sub_layers' in self.__dict__: if '_sub_layers' in self.__dict__:
...@@ -1571,7 +1572,7 @@ class Layer: ...@@ -1571,7 +1572,7 @@ class Layer:
if '_buffers' in self.__dict__: if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers'] _buffers = self.__dict__['_buffers']
if name in _buffers: if name in _buffers:
if in_declarative_mode(): if in_to_static_mode():
return _convert_into_variable(_buffers[name]) return _convert_into_variable(_buffers[name])
return _buffers[name] return _buffers[name]
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
...@@ -1653,7 +1654,7 @@ class Layer: ...@@ -1653,7 +1654,7 @@ class Layer:
# but should all non-Variable _buffers[name] be re-assign? We # but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as # should consider it in the future. I current wrote this as
# conservative code. # conservative code.
if in_declarative_mode() and _buffers[name] is None: if in_to_static_mode() and _buffers[name] is None:
raise RuntimeError( raise RuntimeError(
'In Dy2stat, self.{0} is a buffer and self.{0} is ' 'In Dy2stat, self.{0} is a buffer and self.{0} is '
'not allowed to be set to Variable when self.{0} is None.'.format( 'not allowed to be set to Variable when self.{0} is None.'.format(
......
...@@ -419,7 +419,7 @@ class Adam(Optimizer): ...@@ -419,7 +419,7 @@ class Adam(Optimizer):
>>> adam.step() >>> adam.step()
>>> adam.clear_grad() >>> adam.clear_grad()
""" """
if paddle.base.dygraph.base.in_declarative_mode(): if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step() self._declarative_step()
return return
......
...@@ -557,7 +557,7 @@ class AdamW(Optimizer): ...@@ -557,7 +557,7 @@ class AdamW(Optimizer):
>>> opt.step() >>> opt.step()
>>> opt.clear_grad() >>> opt.clear_grad()
""" """
if paddle.base.dygraph.base.in_declarative_mode(): if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step() self._declarative_step()
return return
......
...@@ -1603,7 +1603,7 @@ class Optimizer: ...@@ -1603,7 +1603,7 @@ class Optimizer:
>>> adam.step() >>> adam.step()
>>> adam.clear_grad() >>> adam.clear_grad()
""" """
if paddle.base.dygraph.base.in_declarative_mode(): if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step() self._declarative_step()
return return
......
...@@ -33,9 +33,9 @@ def _inplace_apis_in_dygraph_only_(func): ...@@ -33,9 +33,9 @@ def _inplace_apis_in_dygraph_only_(func):
func.__name__, origin_api_name func.__name__, origin_api_name
) )
) )
from ..base.dygraph.base import in_declarative_mode from ..base.dygraph.base import in_to_static_mode
if in_declarative_mode(): if in_to_static_mode():
for arg in args: for arg in args:
if hasattr(arg, "is_view_var") and arg.is_view_var: if hasattr(arg, "is_view_var") and arg.is_view_var:
raise ValueError( raise ValueError(
......
...@@ -39,7 +39,7 @@ class TestDy2staticException(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error): with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
self.assertTrue(to_static(self.dyfunc)(self.x)) self.assertTrue(to_static(self.dyfunc)(self.x))
paddle.base.dygraph.base.global_var._in_declarative_mode_ = False paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
......
...@@ -68,7 +68,7 @@ class TestDy2staticException(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error): with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x)) self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x))
paddle.base.dygraph.base.global_var._in_declarative_mode_ = False paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
...@@ -468,12 +468,12 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): ...@@ -468,12 +468,12 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
with self.assertRaises(Dygraph2StaticException): with self.assertRaises(Dygraph2StaticException):
static_func = paddle.jit.to_static(self.dyfunc) static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x) out = static_func(self.x)
# Why need set `_in_declarative_mode_` here? # Why need set `_in_to_static_mode_` here?
# In Dy2St we use `with _switch_declarative_mode_guard_()` to indicate # In Dy2St we use `with _to_static_mode_guard_()` to indicate
# that the code block is under @to_static, but in this UT # that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_` # an exception is thrown during Dy2St, making the `_in_to_static_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually. # a wrong value. So We need set `_in_to_static_mode_` to False manually.
paddle.base.dygraph.base.global_var._in_declarative_mode_ = False paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册