未验证 提交 b1333175 编写于 作者: meteor135's avatar meteor135 提交者: GitHub

[dygraph]remove legacy code : _in_eager_mode_ and _in_eager_without_dygraph_check() (#53761)

* remove _in_eager_mode_

* remove _in_eager_mode_
上级 98100fd2
......@@ -107,10 +107,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
each_tensor, (paddle.Tensor, core.eager.Tensor)
), "The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'."
else:
if framework.global_var._in_eager_mode_:
grad_tensors = []
else:
grad_tensors = [None] * len(tensors)
if len(grad_tensors) > 0:
assert len(tensors) == len(
......@@ -119,9 +116,4 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
assert isinstance(retain_graph, bool), "retain_graph must be True or False"
if framework.global_var._in_eager_mode_:
core.eager.run_backward(tensors, grad_tensors, retain_graph)
else:
core.dygraph_run_backward(
tensors, grad_tensors, retain_graph, framework._dygraph_tracer()
)
......@@ -24,7 +24,6 @@ from .framework import (
default_main_program,
_current_expected_place,
_non_static_mode,
_in_eager_without_dygraph_check,
)
from .framework import _cpu_num, _cuda_ids
......
......@@ -26,10 +26,7 @@ from .tracer import Tracer
import logging
from ..data_feeder import convert_dtype
import warnings
from ..framework import (
_get_paddle_place,
_in_eager_without_dygraph_check,
)
from ..framework import _get_paddle_place
import paddle
import warnings
......@@ -796,14 +793,9 @@ def grad(
var, core.eager.Tensor
), "no_grad_vars can only contains Tensor"
else:
if _in_eager_without_dygraph_check():
raise AssertionError(
"no_grad_vars must be None, Tensor or list/tuple/set of Tensors"
)
else:
raise AssertionError(
"no_grad_vars must be None, Variable or list/tuple/set of Variables"
)
assert isinstance(create_graph, bool), "create_graph must be True or False"
......@@ -819,7 +811,6 @@ def grad(
assert isinstance(only_inputs, bool), "only_inputs must be True or False"
assert only_inputs, "only_inputs=False is not supported yet"
if _in_eager_without_dygraph_check():
return core.eager.run_partial_grad(
outputs,
inputs,
......@@ -830,20 +821,6 @@ def grad(
allow_unused,
no_grad_vars,
)
else:
place = core.Place()
place.set_place(framework._current_expected_place())
return core.dygraph_partial_grad(
inputs,
outputs,
grad_outputs,
no_grad_vars,
place,
create_graph,
retain_graph,
allow_unused,
only_inputs,
)
@framework.dygraph_only
......
......@@ -71,31 +71,6 @@ def monkey_patch_math_tensor():
The difference is, in dygraph mode, use auto-generated op functions for better performance.
"""
@no_grad
def create_tensor(value, dtype, shape):
if framework.global_var._in_eager_mode_:
out = _C_ops.full(
shape, value, dtype, framework._current_expected_place()
)
else:
out = framework_create_tensor(dtype=dtype)
out = _legacy_C_ops.fill_constant(
out,
'dtype',
dtype,
'shape',
shape,
'value',
value,
'force_cpu',
False,
)
out.stop_gradient = True
return out
def create_scalar(value, dtype):
return create_tensor(value, dtype, shape=[])
def astype(self, dtype):
"""
......@@ -191,253 +166,6 @@ def monkey_patch_math_tensor():
out = _C_ops.transpose(var, perm)
return out
def _scalar_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)
def _scalar_sub_(var, value):
return _scalar_elementwise_op_(var, 1.0, -value)
def _scalar_rsub_(var, value):
return _scalar_elementwise_op_(var, -1.0, value)
def _scalar_mul_(var, value):
return _scalar_elementwise_op_(var, value, 0.0)
def _scalar_div_(var, value):
return _scalar_elementwise_op_(var, 1.0 / value, 0.0)
# for binary operator such as elementwise, compare
def _binary_creator_(
method_name,
op_type,
reverse=False,
scalar_method=None,
call_final_api=False,
):
def __impl__(self, other_var):
# 1. scalar exists cases
# we need combine the tensor.dtype and scalar.dtype, cast correct object
if isinstance(other_var, float):
# in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
if self.dtype in _supported_int_dtype_:
self = astype(self, 'float32')
# here use `scale` replace `elementwise` to get better performance
# but only +, -, *, / can use this method
if scalar_method is not None:
return scalar_method(self, other_var)
elif isinstance(other_var, int):
# in all cases(+, -, *, /, **, //, %), we can cast it to float
# because the output tensor.dtype depend on the type of input tensor
other_var = float(other_var)
# division is a special case
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
# the division result can only guarantee the numerical accuracy of 6 digits
# after the decimal point. The result of numpy calculation is of float64 type,
# so the calculation result here and the calculation result of numpy are
# different after 6 decimal point. If necessary, we can also use float64 here.
# torch's behavior here is consistent with ours
if (
op_type == "divide" or op_type == "elementwise_div"
) and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32')
# here use `scale` replace `elementwise` to get better performance
# but only +, -, *, / can use this method
if scalar_method is not None:
return scalar_method(self, other_var)
else:
# do nothing
pass
# 2. create Tensor for scalar
lhs_dtype = self.dtype
other_var_should_be = core.eager.Tensor
if not isinstance(other_var, other_var_should_be):
if isinstance(other_var, complex):
import paddle
other_var = paddle.to_tensor(other_var, dtype='complex64')
else:
if reverse:
other_var = create_tensor(
other_var, dtype=lhs_dtype, shape=self.shape
)
else:
# add fill_op
other_var = create_scalar(
value=other_var, dtype=lhs_dtype
)
# 3. promote types or unify right var type to left var
rhs_dtype = other_var.dtype
if lhs_dtype != rhs_dtype:
if method_name in _supported_promote_complex_types_ and (
lhs_dtype in _complex_dtypes or rhs_dtype in _complex_dtypes
):
# only when lhs_dtype or rhs_dtype is complex type,
# the dtype will promote, in other cases, directly
# use lhs_dtype, this is consistent will original rule
promote_dtype = core._promote_types_if_complex_exists(
lhs_dtype, rhs_dtype
)
self = (
self
if lhs_dtype == promote_dtype
else astype(self, promote_dtype)
)
other_var = (
other_var
if rhs_dtype == promote_dtype
else astype(other_var, promote_dtype)
)
else:
warnings.warn(
'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'.format(
lhs_dtype, rhs_dtype, lhs_dtype
)
)
other_var = astype(other_var, lhs_dtype)
if reverse:
tmp = self
self = other_var
other_var = tmp
if (
op_type == "divide" or op_type == "elementwise_div"
) and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32')
other_var = astype(other_var, 'float32')
# 4. calculation
axis = -1
if in_dygraph_mode():
math_op = getattr(_C_ops, op_type)
else:
math_op = getattr(_legacy_C_ops, op_type)
if call_final_api:
if op_type == "matmul":
return math_op(self, other_var, False, False)
if op_type == "pow":
if isinstance(other_var, core.eager.Tensor):
return _C_ops.elementwise_pow(self, other_var)
else:
return _C_ops.elementwise_pow(self, other_var)
return math_op(self, other_var, -1)
return math_op(self, other_var, 'axis', axis)
if call_final_api:
comment = ""
else:
comment = OpProtoHolder.instance().get_op_proto(op_type).comment
__impl__.__doc__ = """
{0}
Args:
other_var(Tensor|float|int): right hand Tensor
Returns:
Tensor
""".format(
comment
)
__impl__.__name__ = method_name
return __impl__
tensor_methods = [
('__neg__', _neg_),
('__float__', _float_),
('__long__', _long_),
('__int__', _int_),
('__len__', _len_),
('__index__', _index_),
('astype', astype),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
('ndim', _ndim_),
('size', _size_),
('T', _T_),
(
'__add__',
_binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),
),
# a+b == b+a. Do not need to reverse explicitly
(
'__radd__',
_binary_creator_(
'__radd__', 'elementwise_add', False, _scalar_add_
),
),
(
'__sub__',
_binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_),
),
(
'__rsub__',
_binary_creator_(
'__rsub__', 'elementwise_sub', True, _scalar_rsub_
),
),
(
'__mul__',
_binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_),
),
## a*b == b*a. Do not need to reverse explicitly
(
'__rmul__',
_binary_creator_(
'__rmul__', 'elementwise_mul', False, _scalar_mul_
),
),
(
'__div__',
_binary_creator_('__div__', 'elementwise_div', False, _scalar_div_),
),
(
'__truediv__',
_binary_creator_(
'__truediv__', 'elementwise_div', False, _scalar_div_
),
),
(
'__rdiv__',
_binary_creator_('__rdiv__', 'elementwise_div', True, None),
),
(
'__rtruediv__',
_binary_creator_('rtruediv__', 'elementwise_div', True, None),
),
(
'__pow__',
_binary_creator_('__pow__', 'elementwise_pow', False, None),
),
(
'__rpow__',
_binary_creator_('__rpow__', 'elementwise_pow', True, None),
),
(
'__floordiv__',
_binary_creator_(
'__floordiv__', 'elementwise_floordiv', False, None
),
),
(
'__mod__',
_binary_creator_('__mod__', 'elementwise_mod', False, None),
),
(
'__matmul__',
_binary_creator_('__matmul__', "matmul_v2", False, None),
),
## for logical compare
('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
('__array_ufunc__', None),
]
eager_methods = [
('__neg__', _neg_),
('__float__', _float_),
......@@ -486,7 +214,6 @@ def monkey_patch_math_tensor():
local_tensor = core.eager.Tensor
if not local_already_patch:
if framework.global_var._in_eager_mode_:
for method_name in eager_cpp_level_patch:
method_impl = getattr(local_tensor, method_name, None)
if method_impl:
......@@ -496,12 +223,6 @@ def monkey_patch_math_tensor():
method_name = method[0]
method_impl = method[1]
setattr(local_tensor, method_name, method_impl)
else:
for method in tensor_methods:
method_name = method[0]
method_impl = method[1]
setattr(local_tensor, method_name, method_impl)
else:
import paddle.tensor
......
......@@ -54,11 +54,7 @@ class TensorHookRemoveHelper:
"""
def __init__(self, tensor, hook_id):
self._tensor = (
tensor
if framework.global_var._in_eager_mode_
else weakref.ref(tensor)
)
self._tensor = tensor
self._hook_id = hook_id
def remove(self):
......@@ -68,11 +64,7 @@ class TensorHookRemoveHelper:
Returns:
bool: Return True if removed successfully
"""
tensor = (
self._tensor
if framework.global_var._in_eager_mode_
else self._tensor()
)
tensor = self._tensor
if tensor is not None:
res = tensor._remove_grad_hook(self._hook_id)
if res is True:
......@@ -285,21 +277,16 @@ def monkey_patch_tensor():
)
record_event.begin()
if grad_tensor is not None:
if framework.global_var._in_eager_mode_:
assert isinstance(
grad_tensor, core.eager.Tensor
), "The type of grad_tensor must be paddle.Tensor"
else:
assert isinstance(
grad_tensor, paddle.Tensor
), "The type of grad_tensor must be paddle.Tensor"
assert (
grad_tensor.shape == self.shape
), "Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format(
grad_tensor.name, grad_tensor.shape, self.name, self.shape
)
if framework.global_var._in_eager_mode_:
if grad_tensor is None:
grad_tensor = []
else:
......@@ -307,15 +294,9 @@ def monkey_patch_tensor():
if _grad_scalar:
# When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
self = _grad_scalar.scale(self)
if framework.global_var._in_eager_mode_:
core.eager.run_backward([self], grad_tensor, retain_graph)
else:
core.dygraph_run_backward(
[self],
[grad_tensor],
retain_graph,
framework._dygraph_tracer(),
)
if in_profiler_mode():
record_event.end()
else:
......@@ -352,31 +333,11 @@ def monkey_patch_tensor():
# [500.]
"""
if framework.global_var._in_eager_mode_:
if self.grad is None:
return None
if self.grad.is_selected_rows():
return (np.array(self.grad), np.array(self.grad.rows()))
return np.array(self.grad)
else:
if self._grad_ivar() is None:
return None
new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
new_ivar = new_ivar._copy_to(core.CPUPlace(), True)
if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
return (
np.array(new_ivar.value().get_selected_rows().get_tensor()),
np.array(new_ivar.value().get_selected_rows().rows()),
)
else:
return np.array(new_ivar.value().get_tensor())
@framework.dygraph_only
def register_hook(self, hook):
......@@ -705,13 +666,8 @@ def monkey_patch_tensor():
assert (
numel == 1
), "When Variable is used as the condition of if/while , Variable can only contain one element."
if framework.global_var._in_eager_mode_:
assert self._is_initialized(), "tensor not initialized"
return bool(np.array(self) > 0)
else:
tensor = self.value().get_tensor()
assert tensor._is_initialized(), "tensor not initialized"
return bool(np.array(tensor) > 0)
def __bool__(self):
return self.__nonzero__()
......@@ -830,11 +786,7 @@ def monkey_patch_tensor():
return _setitem_impl_(self, item, value)
else:
if framework.global_var._in_eager_mode_:
return self.__setitem_eager_tensor__(item, value)
else:
# Call c++ func __setitem_varbase__ to speedup.
return self.__setitem_varbase__(item, value)
@framework.dygraph_only
def _set_grad_ivar(self, value):
......@@ -1000,7 +952,7 @@ def monkey_patch_tensor():
def __hash__(self):
return hash(id(self))
if framework.global_var._in_eager_mode_ and not hasattr(core, "eager"):
if not hasattr(core, "eager"):
return
for method_name, method in (
......
......@@ -83,7 +83,6 @@ class GlobalThreadLocal(threading.local):
self._in_declarative_mode_ = False
self._functional_dygraph_context_manager = None
self._dygraph_tracer_ = _dygraph_tracer_
self._in_eager_mode_ = True
def __str__(self):
strings = []
......@@ -95,7 +94,6 @@ class GlobalThreadLocal(threading.local):
+ str(self._functional_dygraph_context_manager)
)
strings.append("_dygraph_tracer_:" + str(self._dygraph_tracer_))
strings.append("_in_eager_mode_:" + str(self._in_eager_mode_))
return "\n".join(strings)
def __setattr__(self, name, val):
......@@ -180,10 +178,6 @@ extra_op_attrs = {
# to make sure in most case, we find new dygraph mode first with only one if statement.
def _in_eager_without_dygraph_check():
return global_var._in_eager_mode_
# FIXME(dev): We haven't fully verified eager mode on XPU et.al but
# only GPU/CPU. Remove this after we improve this feature.
_is_first_import_ = True
......@@ -216,9 +210,7 @@ def in_dygraph_mode():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
return (
global_var._dygraph_tracer_ is not None
) and global_var._in_eager_mode_
return global_var._dygraph_tracer_ is not None
def _non_static_mode():
......
......@@ -22,7 +22,6 @@ from .framework import (
default_startup_program,
_non_static_mode,
_current_expected_place,
_in_eager_without_dygraph_check,
)
from . import unique_name
from .param_attr import ParamAttr, WeightNormParamAttr
......
......@@ -29,7 +29,6 @@ from .framework import (
_non_static_mode,
cpu_places,
_current_expected_place,
_in_eager_without_dygraph_check,
)
from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider
......@@ -663,12 +662,9 @@ class DygraphGeneratorLoader(DataLoaderBase):
def __next__(self):
try:
if _in_eager_without_dygraph_check():
return core.eager.read_next_tensor_list(
self._reader.read_next_list()[0]
)
else:
return self._reader.read_next_var_list()
except StopIteration:
self._reset()
raise
......
......@@ -21,7 +21,6 @@ import paddle
from paddle import _legacy_C_ops, fluid
from paddle.fluid import core, framework
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import global_var
paddle.enable_static()
......@@ -207,16 +206,7 @@ class RunProgramOpTest(unittest.TestCase):
for name in self.output_names['Out']:
outputs['Out'].append(create_var_base(False, name))
if global_var._in_eager_mode_:
outputs['OutScope'] = [core.Scope()]
else:
outputs['OutScope'] = framework._create_tensor(
type=core.VarDesc.VarType.STEP_SCOPES,
name="program_out_scope",
persistable=True,
)
inner_scope = core.Scope()
outputs['OutScope'].value().set_scope(inner_scope)
outputs['DOut'] = [create_var_base(False, "Fake_var")]
return outputs
......
......@@ -24,7 +24,6 @@ from paddle.fluid import core
from paddle.fluid.framework import (
Variable,
_current_expected_place,
_in_eager_without_dygraph_check,
default_main_program,
device_guard,
in_dygraph_mode,
......@@ -736,9 +735,7 @@ class Optimizer:
name=var_name,
persistable=True,
dtype=dtype or param.dtype,
type=core.VarDesc.VarType.LOD_TENSOR
if framework._in_eager_without_dygraph_check()
else (param.type if type is None else type),
type=core.VarDesc.VarType.LOD_TENSOR,
shape=shape,
belong_to_optimizer=True,
)
......@@ -1380,11 +1377,8 @@ class Optimizer:
if not p.stop_gradient:
param_list.append(p)
if _in_eager_without_dygraph_check():
for p in param_list:
p.clear_gradient(set_to_zero)
else:
core.clear_gradients(param_list, set_to_zero)
@imperative_base.no_grad()
def minimize(
......
......@@ -30,11 +30,7 @@ from ..fluid.data_feeder import (
convert_dtype,
convert_float_to_uint16,
)
from ..fluid.framework import (
Variable,
_in_eager_without_dygraph_check,
device_guard,
)
from ..fluid.framework import Variable, device_guard
from ..fluid.param_attr import ParamAttr
from ..framework import (
LayerHelper,
......@@ -625,7 +621,7 @@ def _to_tensor_non_static(data, dtype=None, place=None, stop_gradient=True):
if dtype:
data = _handle_np_dtype(data, dtype)
if _in_eager_without_dygraph_check() and isinstance(data, np.ndarray):
if isinstance(data, np.ndarray):
return core.eager.Tensor(
value=data,
place=place,
......
......@@ -776,7 +776,7 @@ class TestModelFunction(unittest.TestCase):
paddle.summary(nlp_net, (1, 1, 2))
def test_static_flops(self):
if paddle.fluid.framework._in_eager_without_dygraph_check():
if True:
return
paddle.disable_static()
net = models.__dict__['mobilenet_v2'](pretrained=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册