未验证 提交 fd591ecb 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Polish enable/disable_legacy_dygraph logic (#41364)

* [Eager]Polish enable/disable_legacy_dygraph logic

* merge yunfei PR

* merge other pr
上级 1ae0730f
...@@ -115,15 +115,39 @@ def _update_monkey_methods(is_eager): ...@@ -115,15 +115,39 @@ def _update_monkey_methods(is_eager):
from .dygraph.varbase_patch_methods import monkey_patch_varbase from .dygraph.varbase_patch_methods import monkey_patch_varbase
from .dygraph import monkey_patch_math_varbase from .dygraph import monkey_patch_math_varbase
global _already_patch_eager_tensor
global _already_patch_varbase
assert isinstance(is_eager, bool) assert isinstance(is_eager, bool)
# switch into eager mode
if is_eager: if is_eager:
_C_ops.switch_to_eager_ops() _C_ops.switch_to_eager_ops()
if not _already_patch_eager_tensor:
monkey_patch_varbase()
monkey_patch_math_varbase()
_already_patch_eager_tensor = True
# switch back into legacy mode
else: else:
_C_ops.switch_to_core_ops() _C_ops.switch_to_core_ops()
if not _already_patch_varbase:
monkey_patch_varbase() monkey_patch_varbase()
monkey_patch_math_varbase() monkey_patch_math_varbase()
_already_patch_varbase = True
# switch Paddle.Tensor bind type
_switch_tensor_bind_type(is_eager)
def _switch_tensor_bind_type(is_eager):
import paddle
if is_eager:
paddle.Tensor = core.eager.Tensor
else:
paddle.Tensor = core.VarBase
paddle.Tensor.__qualname__ = 'Tensor'
def _enable_legacy_dygraph(): def _enable_legacy_dygraph():
global _in_eager_mode_ global _in_eager_mode_
...@@ -183,35 +207,10 @@ def _non_static_mode(): ...@@ -183,35 +207,10 @@ def _non_static_mode():
@signature_safe_contextmanager @signature_safe_contextmanager
def _test_eager_guard(place=None): def _test_eager_guard(place=None):
_disable_legacy_dygraph() _disable_legacy_dygraph()
from paddle import _C_ops
_C_ops.switch_to_eager_ops()
global _already_patch_eager_tensor
global _already_patch_varbase
from .dygraph.varbase_patch_methods import monkey_patch_varbase
from .dygraph import monkey_patch_math_varbase
if not _already_patch_eager_tensor:
monkey_patch_varbase()
monkey_patch_math_varbase()
# Ugly setting
from paddle.tensor.manipulation import fill_, zero_, fill_diagonal_, fill_diagonal_tensor_, tolist
setattr(core.eager.Tensor, 'fill_', fill_)
setattr(core.eager.Tensor, 'zero_', zero_)
setattr(core.eager.Tensor, 'fill_diagonal_', fill_diagonal_)
setattr(core.eager.Tensor, 'fill_diagonal_tensor_',
fill_diagonal_tensor_)
setattr(core.eager.Tensor, 'tolist', tolist)
_already_patch_eager_tensor = True
try: try:
yield yield
finally: finally:
_enable_legacy_dygraph() _enable_legacy_dygraph()
if not _already_patch_varbase:
monkey_patch_varbase()
monkey_patch_math_varbase()
_already_patch_varbase = True
_C_ops.switch_to_core_ops()
global_ipu_index = None global_ipu_index = None
......
...@@ -76,9 +76,6 @@ def fill_(x, value): ...@@ -76,9 +76,6 @@ def fill_(x, value):
float(value), "value_int", int(value)) float(value), "value_int", int(value))
setattr(core.VarBase, 'fill_', fill_)
@dygraph_only @dygraph_only
def zero_(x): def zero_(x):
""" """
...@@ -107,9 +104,6 @@ def zero_(x): ...@@ -107,9 +104,6 @@ def zero_(x):
return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0)) return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0))
setattr(core.VarBase, 'zero_', zero_)
@dygraph_only @dygraph_only
def fill_diagonal_(x, value, offset=0, wrap=False, name=None): def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
""" """
...@@ -156,9 +150,6 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None): ...@@ -156,9 +150,6 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
True) True)
setattr(core.VarBase, 'fill_diagonal_', fill_diagonal_)
def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False): def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False):
inshape = x.shape inshape = x.shape
assert dim1 < len(inshape) and dim1 >= -len(inshape), ( assert dim1 < len(inshape) and dim1 >= -len(inshape), (
...@@ -226,9 +217,6 @@ def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None): ...@@ -226,9 +217,6 @@ def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None):
x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=True) x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=True)
setattr(core.VarBase, 'fill_diagonal_tensor_', fill_diagonal_tensor_)
def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None): def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None):
""" """
This function fill the source Tensor y into the x Tensor's diagonal. This function fill the source Tensor y into the x Tensor's diagonal.
...@@ -262,12 +250,6 @@ def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None): ...@@ -262,12 +250,6 @@ def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None):
x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=False) x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=False)
setattr(core.VarBase, 'fill_diagonal_tensor', fill_diagonal_tensor)
if _in_eager_without_dygraph_check():
setattr(core.eager.Tensor, 'fill_diagonal_tensor', fill_diagonal_tensor)
@dygraph_only @dygraph_only
def tolist(x): def tolist(x):
""" """
...@@ -301,9 +283,6 @@ def tolist(x): ...@@ -301,9 +283,6 @@ def tolist(x):
return x.numpy().tolist() return x.numpy().tolist()
setattr(core.VarBase, 'tolist', tolist)
def concat(x, axis=0, name=None): def concat(x, axis=0, name=None):
""" """
...@@ -2961,3 +2940,17 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'): ...@@ -2961,3 +2940,17 @@ def put_along_axis_(arr, indices, values, axis, reduce='assign'):
values = paddle.broadcast_to(values, indices.shape) values = paddle.broadcast_to(values, indices.shape)
return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce", return _C_ops.put_along_axis_(arr, indices, values, "Axis", axis, "Reduce",
reduce) reduce)
# TODO(dev): We need avoid implementing it by this way.
__METHODS = {
'fill_': fill_,
'zero_': zero_,
'fill_diagonal_': fill_diagonal_,
'fill_diagonal_tensor_': fill_diagonal_tensor_,
"fill_diagonal_tensor": fill_diagonal_tensor,
'tolist': tolist
}
for name, func in __METHODS.items():
setattr(core.VarBase, name, func)
setattr(core.eager.Tensor, name, func)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册