未验证 提交 18745e6f 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] fix switch static graph affects dataloader (#49821)

* rebase merge

* code fix

* fix bugs
上级 611da7fc
......@@ -55,7 +55,7 @@ from .framework.dtype import bool # noqa: F401
from .framework.dtype import complex64 # noqa: F401
from .framework.dtype import complex128 # noqa: F401
if fluid.framework._in_eager_mode_:
if fluid.framework.global_var._in_eager_mode_:
Tensor = framework.core.eager.Tensor
else:
from .framework import VarBase as Tensor # noqa: F401
......
......@@ -107,7 +107,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
each_tensor, (paddle.Tensor, core.eager.Tensor)
), "The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'."
else:
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
grad_tensors = []
else:
grad_tensors = [None] * len(tensors)
......@@ -119,7 +119,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
assert isinstance(retain_graph, bool), "retain_graph must be True or False"
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
core.eager.run_backward(tensors, grad_tensors, retain_graph)
else:
core.dygraph_run_backward(
......
......@@ -20,6 +20,7 @@ import sys
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.framework import global_var
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer
import logging
......@@ -44,7 +45,6 @@ __all__ = [
]
# Flag that indicates whether running code under `@to_static`
_in_declarative_mode_ = False
def in_declarative_mode():
......@@ -52,7 +52,7 @@ def in_declarative_mode():
Return a bool value that indicates whether running code under `@to_static`
"""
return _in_declarative_mode_
return global_var._in_declarative_mode_
def declarative_unsupport_argument_warning(
......@@ -86,11 +86,11 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):
global _in_declarative_mode_
original_val = _in_declarative_mode_
_in_declarative_mode_ = is_declarative
global global_var
original_val = global_var._in_declarative_mode_
global_var._in_declarative_mode_ = is_declarative
yield
_in_declarative_mode_ = original_val
global_var._in_declarative_mode_ = original_val
@signature_safe_contextmanager
......@@ -106,9 +106,6 @@ def program_desc_tracing_guard(enable):
tracer._enable_program_desc_tracing = original_val
_functional_dygraph_context_manager = None
@signature_safe_contextmanager
def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers
......@@ -228,12 +225,12 @@ def enable_dygraph(place=None):
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
global _functional_dygraph_context_manager
if _functional_dygraph_context_manager is None:
_functional_dygraph_context_manager = guard(
global global_var
if global_var._functional_dygraph_context_manager is None:
global_var._functional_dygraph_context_manager = guard(
place=_get_paddle_place(place)
)
_functional_dygraph_context_manager.__enter__()
global_var._functional_dygraph_context_manager.__enter__()
# call disable_dygraph when Python exit
CleanupFuncRegistrar.register(disable_dygraph)
......@@ -263,10 +260,10 @@ def disable_dygraph():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
global _functional_dygraph_context_manager
if _functional_dygraph_context_manager is not None:
_functional_dygraph_context_manager.__exit__(*sys.exc_info())
_functional_dygraph_context_manager = None
global global_var
if global_var._functional_dygraph_context_manager is not None:
global_var._functional_dygraph_context_manager.__exit__(*sys.exc_info())
global_var._functional_dygraph_context_manager = None
@signature_safe_contextmanager
......
......@@ -74,7 +74,7 @@ def monkey_patch_math_varbase():
@no_grad
def create_tensor(value, dtype, shape):
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
out = _C_ops.full(
shape, value, dtype, framework._current_expected_place()
)
......@@ -251,7 +251,7 @@ def monkey_patch_math_varbase():
# 2. create varbase for scalar
lhs_dtype = self.dtype
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
other_var_should_be = core.eager.Tensor
else:
other_var_should_be = core.VarBase
......@@ -486,7 +486,7 @@ def monkey_patch_math_varbase():
global _already_patch_varbase
global _already_patch_eager_tensor
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
local_already_patch = _already_patch_eager_tensor
_already_patch_eager_tensor = True
local_tensor = core.eager.Tensor
......@@ -496,7 +496,7 @@ def monkey_patch_math_varbase():
local_tensor = core.VarBase
if not local_already_patch:
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
for method_name in eager_cpp_level_patch:
method_impl = getattr(local_tensor, method_name, None)
if method_impl:
......
......@@ -54,7 +54,9 @@ class TensorHookRemoveHelper:
def __init__(self, tensor, hook_id):
self._tensor = (
tensor if framework._in_eager_mode_ else weakref.ref(tensor)
tensor
if framework.global_var._in_eager_mode_
else weakref.ref(tensor)
)
self._hook_id = hook_id
......@@ -65,7 +67,11 @@ class TensorHookRemoveHelper:
Returns:
bool: Return True if removed successfully
"""
tensor = self._tensor if framework._in_eager_mode_ else self._tensor()
tensor = (
self._tensor
if framework.global_var._in_eager_mode_
else self._tensor()
)
if tensor is not None:
res = tensor._remove_grad_hook(self._hook_id)
if res is True:
......@@ -178,7 +184,7 @@ def monkey_patch_varbase():
out = linear(t) # call with different weight
"""
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
base_tensor = core.eager.Tensor
else:
base_tensor = core.VarBase
......@@ -282,7 +288,7 @@ def monkey_patch_varbase():
)
record_event.begin()
if grad_tensor is not None:
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
assert isinstance(
grad_tensor, core.eager.Tensor
), "The type of grad_tensor must be paddle.Tensor"
......@@ -296,7 +302,7 @@ def monkey_patch_varbase():
grad_tensor.name, grad_tensor.shape, self.name, self.shape
)
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
if grad_tensor is None:
grad_tensor = []
else:
......@@ -311,7 +317,7 @@ def monkey_patch_varbase():
):
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
core.eager.run_backward(
[scaled_loss], grad_tensor, retain_graph
)
......@@ -323,7 +329,7 @@ def monkey_patch_varbase():
framework._dygraph_tracer(),
)
else:
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
core.eager.run_backward([self], grad_tensor, retain_graph)
else:
core.dygraph_run_backward(
......@@ -368,7 +374,7 @@ def monkey_patch_varbase():
# [500.]
"""
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
if self.grad is None:
return None
if self.grad.is_selected_rows():
......@@ -673,7 +679,7 @@ def monkey_patch_varbase():
# [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436],
# [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]])
"""
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
from paddle.tensor.to_string import tensor_to_string
return tensor_to_string(self)
......@@ -707,7 +713,7 @@ def monkey_patch_varbase():
raise RuntimeError(
"Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy"
)
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
new_varbase = core.eager.Tensor()
else:
new_varbase = core.VarBase()
......@@ -725,7 +731,7 @@ def monkey_patch_varbase():
assert (
numel == 1
), "When Variable is used as the condition of if/while , Variable can only contain one element."
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
assert self._is_initialized(), "tensor not initialized"
return bool(np.all(self.numpy() > 0))
else:
......@@ -850,7 +856,7 @@ def monkey_patch_varbase():
return _setitem_impl_(self, item, value)
else:
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
return self.__setitem_eager_tensor__(item, value)
else:
# Call c++ func __setitem_varbase__ to speedup.
......@@ -1020,7 +1026,7 @@ def monkey_patch_varbase():
def __hash__(self):
return hash(id(self))
if framework._in_eager_mode_ and not hasattr(core, "eager"):
if framework.global_var._in_eager_mode_ and not hasattr(core, "eager"):
return
for method_name, method in (
......@@ -1047,12 +1053,12 @@ def monkey_patch_varbase():
("to_dense", to_dense),
("to_sparse_coo", to_sparse_coo),
):
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
setattr(core.eager.Tensor, method_name, method)
else:
setattr(core.VarBase, method_name, method)
if framework._in_eager_mode_:
if framework.global_var._in_eager_mode_:
setattr(core.eager.Tensor, "_set_grad_ivar", _set_grad_ivar)
setattr(core.eager.Tensor, "value", value)
setattr(core.eager.Tensor, "cpu", cpu)
......
......@@ -36,6 +36,7 @@ import paddle.version as fluid_version
import warnings
import functools
from .variable_index import _getitem_impl_, _setitem_impl_
import threading
__all__ = [
'Program',
......@@ -70,8 +71,42 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
# use thread local to create thread save global variables.
class GlobalThreadLocal(threading.local):
def __init__(self):
"""
init the thread local data.
TODO(xiongkun): how to access another thread local data ?
"""
global _dygraph_tracer_
self._in_declarative_mode_ = False
self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_
self._in_eager_mode_ = True
def __str__(self):
strings = []
strings.append(
"_in_declarative_mode_:" + str(self._in_declarative_mode_)
)
strings.append(
"_functional_dygraph_context_manager:"
+ str(self._functional_dygraph_context_manager)
)
strings.append("_dygraph_tracer_:" + str(self._dygraph_tracer_))
strings.append("_in_eager_mode_:" + str(self._in_eager_mode_))
return "\n".join(strings)
def __setattr__(self, name, val):
if name == '_dygraph_tracer_':
global _dygraph_tracer_
_dygraph_tracer_ = val
self.__dict__[name] = val
_dygraph_tracer_ = None
_in_eager_mode_ = True
global_var = GlobalThreadLocal()
_global_expected_place_ = None
_current_device = None
global_prog_seed = 0
......@@ -155,20 +190,17 @@ def _switch_tensor_bind_type(is_eager):
def _enable_legacy_dygraph():
global _in_eager_mode_
_in_eager_mode_ = False
global_var._in_eager_mode_ = False
_update_monkey_methods(is_eager=False)
def _disable_legacy_dygraph():
global _in_eager_mode_
_in_eager_mode_ = True
global_var._in_eager_mode_ = True
_update_monkey_methods(is_eager=True)
def _in_eager_without_dygraph_check():
global _in_eager_mode_
return _in_eager_mode_
return global_var._in_eager_mode_
# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but
......@@ -177,7 +209,6 @@ _is_first_import_ = True
def _fallback_legacy_dygraph():
global _in_eager_mode_
global _is_first_import_
need_fallback = False
# Only enable eager on CPU/GPU/XPU
......@@ -187,12 +218,12 @@ def _fallback_legacy_dygraph():
or core.is_compiled_with_mlu()
)
if _in_eager_mode_ and is_not_support:
if global_var._in_eager_mode_ and is_not_support:
# switch into legacy dygraph mode
warnings.warn(
"We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. "
)
_in_eager_mode_ = False
global_var._in_eager_mode_ = False
if not _is_first_import_:
_enable_legacy_dygraph()
need_fallback = True
......@@ -234,11 +265,13 @@ def in_dygraph_mode():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
return (_dygraph_tracer_ is not None) and _in_eager_mode_
return (
global_var._dygraph_tracer_ is not None
) and global_var._in_eager_mode_
def _non_static_mode():
return _dygraph_tracer_ is not None
return global_var._dygraph_tracer_ is not None
@signature_safe_contextmanager
......@@ -603,7 +636,7 @@ non_static_only = wrap_decorator(_non_static_only_)
def _dygraph_tracer():
return _dygraph_tracer_
return global_var._dygraph_tracer_
def _global_flags():
......@@ -671,9 +704,8 @@ def _current_expected_place():
def _set_dygraph_tracer_expected_place(place):
global _dygraph_tracer_
if _dygraph_tracer_ is not None:
_dygraph_tracer_._expected_place = place
if global_var._dygraph_tracer_ is not None:
global_var._dygraph_tracer_._expected_place = place
def _set_expected_place(place):
......@@ -1315,7 +1347,7 @@ def _varbase_creator(
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if _in_eager_mode_:
if global_var._in_eager_mode_:
eager_tensor = core.eager.Tensor(
dtype if dtype else core.VarDesc.VarType.FP32,
list(shape) if shape else [],
......@@ -7460,16 +7492,17 @@ def _get_var(name, program=None):
@signature_safe_contextmanager
def _dygraph_guard(tracer):
global _dygraph_tracer_
tmp_tracer = _dygraph_tracer_
_dygraph_tracer_ = tracer
core._switch_tracer(tracer)
tmp_tracer = global_var._dygraph_tracer_
global_var._dygraph_tracer_ = tracer
if tracer is not None:
core._switch_tracer(tracer)
try:
yield
finally:
core._switch_tracer(tmp_tracer)
_dygraph_tracer_ = tmp_tracer
if tmp_tracer is not None:
core._switch_tracer(tmp_tracer)
global_var._dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager
......
......@@ -59,8 +59,8 @@ class LazyInitHelper:
self.enable()
if self._in_guard:
return
self._tracer = framework._dygraph_tracer_
framework._dygraph_tracer_ = None
self._tracer = framework.global_var._dygraph_tracer_
framework.global_var._dygraph_tracer_ = None
self._in_guard = True
def __exit__(self, *args, **kwargs):
......@@ -71,7 +71,7 @@ class LazyInitHelper:
if not self._in_guard:
return
assert self._tracer is not None
framework._dygraph_tracer_ = self._tracer
framework.global_var._dygraph_tracer_ = self._tracer
self._tracer = None
self._in_guard = False
......
......@@ -36,7 +36,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True)
self.assertTrue(to_static(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False
paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
paddle.jit.enable_to_static(False)
......
......@@ -65,7 +65,7 @@ class TestDy2staticException(unittest.TestCase):
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True)
self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False
paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
paddle.jit.enable_to_static(False)
......@@ -463,7 +463,7 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
# that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle.fluid.dygraph.base._in_declarative_mode_ = False
paddle.fluid.dygraph.base.global_var._in_declarative_mode_ = False
paddle.jit.enable_to_static(False)
......
......@@ -25,7 +25,7 @@ from paddle import _C_ops, _legacy_C_ops
import paddle.fluid as fluid
from paddle.fluid import core, framework, executor
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.framework import _in_eager_mode_
from paddle.fluid.framework import global_var
paddle.enable_static()
np.random.seed(1243)
......@@ -135,7 +135,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
def prepare_dygraph_input(self, place, return_param_list=False):
def create_var_base(is_input, name, np_value, stop_gradient):
if _in_eager_mode_:
if global_var._in_eager_mode_:
var = core.eager.Tensor(
value=np_value, name=name, place=place, zero_copy=True
)
......@@ -176,7 +176,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
for name in self.output_names['Out']:
outputs['Out'].append(create_var_base(False, name))
if _in_eager_mode_:
if global_var._in_eager_mode_:
outputs['OutScope'] = [core.Scope()]
else:
outputs['OutScope'] = framework._varbase_creator(
......
......@@ -26,7 +26,7 @@ from paddle.fluid.executor import (
_is_dy2st_enable_standalone_executor,
_is_enable_standalone_executor,
)
from paddle.fluid.framework import _in_eager_mode_
from paddle.fluid.framework import global_var
from paddle.fluid.layers.utils import _hash_with_id
paddle.enable_static()
......@@ -177,7 +177,7 @@ class RunProgramOpTest(unittest.TestCase):
def prepare_dygraph_input(self, place, return_param_list=False):
def create_var_base(is_input, name, np_value, stop_gradient):
if _in_eager_mode_:
if global_var._in_eager_mode_:
var = core.eager.Tensor(
value=np_value, name=name, place=place, zero_copy=True
)
......@@ -218,7 +218,7 @@ class RunProgramOpTest(unittest.TestCase):
for name in self.output_names['Out']:
outputs['Out'].append(create_var_base(False, name))
if _in_eager_mode_:
if global_var._in_eager_mode_:
outputs['OutScope'] = [core.Scope()]
else:
outputs['OutScope'] = framework._varbase_creator(
......
......@@ -619,7 +619,7 @@ class PartialProgramLayer:
if "@GRAD" in name:
var_desc = block.vars[name].desc
var_base = None
if not framework._in_eager_mode_:
if not framework.global_var._in_eager_mode_:
var_base = core.VarBase(
var_desc.dtype(),
var_desc.shape(),
......@@ -874,7 +874,7 @@ class PartialProgramLayer:
for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray):
var = None
if not framework._in_eager_mode_:
if not framework.global_var._in_eager_mode_:
var = core.VarBase(
value=value,
name=self._inputs[i].desc.name(),
......@@ -918,7 +918,7 @@ class PartialProgramLayer:
if var_desc.name() in out_varbase_map:
return out_varbase_map[var_desc.name()]
if not framework._in_eager_mode_:
if not framework.global_var._in_eager_mode_:
var_base = core.VarBase(
var_desc.dtype(),
var_desc.shape(),
......@@ -949,7 +949,7 @@ class PartialProgramLayer:
inner_scope = self._get_scope(
program_id=program_id, use_scope_cache=use_scope_cache
)
if not framework._in_eager_mode_:
if not framework.global_var._in_eager_mode_:
tmp_scope_vec = core.VarBase(
core.VarDesc.VarType.FP32,
[],
......@@ -1102,7 +1102,7 @@ def _create_fake_var():
"""
Create a fake_var (force on CPU) to handle empty input or output
"""
if not framework._in_eager_mode_:
if not framework.global_var._in_eager_mode_:
return [
core.VarBase(
core.VarDesc.VarType.FP32,
......
......@@ -17,11 +17,11 @@
import paddle
from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.framework import _in_eager_mode_
from ..fluid.framework import global_var
from ..static import Variable
from .layer_function_generator import templatedoc
if _in_eager_mode_:
if global_var._in_eager_mode_:
Tensor = paddle.fluid.framework.core.eager.Tensor
else:
from ..framework import VarBase as Tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册