未验证 提交 18745e6f 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] fix switch static graph affects dataloader (#49821)

* rebase merge

* code fix

* fix bugs
上级 611da7fc
...@@ -55,7 +55,7 @@ from .framework.dtype import bool # noqa: F401 ...@@ -55,7 +55,7 @@ from .framework.dtype import bool # noqa: F401
from .framework.dtype import complex64 # noqa: F401 from .framework.dtype import complex64 # noqa: F401
from .framework.dtype import complex128 # noqa: F401 from .framework.dtype import complex128 # noqa: F401
if fluid.framework._in_eager_mode_: if fluid.framework.global_var._in_eager_mode_:
Tensor = framework.core.eager.Tensor Tensor = framework.core.eager.Tensor
else: else:
from .framework import VarBase as Tensor # noqa: F401 from .framework import VarBase as Tensor # noqa: F401
......
...@@ -107,7 +107,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False): ...@@ -107,7 +107,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
each_tensor, (paddle.Tensor, core.eager.Tensor) each_tensor, (paddle.Tensor, core.eager.Tensor)
), "The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'." ), "The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'."
else: else:
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
grad_tensors = [] grad_tensors = []
else: else:
grad_tensors = [None] * len(tensors) grad_tensors = [None] * len(tensors)
...@@ -119,7 +119,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False): ...@@ -119,7 +119,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
assert isinstance(retain_graph, bool), "retain_graph must be True or False" assert isinstance(retain_graph, bool), "retain_graph must be True or False"
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
core.eager.run_backward(tensors, grad_tensors, retain_graph) core.eager.run_backward(tensors, grad_tensors, retain_graph)
else: else:
core.dygraph_run_backward( core.dygraph_run_backward(
......
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
import numpy as np import numpy as np
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.framework import global_var
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer from .tracer import Tracer
import logging import logging
...@@ -44,7 +45,6 @@ __all__ = [ ...@@ -44,7 +45,6 @@ __all__ = [
] ]
# Flag that indicates whether running code under `@to_static` # Flag that indicates whether running code under `@to_static`
_in_declarative_mode_ = False
def in_declarative_mode(): def in_declarative_mode():
...@@ -52,7 +52,7 @@ def in_declarative_mode(): ...@@ -52,7 +52,7 @@ def in_declarative_mode():
Return a bool value that indicates whether running code under `@to_static` Return a bool value that indicates whether running code under `@to_static`
""" """
return _in_declarative_mode_ return global_var._in_declarative_mode_
def declarative_unsupport_argument_warning( def declarative_unsupport_argument_warning(
...@@ -86,11 +86,11 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_) ...@@ -86,11 +86,11 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager @signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True): def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_ global global_var
original_val = _in_declarative_mode_ original_val = global_var._in_declarative_mode_
_in_declarative_mode_ = is_declarative global_var._in_declarative_mode_ = is_declarative
yield yield
_in_declarative_mode_ = original_val global_var._in_declarative_mode_ = original_val
@signature_safe_contextmanager @signature_safe_contextmanager
...@@ -106,9 +106,6 @@ def program_desc_tracing_guard(enable): ...@@ -106,9 +106,6 @@ def program_desc_tracing_guard(enable):
tracer._enable_program_desc_tracing = original_val tracer._enable_program_desc_tracing = original_val
_functional_dygraph_context_manager = None
@signature_safe_contextmanager @signature_safe_contextmanager
def param_guard(parameters): def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers # Note: parameters is a reference of self._parameters or self._buffers
...@@ -228,12 +225,12 @@ def enable_dygraph(place=None): ...@@ -228,12 +225,12 @@ def enable_dygraph(place=None):
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
""" """
global _functional_dygraph_context_manager global global_var
if _functional_dygraph_context_manager is None: if global_var._functional_dygraph_context_manager is None:
_functional_dygraph_context_manager = guard( global_var._functional_dygraph_context_manager = guard(
place=_get_paddle_place(place) place=_get_paddle_place(place)
) )
_functional_dygraph_context_manager.__enter__() global_var._functional_dygraph_context_manager.__enter__()
# call disable_dygraph when Python exit # call disable_dygraph when Python exit
CleanupFuncRegistrar.register(disable_dygraph) CleanupFuncRegistrar.register(disable_dygraph)
...@@ -263,10 +260,10 @@ def disable_dygraph(): ...@@ -263,10 +260,10 @@ def disable_dygraph():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
""" """
global _functional_dygraph_context_manager global global_var
if _functional_dygraph_context_manager is not None: if global_var._functional_dygraph_context_manager is not None:
_functional_dygraph_context_manager.__exit__(*sys.exc_info()) global_var._functional_dygraph_context_manager.__exit__(*sys.exc_info())
_functional_dygraph_context_manager = None global_var._functional_dygraph_context_manager = None
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -74,7 +74,7 @@ def monkey_patch_math_varbase(): ...@@ -74,7 +74,7 @@ def monkey_patch_math_varbase():
@no_grad @no_grad
def create_tensor(value, dtype, shape): def create_tensor(value, dtype, shape):
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
out = _C_ops.full( out = _C_ops.full(
shape, value, dtype, framework._current_expected_place() shape, value, dtype, framework._current_expected_place()
) )
...@@ -251,7 +251,7 @@ def monkey_patch_math_varbase(): ...@@ -251,7 +251,7 @@ def monkey_patch_math_varbase():
# 2. create varbase for scalar # 2. create varbase for scalar
lhs_dtype = self.dtype lhs_dtype = self.dtype
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
other_var_should_be = core.eager.Tensor other_var_should_be = core.eager.Tensor
else: else:
other_var_should_be = core.VarBase other_var_should_be = core.VarBase
...@@ -486,7 +486,7 @@ def monkey_patch_math_varbase(): ...@@ -486,7 +486,7 @@ def monkey_patch_math_varbase():
global _already_patch_varbase global _already_patch_varbase
global _already_patch_eager_tensor global _already_patch_eager_tensor
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
local_already_patch = _already_patch_eager_tensor local_already_patch = _already_patch_eager_tensor
_already_patch_eager_tensor = True _already_patch_eager_tensor = True
local_tensor = core.eager.Tensor local_tensor = core.eager.Tensor
...@@ -496,7 +496,7 @@ def monkey_patch_math_varbase(): ...@@ -496,7 +496,7 @@ def monkey_patch_math_varbase():
local_tensor = core.VarBase local_tensor = core.VarBase
if not local_already_patch: if not local_already_patch:
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
for method_name in eager_cpp_level_patch: for method_name in eager_cpp_level_patch:
method_impl = getattr(local_tensor, method_name, None) method_impl = getattr(local_tensor, method_name, None)
if method_impl: if method_impl:
......
...@@ -54,7 +54,9 @@ class TensorHookRemoveHelper: ...@@ -54,7 +54,9 @@ class TensorHookRemoveHelper:
def __init__(self, tensor, hook_id): def __init__(self, tensor, hook_id):
self._tensor = ( self._tensor = (
tensor if framework._in_eager_mode_ else weakref.ref(tensor) tensor
if framework.global_var._in_eager_mode_
else weakref.ref(tensor)
) )
self._hook_id = hook_id self._hook_id = hook_id
...@@ -65,7 +67,11 @@ class TensorHookRemoveHelper: ...@@ -65,7 +67,11 @@ class TensorHookRemoveHelper:
Returns: Returns:
bool: Return True if removed successfully bool: Return True if removed successfully
""" """
tensor = self._tensor if framework._in_eager_mode_ else self._tensor() tensor = (
self._tensor
if framework.global_var._in_eager_mode_
else self._tensor()
)
if tensor is not None: if tensor is not None:
res = tensor._remove_grad_hook(self._hook_id) res = tensor._remove_grad_hook(self._hook_id)
if res is True: if res is True:
...@@ -178,7 +184,7 @@ def monkey_patch_varbase(): ...@@ -178,7 +184,7 @@ def monkey_patch_varbase():
out = linear(t) # call with different weight out = linear(t) # call with different weight
""" """
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
base_tensor = core.eager.Tensor base_tensor = core.eager.Tensor
else: else:
base_tensor = core.VarBase base_tensor = core.VarBase
...@@ -282,7 +288,7 @@ def monkey_patch_varbase(): ...@@ -282,7 +288,7 @@ def monkey_patch_varbase():
) )
record_event.begin() record_event.begin()
if grad_tensor is not None: if grad_tensor is not None:
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
assert isinstance( assert isinstance(
grad_tensor, core.eager.Tensor grad_tensor, core.eager.Tensor
), "The type of grad_tensor must be paddle.Tensor" ), "The type of grad_tensor must be paddle.Tensor"
...@@ -296,7 +302,7 @@ def monkey_patch_varbase(): ...@@ -296,7 +302,7 @@ def monkey_patch_varbase():
grad_tensor.name, grad_tensor.shape, self.name, self.shape grad_tensor.name, grad_tensor.shape, self.name, self.shape
) )
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
if grad_tensor is None: if grad_tensor is None:
grad_tensor = [] grad_tensor = []
else: else:
...@@ -311,7 +317,7 @@ def monkey_patch_varbase(): ...@@ -311,7 +317,7 @@ def monkey_patch_varbase():
): ):
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future. # TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self) scaled_loss = scale_loss(self)
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
core.eager.run_backward( core.eager.run_backward(
[scaled_loss], grad_tensor, retain_graph [scaled_loss], grad_tensor, retain_graph
) )
...@@ -323,7 +329,7 @@ def monkey_patch_varbase(): ...@@ -323,7 +329,7 @@ def monkey_patch_varbase():
framework._dygraph_tracer(), framework._dygraph_tracer(),
) )
else: else:
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
core.eager.run_backward([self], grad_tensor, retain_graph) core.eager.run_backward([self], grad_tensor, retain_graph)
else: else:
core.dygraph_run_backward( core.dygraph_run_backward(
...@@ -368,7 +374,7 @@ def monkey_patch_varbase(): ...@@ -368,7 +374,7 @@ def monkey_patch_varbase():
# [500.] # [500.]
""" """
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
if self.grad is None: if self.grad is None:
return None return None
if self.grad.is_selected_rows(): if self.grad.is_selected_rows():
...@@ -673,7 +679,7 @@ def monkey_patch_varbase(): ...@@ -673,7 +679,7 @@ def monkey_patch_varbase():
# [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436], # [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436],
# [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]]) # [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]])
""" """
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
from paddle.tensor.to_string import tensor_to_string from paddle.tensor.to_string import tensor_to_string
return tensor_to_string(self) return tensor_to_string(self)
...@@ -707,7 +713,7 @@ def monkey_patch_varbase(): ...@@ -707,7 +713,7 @@ def monkey_patch_varbase():
raise RuntimeError( raise RuntimeError(
"Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy" "Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy"
) )
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
new_varbase = core.eager.Tensor() new_varbase = core.eager.Tensor()
else: else:
new_varbase = core.VarBase() new_varbase = core.VarBase()
...@@ -725,7 +731,7 @@ def monkey_patch_varbase(): ...@@ -725,7 +731,7 @@ def monkey_patch_varbase():
assert ( assert (
numel == 1 numel == 1
), "When Variable is used as the condition of if/while , Variable can only contain one element." ), "When Variable is used as the condition of if/while , Variable can only contain one element."
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
assert self._is_initialized(), "tensor not initialized" assert self._is_initialized(), "tensor not initialized"
return bool(np.all(self.numpy() > 0)) return bool(np.all(self.numpy() > 0))
else: else:
...@@ -850,7 +856,7 @@ def monkey_patch_varbase(): ...@@ -850,7 +856,7 @@ def monkey_patch_varbase():
return _setitem_impl_(self, item, value) return _setitem_impl_(self, item, value)
else: else:
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
return self.__setitem_eager_tensor__(item, value) return self.__setitem_eager_tensor__(item, value)
else: else:
# Call c++ func __setitem_varbase__ to speedup. # Call c++ func __setitem_varbase__ to speedup.
...@@ -1020,7 +1026,7 @@ def monkey_patch_varbase(): ...@@ -1020,7 +1026,7 @@ def monkey_patch_varbase():
def __hash__(self): def __hash__(self):
return hash(id(self)) return hash(id(self))
if framework._in_eager_mode_ and not hasattr(core, "eager"): if framework.global_var._in_eager_mode_ and not hasattr(core, "eager"):
return return
for method_name, method in ( for method_name, method in (
...@@ -1047,12 +1053,12 @@ def monkey_patch_varbase(): ...@@ -1047,12 +1053,12 @@ def monkey_patch_varbase():
("to_dense", to_dense), ("to_dense", to_dense),
("to_sparse_coo", to_sparse_coo), ("to_sparse_coo", to_sparse_coo),
): ):
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
setattr(core.eager.Tensor, method_name, method) setattr(core.eager.Tensor, method_name, method)
else: else:
setattr(core.VarBase, method_name, method) setattr(core.VarBase, method_name, method)
if framework._in_eager_mode_: if framework.global_var._in_eager_mode_:
setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar) setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar)
setattr(core.eager.Tensor, "value", value) setattr(core.eager.Tensor, "value", value)
setattr(core.eager.Tensor, "cpu", cpu) setattr(core.eager.Tensor, "cpu", cpu)
......
...@@ -36,6 +36,7 @@ import paddle.version as fluid_version ...@@ -36,6 +36,7 @@ import paddle.version as fluid_version
import warnings import warnings
import functools import functools
from .variable_index import _getitem_impl_, _setitem_impl_ from .variable_index import _getitem_impl_, _setitem_impl_
import threading
__all__ = [ __all__ = [
'Program', 'Program',
...@@ -70,8 +71,42 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix() ...@@ -70,8 +71,42 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix() ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
# use thread local to create thread save global variables.
class GlobalThreadLocal(threading.local):
def __init__(self):
"""
init the thread local data.
TODO(xiongkun): how to access another thread local data ?
"""
global _dygraph_tracer_
self._in_declarative_mode_ = False
self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_
self._in_eager_mode_ = True
def __str__(self):
strings = []
strings.append(
"_in_declarative_mode_:" + str(self._in_declarative_mode_)
)
strings.append(
"_functional_dygraph_context_manager:"
+ str(self._functional_dygraph_context_manager)
)
strings.append("_dygraph_tracer_:" + str(self._dygraph_tracer_))
strings.append("_in_eager_mode_:" + str(self._in_eager_mode_))
return "\n".join(strings)
def __setattr__(self, name, val):
if name == '_dygraph_tracer_':
global _dygraph_tracer_
_dygraph_tracer_ = val
self.__dict__[name] = val
_dygraph_tracer_ = None _dygraph_tracer_ = None
_in_eager_mode_ = True global_var = GlobalThreadLocal()
_global_expected_place_ = None _global_expected_place_ = None
_current_device = None _current_device = None
global_prog_seed = 0 global_prog_seed = 0
...@@ -155,20 +190,17 @@ def _switch_tensor_bind_type(is_eager): ...@@ -155,20 +190,17 @@ def _switch_tensor_bind_type(is_eager):
def _enable_legacy_dygraph(): def _enable_legacy_dygraph():
global _in_eager_mode_ global_var._in_eager_mode_ = False
_in_eager_mode_ = False
_update_monkey_methods(is_eager=False) _update_monkey_methods(is_eager=False)
def _disable_legacy_dygraph(): def _disable_legacy_dygraph():
global _in_eager_mode_ global_var._in_eager_mode_ = True
_in_eager_mode_ = True
_update_monkey_methods(is_eager=True) _update_monkey_methods(is_eager=True)
def _in_eager_without_dygraph_check(): def _in_eager_without_dygraph_check():
global _in_eager_mode_ return global_var._in_eager_mode_
return _in_eager_mode_
# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but # FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but
...@@ -177,7 +209,6 @@ _is_first_import_ = True ...@@ -177,7 +209,6 @@ _is_first_import_ = True
def _fallback_legacy_dygraph(): def _fallback_legacy_dygraph():
global _in_eager_mode_
global _is_first_import_ global _is_first_import_
need_fallback = False need_fallback = False
# Only enable eager on CPU/GPU/XPU # Only enable eager on CPU/GPU/XPU
...@@ -187,12 +218,12 @@ def _fallback_legacy_dygraph(): ...@@ -187,12 +218,12 @@ def _fallback_legacy_dygraph():
or core.is_compiled_with_mlu() or core.is_compiled_with_mlu()
) )
if _in_eager_mode_ and is_not_support: if global_var._in_eager_mode_ and is_not_support:
# switch into legacy dygraph mode # switch into legacy dygraph mode
warnings.warn( warnings.warn(
"We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. " "We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. "
) )
_in_eager_mode_ = False global_var._in_eager_mode_ = False
if not _is_first_import_: if not _is_first_import_:
_enable_legacy_dygraph() _enable_legacy_dygraph()
need_fallback = True need_fallback = True
...@@ -234,11 +265,13 @@ def in_dygraph_mode(): ...@@ -234,11 +265,13 @@ def in_dygraph_mode():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
""" """
return (_dygraph_tracer_ is not None) and _in_eager_mode_ return (
global_var._dygraph_tracer_ is not None
) and global_var._in_eager_mode_
def _non_static_mode(): def _non_static_mode():
return _dygraph_tracer_ is not None return global_var._dygraph_tracer_ is not None
@signature_safe_contextmanager @signature_safe_contextmanager
...@@ -603,7 +636,7 @@ non_static_only = wrap_decorator(_non_static_only_) ...@@ -603,7 +636,7 @@ non_static_only = wrap_decorator(_non_static_only_)
def _dygraph_tracer(): def _dygraph_tracer():
return _dygraph_tracer_ return global_var._dygraph_tracer_
def _global_flags(): def _global_flags():
...@@ -671,9 +704,8 @@ def _current_expected_place(): ...@@ -671,9 +704,8 @@ def _current_expected_place():
def _set_dygraph_tracer_expected_place(place): def _set_dygraph_tracer_expected_place(place):
global _dygraph_tracer_ if global_var._dygraph_tracer_ is not None:
if _dygraph_tracer_ is not None: global_var._dygraph_tracer_._expected_place = place
_dygraph_tracer_._expected_place = place
def _set_expected_place(place): def _set_expected_place(place):
...@@ -1315,7 +1347,7 @@ def _varbase_creator( ...@@ -1315,7 +1347,7 @@ def _varbase_creator(
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if _in_eager_mode_: if global_var._in_eager_mode_:
eager_tensor = core.eager.Tensor( eager_tensor = core.eager.Tensor(
dtype if dtype else core.VarDesc.VarType.FP32, dtype if dtype else core.VarDesc.VarType.FP32,
list(shape) if shape else [], list(shape) if shape else [],
...@@ -7460,16 +7492,17 @@ def _get_var(name, program=None): ...@@ -7460,16 +7492,17 @@ def _get_var(name, program=None):
@signature_safe_contextmanager @signature_safe_contextmanager
def _dygraph_guard(tracer): def _dygraph_guard(tracer):
global _dygraph_tracer_ tmp_tracer = global_var._dygraph_tracer_
tmp_tracer = _dygraph_tracer_ global_var._dygraph_tracer_ = tracer
_dygraph_tracer_ = tracer if tracer is not None:
core._switch_tracer(tracer) core._switch_tracer(tracer)
try: try:
yield yield
finally: finally:
if tmp_tracer is not None:
core._switch_tracer(tmp_tracer) core._switch_tracer(tmp_tracer)
_dygraph_tracer_ = tmp_tracer global_var._dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -59,8 +59,8 @@ class LazyInitHelper: ...@@ -59,8 +59,8 @@ class LazyInitHelper:
self.enable() self.enable()
if self._in_guard: if self._in_guard:
return return
self._tracer = framework._dygraph_tracer_ self._tracer = framework.global_var._dygraph_tracer_
framework._dygraph_tracer_ = None framework.global_var._dygraph_tracer_ = None
self._in_guard = True self._in_guard = True
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
...@@ -71,7 +71,7 @@ class LazyInitHelper: ...@@ -71,7 +71,7 @@ class LazyInitHelper:
if not self._in_guard: if not self._in_guard:
return return
assert self._tracer is not None assert self._tracer is not None
framework._dygraph_tracer_ = self._tracer framework.global_var._dygraph_tracer_ = self._tracer
self._tracer = None self._tracer = None
self._in_guard = False self._in_guard = False
......
...@@ -36,7 +36,7 @@ class TestDy2staticException(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error): with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
self.assertTrue(to_static(self.dyfunc)(self.x)) self.assertTrue(to_static(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
......
...@@ -65,7 +65,7 @@ class TestDy2staticException(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error): with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True) paddle.jit.enable_to_static(True)
self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x)) self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
...@@ -463,7 +463,7 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): ...@@ -463,7 +463,7 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
# that the code block is under @to_static, but in this UT # that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_` # an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually. # a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle.fluid.dygraph.base._in_declarative_mode_ = False paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
paddle.jit.enable_to_static(False) paddle.jit.enable_to_static(False)
......
...@@ -25,7 +25,7 @@ from paddle import _C_ops, _legacy_C_ops ...@@ -25,7 +25,7 @@ from paddle import _C_ops, _legacy_C_ops
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core, framework, executor from paddle.fluid import core, framework, executor
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.framework import _in_eager_mode_ from paddle.fluid.framework import global_var
paddle.enable_static() paddle.enable_static()
np.random.seed(1243) np.random.seed(1243)
...@@ -135,7 +135,7 @@ class RunProgramNPUOpTest(unittest.TestCase): ...@@ -135,7 +135,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
def prepare_dygraph_input(self, place, return_param_list=False): def prepare_dygraph_input(self, place, return_param_list=False):
def create_var_base(is_input, name, np_value, stop_gradient): def create_var_base(is_input, name, np_value, stop_gradient):
if _in_eager_mode_: if global_var._in_eager_mode_:
var = core.eager.Tensor( var = core.eager.Tensor(
value=np_value, name=name, place=place, zero_copy=True value=np_value, name=name, place=place, zero_copy=True
) )
...@@ -176,7 +176,7 @@ class RunProgramNPUOpTest(unittest.TestCase): ...@@ -176,7 +176,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
for name in self.output_names['Out']: for name in self.output_names['Out']:
outputs['Out'].append(create_var_base(False, name)) outputs['Out'].append(create_var_base(False, name))
if _in_eager_mode_: if global_var._in_eager_mode_:
outputs['OutScope'] = [core.Scope()] outputs['OutScope'] = [core.Scope()]
else: else:
outputs['OutScope'] = framework._varbase_creator( outputs['OutScope'] = framework._varbase_creator(
......
...@@ -26,7 +26,7 @@ from paddle.fluid.executor import ( ...@@ -26,7 +26,7 @@ from paddle.fluid.executor import (
_is_dy2st_enable_standalone_executor, _is_dy2st_enable_standalone_executor,
_is_enable_standalone_executor, _is_enable_standalone_executor,
) )
from paddle.fluid.framework import _in_eager_mode_ from paddle.fluid.framework import global_var
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
paddle.enable_static() paddle.enable_static()
...@@ -177,7 +177,7 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -177,7 +177,7 @@ class RunProgramOpTest(unittest.TestCase):
def prepare_dygraph_input(self, place, return_param_list=False): def prepare_dygraph_input(self, place, return_param_list=False):
def create_var_base(is_input, name, np_value, stop_gradient): def create_var_base(is_input, name, np_value, stop_gradient):
if _in_eager_mode_: if global_var._in_eager_mode_:
var = core.eager.Tensor( var = core.eager.Tensor(
value=np_value, name=name, place=place, zero_copy=True value=np_value, name=name, place=place, zero_copy=True
) )
...@@ -218,7 +218,7 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -218,7 +218,7 @@ class RunProgramOpTest(unittest.TestCase):
for name in self.output_names['Out']: for name in self.output_names['Out']:
outputs['Out'].append(create_var_base(False, name)) outputs['Out'].append(create_var_base(False, name))
if _in_eager_mode_: if global_var._in_eager_mode_:
outputs['OutScope'] = [core.Scope()] outputs['OutScope'] = [core.Scope()]
else: else:
outputs['OutScope'] = framework._varbase_creator( outputs['OutScope'] = framework._varbase_creator(
......
...@@ -619,7 +619,7 @@ class PartialProgramLayer: ...@@ -619,7 +619,7 @@ class PartialProgramLayer:
if "@GRAD" in name: if "@GRAD" in name:
var_desc = block.vars[name].desc var_desc = block.vars[name].desc
var_base = None var_base = None
if not framework._in_eager_mode_: if not framework.global_var._in_eager_mode_:
var_base = core.VarBase( var_base = core.VarBase(
var_desc.dtype(), var_desc.dtype(),
var_desc.shape(), var_desc.shape(),
...@@ -874,7 +874,7 @@ class PartialProgramLayer: ...@@ -874,7 +874,7 @@ class PartialProgramLayer:
for i, value in enumerate(flatten_inputs): for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
var = None var = None
if not framework._in_eager_mode_: if not framework.global_var._in_eager_mode_:
var = core.VarBase( var = core.VarBase(
value=value, value=value,
name=self._inputs[i].desc.name(), name=self._inputs[i].desc.name(),
...@@ -918,7 +918,7 @@ class PartialProgramLayer: ...@@ -918,7 +918,7 @@ class PartialProgramLayer:
if var_desc.name() in out_varbase_map: if var_desc.name() in out_varbase_map:
return out_varbase_map[var_desc.name()] return out_varbase_map[var_desc.name()]
if not framework._in_eager_mode_: if not framework.global_var._in_eager_mode_:
var_base = core.VarBase( var_base = core.VarBase(
var_desc.dtype(), var_desc.dtype(),
var_desc.shape(), var_desc.shape(),
...@@ -949,7 +949,7 @@ class PartialProgramLayer: ...@@ -949,7 +949,7 @@ class PartialProgramLayer:
inner_scope = self._get_scope( inner_scope = self._get_scope(
program_id=program_id, use_scope_cache=use_scope_cache program_id=program_id, use_scope_cache=use_scope_cache
) )
if not framework._in_eager_mode_: if not framework.global_var._in_eager_mode_:
tmp_scope_vec = core.VarBase( tmp_scope_vec = core.VarBase(
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
[], [],
...@@ -1102,7 +1102,7 @@ def _create_fake_var(): ...@@ -1102,7 +1102,7 @@ def _create_fake_var():
""" """
Create a fake_var (force on CPU) to handle empty input or output Create a fake_var (force on CPU) to handle empty input or output
""" """
if not framework._in_eager_mode_: if not framework.global_var._in_eager_mode_:
return [ return [
core.VarBase( core.VarBase(
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
import paddle import paddle
from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.framework import _in_eager_mode_ from ..fluid.framework import global_var
from ..static import Variable from ..static import Variable
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
if _in_eager_mode_: if global_var._in_eager_mode_:
Tensor = paddle.fluid.framework.core.eager.Tensor Tensor = paddle.fluid.framework.core.eager.Tensor
else: else:
from ..framework import VarBase as Tensor from ..framework import VarBase as Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册