未验证 提交 77036fff 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Rename in_declarative_mode into in_to_static_mode (#56881)

* [Dy2St]Renae in_declarative_mode into in_to_static_mode

[Dy2St]Renae in_declarative_mode into in_to_static_mode

fix comment
上级 a5fb72de
......@@ -143,13 +143,13 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
if in_dygraph_mode():
return
# NOTE: `in_declarative_mode` is used to determined whether this op is called under
# NOTE: `in_to_static_mode` is used to determined whether this op is called under
# @to_static in transformation from dygrah to static layer. We add Tensor in
# expected_type to skip checking because Tensor may be created and used in unusual way.
from .dygraph.base import in_declarative_mode
from .dygraph.base import in_to_static_mode
# Need a better design to be fix this.
if in_declarative_mode():
if in_to_static_mode():
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)
expected_type += (core.eager.Tensor,)
......
......@@ -13,8 +13,6 @@
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import decorator
import contextlib
import functools
import inspect
import sys
import numpy as np
......@@ -23,7 +21,6 @@ from paddle.base import framework
from paddle.base.framework import global_var
from paddle.base.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer
import logging
from ..data_feeder import convert_dtype
import warnings
from ..framework import _get_paddle_place
......@@ -44,15 +41,19 @@ __all__ = [
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"
def in_declarative_mode():
def in_to_static_mode():
"""
Return a bool value that indicates whether running code under `@to_static`
"""
return global_var._in_declarative_mode_
return global_var._in_to_static_mode_
def declarative_unsupport_argument_warning(
# TODO(Aurelius84): Need to remove this alias after clean usage in PaddleX
in_declarative_mode = in_to_static_mode
def to_static_unsupport_argument_warning(
func_name, input_names, inputs, support_values
):
"""
......@@ -81,12 +82,12 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
def _to_static_mode_guard_(is_to_static=True):
global global_var
original_val = global_var._in_declarative_mode_
global_var._in_declarative_mode_ = is_declarative
original_val = global_var._in_to_static_mode_
global_var._in_to_static_mode_ = is_to_static
yield
global_var._in_declarative_mode_ = original_val
global_var._in_to_static_mode_ = original_val
@signature_safe_contextmanager
......@@ -105,7 +106,7 @@ def program_desc_tracing_guard(enable):
@signature_safe_contextmanager
def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers
if in_declarative_mode() and not paddle.in_dynamic_mode() and parameters:
if in_to_static_mode() and not paddle.in_dynamic_mode() and parameters:
try:
origin_parameters = parameters.copy()
for name, var_base in parameters.items():
......@@ -322,7 +323,7 @@ def no_grad(func=None):
test_layer()
"""
if in_declarative_mode():
if in_to_static_mode():
warnings.warn(
"paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
)
......@@ -732,12 +733,12 @@ def grad(
grad_y1 = paddle.to_tensor(3.0)
print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
'''
if in_declarative_mode():
if in_to_static_mode():
# In dy2static context, we call static interface `gradients`
# to calculate grads.
from paddle.static import gradients
declarative_unsupport_argument_warning(
to_static_unsupport_argument_warning(
"paddle.grad",
["retain_graph", "create_grad", "only_inputs", "allow_unused"],
[retain_graph, create_graph, only_inputs, allow_unused],
......
......@@ -84,15 +84,13 @@ class GlobalThreadLocal(threading.local):
TODO(xiongkun): how to access another thread local data ?
"""
global _dygraph_tracer_
self._in_declarative_mode_ = False
self._in_to_static_mode_ = False
self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_
def __str__(self):
strings = []
strings.append(
"_in_declarative_mode_:" + str(self._in_declarative_mode_)
)
strings.append("_in_to_static_mode_:" + str(self._in_to_static_mode_))
strings.append(
"_functional_dygraph_context_manager:"
+ str(self._functional_dygraph_context_manager)
......@@ -528,9 +526,9 @@ def _dygraph_only_(func):
def _non_static_only_(func):
def __impl__(*args, **kwargs):
from .dygraph.base import in_declarative_mode
from .dygraph.base import in_to_static_mode
assert in_dygraph_mode() or in_declarative_mode(), (
assert in_dygraph_mode() or in_to_static_mode(), (
"We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode."
% func.__name__
)
......@@ -2371,9 +2369,9 @@ class Variable(metaclass=VariableMetaClass):
return _getitem_static(self, item)
def __setitem__(self, item, value):
from .dygraph.base import in_declarative_mode
from .dygraph.base import in_to_static_mode
if in_declarative_mode():
if in_to_static_mode():
if is_compiled_with_xpu():
# (NOTE): Currently, there is no index_put_xpu kernel.
return _setitem_impl_(self, item, value)
......@@ -4325,12 +4323,9 @@ class Block:
'while',
'while_grad',
}
from .dygraph.base import in_declarative_mode
from .dygraph.base import in_to_static_mode
if (
in_declarative_mode()
and not _stride_in_no_check_dy2st_diff_mode
):
if in_to_static_mode() and not _stride_in_no_check_dy2st_diff_mode:
check_if_to_static_diff_with_dygraph(
op_type, inplace_map, outputs
)
......@@ -4347,7 +4342,7 @@ class Block:
)
self.ops.append(op)
if in_declarative_mode():
if in_to_static_mode():
record_is_view_var(op_type, inputs, outputs)
return op
......@@ -7645,10 +7640,10 @@ def _get_var(name, program=None):
@signature_safe_contextmanager
def dygraph_guard_if_declarative():
from .dygraph.base import in_declarative_mode
from .dygraph.base import in_to_static_mode
from .dygraph import Tracer
if in_declarative_mode():
if in_to_static_mode():
# Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
with _dygraph_guard(tracer=Tracer()):
yield
......
......@@ -333,10 +333,10 @@ def generate_inplace_fn(inplace_op_type):
inplace_op_type, origin_op_type
)
)
from ..dygraph.base import in_declarative_mode
from ..dygraph.base import in_to_static_mode
if (
in_declarative_mode()
in_to_static_mode()
and hasattr(x, "is_view_var")
and x.is_view_var
):
......
......@@ -18,7 +18,7 @@ import inspect
from .. import core
from ..framework import Variable, unique_name, static_only
from .layer_function_generator import OpProtoHolder
from paddle.base.dygraph.base import in_declarative_mode
from paddle.base.dygraph.base import in_to_static_mode
_supported_int_dtype_ = [
core.VarDesc.VarType.BOOL,
......@@ -302,7 +302,7 @@ def monkey_patch_variable():
"""
if not isinstance(var, Variable):
if in_declarative_mode():
if in_to_static_mode():
"""in dy2static mode, x may be tensorable values such as int, float, np.array"""
from paddle.tensor.creation import to_tensor
......
......@@ -1017,9 +1017,9 @@ def _setitem_static(x, indices, values):
def get_tensor_with_basic_indexing(
x, axes, starts, ends, steps, decrease_axes, none_axes, use_strided_slice
):
from .dygraph.base import in_declarative_mode
from .dygraph.base import in_to_static_mode
if in_declarative_mode() and hasattr(x, "is_view_var"):
if in_to_static_mode() and hasattr(x, "is_view_var"):
x.is_view_var = True
if len(axes) == 0:
......@@ -1111,7 +1111,7 @@ def get_tensor_with_basic_indexing(
out = paddle.unsqueeze(out, axis=none_axes)
if in_declarative_mode() and hasattr(out, "is_view_var"):
if in_to_static_mode() and hasattr(out, "is_view_var"):
out.is_view_var = True
return out
......
......@@ -16,7 +16,7 @@ import re
import paddle
from paddle.base.data_feeder import convert_dtype
from paddle.base.dygraph.base import _convert_into_variable, in_declarative_mode
from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode
from paddle.base.framework import Variable, core, default_main_program
from .utils import (
......@@ -38,14 +38,14 @@ def convert_attr(x, attr):
def convert_load(x):
if in_declarative_mode() and isinstance(x, paddle.base.core.eager.Tensor):
if in_to_static_mode() and isinstance(x, paddle.base.core.eager.Tensor):
"""
TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed.
"""
return _convert_into_variable(x)
# get the new output of the var
if in_declarative_mode() and isinstance(x, Variable):
if in_to_static_mode() and isinstance(x, Variable):
cur_block = default_main_program().current_block()
from paddle.jit.dy2static.program_translator import ProgramTranslator
......
......@@ -23,7 +23,7 @@ import weakref
from paddle.base import core, framework
from paddle.base.data_feeder import check_type
from paddle.base.dygraph.base import (
_switch_declarative_mode_guard_,
_to_static_mode_guard_,
param_guard,
switch_to_static_graph,
)
......@@ -1192,9 +1192,9 @@ class ConcreteProgram:
new_name_generator = UniqueNameGenerator()
with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(
is_declarative=True
), UniqueNameGuard(new_name_generator):
with _to_static_mode_guard_(is_to_static=True), UniqueNameGuard(
new_name_generator
):
# 1. Adds `paddle.static.data` layers for input if needed
static_inputs = func_spec.to_static_inputs_with_spec(
input_spec, main_program
......
......@@ -26,9 +26,10 @@ from paddle import nn, profiler
from paddle.base import core, framework, unique_name
from paddle.base.core import VarDesc
from paddle.base.dygraph import no_grad
from paddle.base.dygraph.base import in_declarative_mode # noqa F401
from paddle.base.dygraph.base import (
_convert_into_variable,
in_declarative_mode,
in_to_static_mode,
program_desc_tracing_guard,
)
from paddle.base.dygraph_utils import _append_activation_in_dygraph
......@@ -1336,7 +1337,7 @@ class Layer:
def __call__(self, *inputs, **kwargs):
if (
(not in_declarative_mode())
(not in_to_static_mode())
and (not self._forward_pre_hooks)
and (not self._forward_post_hooks)
and (not self._built)
......@@ -1561,7 +1562,7 @@ class Layer:
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in self._parameters:
if in_declarative_mode():
if in_to_static_mode():
return _convert_into_variable(self._parameters[name])
return self._parameters[name]
if '_sub_layers' in self.__dict__:
......@@ -1571,7 +1572,7 @@ class Layer:
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
if in_declarative_mode():
if in_to_static_mode():
return _convert_into_variable(_buffers[name])
return _buffers[name]
return object.__getattribute__(self, name)
......@@ -1653,7 +1654,7 @@ class Layer:
# but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as
# conservative code.
if in_declarative_mode() and _buffers[name] is None:
if in_to_static_mode() and _buffers[name] is None:
raise RuntimeError(
'In Dy2stat, self.{0} is a buffer and self.{0} is '
'not allowed to be set to Variable when self.{0} is None.'.format(
......
......@@ -419,7 +419,7 @@ class Adam(Optimizer):
>>> adam.step()
>>> adam.clear_grad()
"""
if paddle.base.dygraph.base.in_declarative_mode():
if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
......
......@@ -557,7 +557,7 @@ class AdamW(Optimizer):
>>> opt.step()
>>> opt.clear_grad()
"""
if paddle.base.dygraph.base.in_declarative_mode():
if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
......
......@@ -1603,7 +1603,7 @@ class Optimizer:
>>> adam.step()
>>> adam.clear_grad()
"""
if paddle.base.dygraph.base.in_declarative_mode():
if paddle.base.dygraph.base.in_to_static_mode():
self._declarative_step()
return
......
......@@ -33,9 +33,9 @@ def _inplace_apis_in_dygraph_only_(func):
func.__name__, origin_api_name
)
)
from ..base.dygraph.base import in_declarative_mode
from ..base.dygraph.base import in_to_static_mode
if in_declarative_mode():
if in_to_static_mode():
for arg in args:
if hasattr(arg, "is_view_var") and arg.is_view_var:
raise ValueError(
......
......@@ -39,7 +39,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True)
self.assertTrue(to_static(self.dyfunc)(self.x))
paddle.base.dygraph.base.global_var._in_declarative_mode_ = False
paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False)
......
......@@ -68,7 +68,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True)
self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x))
paddle.base.dygraph.base.global_var._in_declarative_mode_ = False
paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False)
......@@ -468,12 +468,12 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
with self.assertRaises(Dygraph2StaticException):
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
# Why need set `_in_declarative_mode_` here?
# In Dy2St we use `with _switch_declarative_mode_guard_()` to indicate
# Why need set `_in_to_static_mode_` here?
# In Dy2St we use `with _to_static_mode_guard_()` to indicate
# that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle.base.dygraph.base.global_var._in_declarative_mode_ = False
# an exception is thrown during Dy2St, making the `_in_to_static_mode_`
# a wrong value. So We need set `_in_to_static_mode_` to False manually.
paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册