math_op_patch.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .. import core
16 17 18 19 20
from ..framework import (
    Variable,
    convert_np_dtype_to_dtype_,
    in_dygraph_mode,
)
21
from ..framework import _create_tensor as framework_create_tensor
22
from ..layers.layer_function_generator import OpProtoHolder
23
from . import no_grad
J
Jiabin Yang 已提交
24
from .. import framework
25

26
import numpy as np
27
import warnings
28
from paddle import _C_ops, _legacy_C_ops
29

30 31 32 33 34 35
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
36
    core.VarDesc.VarType.BOOL,
37 38
]

39 40 41 42
# NOTE(chenweihang): We currently do not fully support the type promotion
# between tensors. Parting support here is because the interoperation of
# real and complex numbers in paddle quantum is very frequent, such as the
# binary operation between `float` and `complex64`, so we must support the
43 44 45 46 47 48 49 50 51 52
# correct type promotion on the APIs paddle quantum used.
# Now only check in dygraph (paddle quantum based dygraph)
# Full type promotion support will need to be fully verified later.
_supported_promote_complex_types_ = [
    '__add__',
    '__radd__',
    '__sub__',
    '__rsub__',
    '__mul__',
    '__rmul__',
53
    '__div__',
54
    '__truediv__',
55
    '__rdiv__',
56 57 58 59
    '__rtruediv__',
    '__matmul__',
]

60 61 62 63 64
_complex_dtypes = [
    core.VarDesc.VarType.COMPLEX64,
    core.VarDesc.VarType.COMPLEX128,
]

65
_already_patch_eager_tensor = False
66

67

68
def monkey_patch_math_tensor():
69 70 71 72 73 74 75 76
    """
    Similar to monkey_patch_variable.
    The difference is, in dygraph mode, use auto-generated op functions for better performance.
    """

    def astype(self, dtype):
        """

77
        Cast a Tensor to a specified data type.
78 79

        Args:
80
            dtype: The target data type.
81 82

        Returns:
83
            Tensor: a new Tensor with target dtype
84 85 86 87

        Examples:
            .. code-block:: python

88
                import paddle
89 90
                import numpy as np

91 92 93 94
                original_tensor = paddle.ones([2, 2])
                print("original tensor's dtype is: {}".format(original_tensor.dtype))
                new_tensor = original_tensor.astype('float32')
                print("new tensor's dtype is: {}".format(new_tensor.dtype))
95 96

        """
97 98
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
99
        return _C_ops.cast(self, dtype)
100 101

    def _scalar_elementwise_op_(var, scale, bias):
102
        if framework.in_dygraph_mode():
103
            return _C_ops.scale(var, float(scale), bias, True)
姜永久 已提交
104 105
        else:
            return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias)
106

107 108 109
    def _neg_(var):
        return _scalar_elementwise_op_(var, -1.0, 0.0)

110 111
    def _float_(var):
        numel = np.prod(var.shape)
112 113 114
        assert (
            numel == 1
        ), "only one element variable can be converted to float."
115 116
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
117
        return float(np.array(var).flatten()[0])
118 119 120 121 122 123

    def _long_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to long."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
124
        return int(np.array(var).flatten()[0])
125 126 127 128 129 130

    def _int_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to int."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
131
        return int(np.array(var).flatten()[0])
132 133

    def _len_(var):
134
        assert var.ndim > 0, "len() of a 0-D tensor is wrong"
S
Steffy-zxf 已提交
135 136 137 138 139 140
        if var.type == core.VarDesc.VarType.VOCAB:
            return len(var.value().get_map_tensor())
        elif var.type == core.VarDesc.VarType.STRINGS:
            return len(var.value().get_string_tensor())
        else:
            return var.shape[0]
141 142 143

    def _index_(var):
        numel = np.prod(var.shape)
144 145 146
        assert (
            numel == 1
        ), "only one element variable can be converted to python index."
147 148
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
149
        return int(np.array(var).flatten()[0])
150

151 152 153 154
    @property
    def _ndim_(var):
        return len(var.shape)

155 156
    @property
    def _size_(var):
157
        return int(np.prod(var.shape))
158

159 160 161 162 163 164 165
    @property
    def _T_(var):
        if len(var.shape) == 1:
            return var
        perm = []
        for i in range(len(var.shape)):
            perm.insert(0, i)
姜永久 已提交
166
        out = _C_ops.transpose(var, perm)
167 168
        return out

169 170 171 172 173 174 175 176 177 178 179 180 181 182
    eager_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
        ('size', _size_),
        ('T', _T_),
        # for logical compare
183
        ('__array_ufunc__', None),
184 185 186 187 188 189 190
    ]

    eager_cpp_level_patch = [
        "__add__",
        "__radd__",
        '__sub__',
        '__rsub__',
191 192
        '__mul__',
        '__rmul__',
193 194 195 196
        '__div__',
        '__truediv__',
        '__rdiv__',
        '__rtruediv__',
197 198
        '__mod__',
        '__matmul__',
W
Weilong Wu 已提交
199 200
        '__gt__',
        '__ge__',
201 202
        '__lt__',
        '__le__',
W
Weilong Wu 已提交
203
        '__floordiv__',
204 205
        '__pow__',
        '__rpow__',
206
        '__eq__',
207
        '__ne__',
208 209
    ]

210 211
    global _already_patch_eager_tensor

W
wanghuancoder 已提交
212 213 214
    local_already_patch = _already_patch_eager_tensor
    _already_patch_eager_tensor = True
    local_tensor = core.eager.Tensor
215

216
    if not local_already_patch:
217 218 219
        for method_name in eager_cpp_level_patch:
            method_impl = getattr(local_tensor, method_name, None)
            if method_impl:
220 221
                setattr(local_tensor, method_name, method_impl)

222 223 224 225
        for method in eager_methods:
            method_name = method[0]
            method_impl = method[1]
            setattr(local_tensor, method_name, method_impl)
226 227
    else:
        import paddle.tensor
228

229
        # Tensor method from module paddle.tensor
230
        for method_name in paddle.tensor.tensor_method_func:
231 232
            if hasattr(local_tensor, method_name):
                continue
233
            method_impl = getattr(paddle.tensor, method_name, None)
234 235
            if method_impl:
                setattr(local_tensor, method_name, method_impl)
236

237 238
        for magic_method, origin_method in paddle.tensor.magic_method_func:
            impl = getattr(paddle.tensor, origin_method, None)
239 240
            if impl:
                setattr(local_tensor, magic_method, impl)