未验证 提交 5ca3bc6d 编写于 作者: M Meteor Liu 提交者: GitHub

rename monkey_patch_{math_}varbase as monkey_patch_{math_}tensor (#53191)

* rename monkey_patch_varbase as monkey_patch_tensor & monkey_patch_math_varbase as monkey_patch_math_tensor

* rename monkey_patch_varbase as monkey_patch_tensor & monkey_patch_math_varbase as monkey_patch_math_tensor

* rename monkey_patch_varbase as monkey_patch_tensor & monkey_patch_math_varbase as monkey_patch_math_tensor v2

* rename monkey_patch_varbase as monkey_patch_tensor & monkey_patch_math_varbase as monkey_patch_math_tensor fixed bug
上级 a1f8f411
...@@ -24,11 +24,15 @@ except ImportError: ...@@ -24,11 +24,15 @@ except ImportError:
) )
from .batch import batch # noqa: F401 from .batch import batch # noqa: F401
# Do the *DUPLICATED* monkey-patch for the tensor object.
# We need remove the duplicated code here once we fix
# the illogical implement in the monkey-patch methods later.
from .framework import monkey_patch_variable from .framework import monkey_patch_variable
from .framework import monkey_patch_math_varbase from .framework import monkey_patch_math_tensor
monkey_patch_variable() monkey_patch_variable()
monkey_patch_math_varbase() monkey_patch_math_tensor()
from .framework import disable_signal_handler # noqa: F401 from .framework import disable_signal_handler # noqa: F401
from .framework import get_flags # noqa: F401 from .framework import get_flags # noqa: F401
......
...@@ -79,7 +79,7 @@ from . import compiler ...@@ -79,7 +79,7 @@ from . import compiler
from .compiler import * from .compiler import *
from paddle.fluid.layers.math_op_patch import monkey_patch_variable from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from .dygraph.base import enable_dygraph, disable_dygraph from .dygraph.base import enable_dygraph, disable_dygraph
from .dygraph.varbase_patch_methods import monkey_patch_varbase from .dygraph.tensor_patch_methods import monkey_patch_tensor
from .core import _cuda_synchronize from .core import _cuda_synchronize
from .trainer_desc import ( from .trainer_desc import (
TrainerDesc, TrainerDesc,
...@@ -211,7 +211,7 @@ def __bootstrap__(): ...@@ -211,7 +211,7 @@ def __bootstrap__():
# Consider paddle.init(args) or paddle.main(args) # Consider paddle.init(args) or paddle.main(args)
monkey_patch_variable() monkey_patch_variable()
__bootstrap__() __bootstrap__()
monkey_patch_varbase() monkey_patch_tensor()
# NOTE(Aurelius84): clean up ExecutorCacheInfo in advance manually. # NOTE(Aurelius84): clean up ExecutorCacheInfo in advance manually.
atexit.register(core.clear_executor_cache) atexit.register(core.clear_executor_cache)
......
...@@ -21,8 +21,6 @@ from .tracer import * ...@@ -21,8 +21,6 @@ from .tracer import *
from . import learning_rate_scheduler from . import learning_rate_scheduler
from .learning_rate_scheduler import * from .learning_rate_scheduler import *
from .math_op_patch import monkey_patch_math_varbase
__all__ = [] __all__ = []
__all__ += base.__all__ __all__ += base.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
...@@ -65,7 +65,7 @@ _complex_dtypes = [ ...@@ -65,7 +65,7 @@ _complex_dtypes = [
_already_patch_eager_tensor = False _already_patch_eager_tensor = False
def monkey_patch_math_varbase(): def monkey_patch_math_tensor():
""" """
Similar to monkey_patch_variable. Similar to monkey_patch_variable.
The difference is, in dygraph mode, use auto-generated op functions for better performance. The difference is, in dygraph mode, use auto-generated op functions for better performance.
...@@ -248,7 +248,7 @@ def monkey_patch_math_varbase(): ...@@ -248,7 +248,7 @@ def monkey_patch_math_varbase():
# do nothing # do nothing
pass pass
# 2. create varbase for scalar # 2. create Tensor for scalar
lhs_dtype = self.dtype lhs_dtype = self.dtype
other_var_should_be = core.eager.Tensor other_var_should_be = core.eager.Tensor
if not isinstance(other_var, other_var_should_be): if not isinstance(other_var, other_var_should_be):
...@@ -343,7 +343,7 @@ def monkey_patch_math_varbase(): ...@@ -343,7 +343,7 @@ def monkey_patch_math_varbase():
__impl__.__name__ = method_name __impl__.__name__ = method_name
return __impl__ return __impl__
varbase_methods = [ tensor_methods = [
('__neg__', _neg_), ('__neg__', _neg_),
('__float__', _float_), ('__float__', _float_),
('__long__', _long_), ('__long__', _long_),
...@@ -498,7 +498,7 @@ def monkey_patch_math_varbase(): ...@@ -498,7 +498,7 @@ def monkey_patch_math_varbase():
setattr(local_tensor, method_name, method_impl) setattr(local_tensor, method_name, method_impl)
else: else:
for method in varbase_methods: for method in tensor_methods:
method_name = method[0] method_name = method[0]
method_impl = method[1] method_impl = method[1]
setattr(local_tensor, method_name, method_impl) setattr(local_tensor, method_name, method_impl)
......
...@@ -32,7 +32,7 @@ from ..framework import ( ...@@ -32,7 +32,7 @@ from ..framework import (
in_dygraph_mode, in_dygraph_mode,
) )
from .base import switch_to_static_graph from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase from .math_op_patch import monkey_patch_math_tensor
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
import paddle.utils.deprecated as deprecated import paddle.utils.deprecated as deprecated
import paddle.profiler as profiler import paddle.profiler as profiler
...@@ -86,7 +86,7 @@ class TensorHookRemoveHelper: ...@@ -86,7 +86,7 @@ class TensorHookRemoveHelper:
_already_patch_repr = False _already_patch_repr = False
def monkey_patch_varbase(): def monkey_patch_tensor():
@switch_to_static_graph @switch_to_static_graph
def _to_static_var(self, to_parameter=False, **kwargs): def _to_static_var(self, to_parameter=False, **kwargs):
""" """
...@@ -110,8 +110,8 @@ def monkey_patch_varbase(): ...@@ -110,8 +110,8 @@ def monkey_patch_varbase():
data = np.ones([3, 1024], dtype='float32') data = np.ones([3, 1024], dtype='float32')
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var_base = to_variable(data) tensor = to_variable(data)
static_var = var_base._to_static_var() static_var = tensor._to_static_var()
""" """
...@@ -700,11 +700,11 @@ def monkey_patch_varbase(): ...@@ -700,11 +700,11 @@ def monkey_patch_varbase():
raise RuntimeError( raise RuntimeError(
"Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy" "Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy"
) )
new_varbase = core.eager.Tensor() new_tensor = core.eager.Tensor()
new_varbase.name = self.name + unique_name.generate("_deepcopy") new_tensor.name = self.name + unique_name.generate("_deepcopy")
memo[id(self)] = new_varbase memo[id(self)] = new_tensor
new_varbase.copy_(self, True) new_tensor.copy_(self, True)
return new_varbase return new_tensor
@property @property
def block(self): def block(self):
...@@ -1073,5 +1073,5 @@ def monkey_patch_varbase(): ...@@ -1073,5 +1073,5 @@ def monkey_patch_varbase():
setattr(core.VarDesc.VarType, "__str__", dtype_str) setattr(core.VarDesc.VarType, "__str__", dtype_str)
_already_patch_repr = True _already_patch_repr = True
# patch math methods for varbase # patch math methods for tensor
monkey_patch_math_varbase() monkey_patch_math_tensor()
...@@ -112,8 +112,6 @@ _global_expected_place_ = None ...@@ -112,8 +112,6 @@ _global_expected_place_ = None
_current_device = None _current_device = None
global_prog_seed = 0 global_prog_seed = 0
_current_pipeline_stage = None _current_pipeline_stage = None
_already_patch_eager_tensor = False
_already_patch_varbase = False
_current_cuda_graph_mode = None _current_cuda_graph_mode = None
_global_flags_ = core.globals() _global_flags_ = core.globals()
...@@ -182,35 +180,6 @@ extra_op_attrs = { ...@@ -182,35 +180,6 @@ extra_op_attrs = {
# to make sure in most case, we find new dygraph mode first with only one if statement. # to make sure in most case, we find new dygraph mode first with only one if statement.
def _update_monkey_methods():
"""
Update monkey methods of Tensor or eager.Tensor while
switching eager mode and legacy mode.
"""
from paddle import _C_ops, _legacy_C_ops
from .dygraph.varbase_patch_methods import monkey_patch_varbase
from .dygraph import monkey_patch_math_varbase
global _already_patch_eager_tensor
global _already_patch_varbase
if not _already_patch_eager_tensor:
monkey_patch_varbase()
monkey_patch_math_varbase()
_already_patch_eager_tensor = True
# switch Paddle.Tensor bind type
_switch_tensor_bind_type()
def _switch_tensor_bind_type():
import paddle
paddle.Tensor = core.eager.Tensor
paddle.Tensor.__qualname__ = 'Tensor'
def _in_eager_without_dygraph_check(): def _in_eager_without_dygraph_check():
return global_var._in_eager_mode_ return global_var._in_eager_mode_
......
...@@ -118,7 +118,7 @@ class TestDeprecatedDocorator(unittest.TestCase): ...@@ -118,7 +118,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
grad = x.gradient() grad = x.gradient()
assert ( assert (
'API "paddle.fluid.dygraph.varbase_patch_methods.gradient" is ' 'API "paddle.fluid.dygraph.tensor_patch_methods.gradient" is '
'deprecated since 2.1.0' 'deprecated since 2.1.0'
) in str(w[-1].message) ) in str(w[-1].message)
......
...@@ -44,8 +44,11 @@ from .io_utils import _pack_loaded_dict ...@@ -44,8 +44,11 @@ from .io_utils import _pack_loaded_dict
from .io_utils import _unpack_saved_dict from .io_utils import _unpack_saved_dict
from .io_utils import _load_program_scope from .io_utils import _load_program_scope
from ..fluid import monkey_patch_variable # Do the *DUPLICATED* monkey-patch for the tensor object.
from ..fluid.dygraph import monkey_patch_math_varbase # We need remove the duplicated code here once we fix
# the illogical implement in the monkey-patch methods later.
from ..fluid.layers.math_op_patch import monkey_patch_variable
from ..fluid.dygraph.math_op_patch import monkey_patch_math_tensor
from ..fluid.framework import disable_signal_handler # noqa: F401 from ..fluid.framework import disable_signal_handler # noqa: F401
from ..fluid.framework import get_flags # noqa: F401 from ..fluid.framework import get_flags # noqa: F401
from ..fluid.framework import set_flags # noqa: F401 from ..fluid.framework import set_flags # noqa: F401
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册