未验证 提交 fd591ecb 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Polish enable/disable_legacy_dygraph logic (#41364)

* [Eager]Polish enable/disable_legacy_dygraph logic

* merge yunfei PR

* merge other pr
上级 1ae0730f
......@@ -115,15 +115,39 @@ def _update_monkey_methods(is_eager):
from .dygraph.varbase_patch_methods import monkey_patch_varbase
from .dygraph import monkey_patch_math_varbase
global _already_patch_eager_tensor
global _already_patch_varbase
assert isinstance(is_eager, bool)
# switch into eager mode
if is_eager:
_C_ops.switch_to_eager_ops()
if not _already_patch_eager_tensor:
monkey_patch_varbase()
monkey_patch_math_varbase()
_already_patch_eager_tensor = True
# switch back into legacy mode
else:
_C_ops.switch_to_core_ops()
if not _already_patch_varbase:
monkey_patch_varbase()
monkey_patch_math_varbase()
_already_patch_varbase = True
# switch Paddle.Tensor bind type
_switch_tensor_bind_type(is_eager)
def _switch_tensor_bind_type(is_eager):
import paddle
if is_eager:
paddle.Tensor = core.eager.Tensor
else:
paddle.Tensor = core.VarBase
paddle.Tensor.__qualname__ = 'Tensor'
def _enable_legacy_dygraph():
global _in_eager_mode_
......@@ -183,35 +207,10 @@ def _non_static_mode():
@signature_safe_contextmanager
def _test_eager_guard(place=None):
_disable_legacy_dygraph()
from paddle import _C_ops
_C_ops.switch_to_eager_ops()
global _already_patch_eager_tensor
global _already_patch_varbase
from .dygraph.varbase_patch_methods import monkey_patch_varbase
from .dygraph import monkey_patch_math_varbase
if not _already_patch_eager_tensor:
monkey_patch_varbase()
monkey_patch_math_varbase()
# Ugly setting
from paddle.tensor.manipulation import fill_, zero_, fill_diagonal_, fill_diagonal_tensor_, tolist
setattr(core.eager.Tensor, 'fill_', fill_)
setattr(core.eager.Tensor, 'zero_', zero_)
setattr(core.eager.Tensor, 'fill_diagonal_', fill_diagonal_)
setattr(core.eager.Tensor, 'fill_diagonal_tensor_',
fill_diagonal_tensor_)
setattr(core.eager.Tensor, 'tolist', tolist)
_already_patch_eager_tensor = True
try:
yield
finally:
_enable_legacy_dygraph()
if not _already_patch_varbase:
monkey_patch_varbase()
monkey_patch_math_varbase()
_already_patch_varbase = True
_C_ops.switch_to_core_ops()
global_ipu_index = None
......
......@@ -76,9 +76,6 @@ def fill_(x, value):
float(value), "value_int", int(value))
setattr(core.VarBase, 'fill_', fill_)
@dygraph_only
def zero_(x):
"""
......@@ -107,9 +104,6 @@ def zero_(x):
return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0))
setattr(core.VarBase, 'zero_', zero_)
@dygraph_only
def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
"""
......@@ -156,9 +150,6 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
True)
setattr(core.VarBase, 'fill_diagonal_', fill_diagonal_)
def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False):
inshape = x.shape
assert dim1 < len(inshape) and dim1 >= -len(inshape), (
......@@ -226,9 +217,6 @@ def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None):
x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=True)
setattr(core.VarBase, 'fill_diagonal_tensor_', fill_diagonal_tensor_)
def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None):
"""
This function fill the source Tensor y into the x Tensor's diagonal.
......@@ -262,12 +250,6 @@ def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None):
x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=False)
setattr(core.VarBase, 'fill_diagonal_tensor', fill_diagonal_tensor)
if _in_eager_without_dygraph_check():
setattr(core.eager.Tensor, 'fill_diagonal_tensor', fill_diagonal_tensor)
@dygraph_only
def tolist(x):
"""
......@@ -301,9 +283,6 @@ def tolist(x):
return x.numpy().tolist()
setattr(core.VarBase, 'tolist', tolist)
def concat(x, axis=0, name=None):
"""
......@@ -2961,3 +2940,17 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'):
values = paddle.broadcast_to(values, indices.shape)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce)
# TODO(dev): We need avoid implementing it by this way.
__METHODS = {
'fill_': fill_,
'zero_': zero_,
'fill_diagonal_': fill_diagonal_,
'fill_diagonal_tensor_': fill_diagonal_tensor_,
"fill_diagonal_tensor": fill_diagonal_tensor,
'tolist': tolist
}
for name, func in __METHODS.items():
setattr(core.VarBase, name, func)
setattr(core.eager.Tensor, name, func)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册