math_op_patch.py 18.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from .. import core
18
from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator, _in_legacy_dygraph, in_dygraph_mode
19
from ..layers.layer_function_generator import OpProtoHolder
20
from . import no_grad
J
Jiabin Yang 已提交
21
from .. import framework
22

23
import numpy as np
24
import warnings
25
from paddle import _C_ops, _legacy_C_ops
26

27 28 29 30 31 32
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
33
    core.VarDesc.VarType.BOOL,
34 35
]

36 37 38 39
# NOTE(chenweihang): We currently do not fully support the type promotion
# between tensors. Parting support here is because the interoperation of
# real and complex numbers in paddle quantum is very frequent, such as the
# binary operation between `float` and `complex64`, so we must support the
40 41 42 43 44 45 46 47 48 49
# correct type promotion on the APIs paddle quantum used.
# Now only check in dygraph (paddle quantum based dygraph)
# Full type promotion support will need to be fully verified later.
_supported_promote_complex_types_ = [
    '__add__',
    '__radd__',
    '__sub__',
    '__rsub__',
    '__mul__',
    '__rmul__',
50
    '__div__',
51
    '__truediv__',
52
    '__rdiv__',
53 54 55 56
    '__rtruediv__',
    '__matmul__',
]

57 58 59 60 61
_complex_dtypes = [
    core.VarDesc.VarType.COMPLEX64,
    core.VarDesc.VarType.COMPLEX128,
]

62
_already_patch_varbase = False
63
_already_patch_eager_tensor = False
64

65 66 67 68 69 70 71

def monkey_patch_math_varbase():
    """
    Similar to monkey_patch_variable.
    The difference is, in dygraph mode, use auto-generated op functions for better performance.
    """

72
    @no_grad
73
    def create_tensor(value, dtype, shape):
74
        if framework._in_eager_mode_:
75 76
            out = _C_ops.full(shape, value, dtype,
                              framework._current_expected_place())
77 78
        else:
            out = _varbase_creator(dtype=dtype)
79 80 81
            out = _legacy_C_ops.fill_constant(out, 'dtype', dtype, 'shape',
                                              shape, 'value', value,
                                              'force_cpu', False)
82 83
        out.stop_gradient = True
        return out
84 85 86 87 88 89 90

    def create_scalar(value, dtype):
        return create_tensor(value, dtype, shape=[1])

    def astype(self, dtype):
        """

91
        Cast a Tensor to a specified data type.
92 93

        Args:
94
            dtype: The target data type.
95 96

        Returns:
97
            Tensor: a new Tensor with target dtype
98 99 100 101

        Examples:
            .. code-block:: python

102
                import paddle
103 104
                import numpy as np

105 106 107 108
                original_tensor = paddle.ones([2, 2])
                print("original tensor's dtype is: {}".format(original_tensor.dtype))
                new_tensor = original_tensor.astype('float32')
                print("new tensor's dtype is: {}".format(new_tensor.dtype))
109 110

        """
111 112
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
113 114

        if _in_legacy_dygraph():
115 116 117
            return _legacy_C_ops.cast(self, 'in_dtype', self.dtype, 'out_dtype',
                                      dtype)
        return _C_ops.cast(self, dtype)
118 119

    def _scalar_elementwise_op_(var, scale, bias):
120
        if framework.in_dygraph_mode():
121 122
            return _C_ops.scale(var, float(scale), bias, True)
        return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias)
123

124 125 126
    def _neg_(var):
        return _scalar_elementwise_op_(var, -1.0, 0.0)

127 128 129 130 131 132 133 134 135 136 137 138
    def _float_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to float."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        return float(var.numpy().flatten()[0])

    def _long_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to long."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
T
tianshuo78520a 已提交
139
        return int(var.numpy().flatten()[0])
140 141 142 143 144 145 146 147 148

    def _int_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to int."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        return int(var.numpy().flatten()[0])

    def _len_(var):
S
Steffy-zxf 已提交
149 150 151 152 153 154
        if var.type == core.VarDesc.VarType.VOCAB:
            return len(var.value().get_map_tensor())
        elif var.type == core.VarDesc.VarType.STRINGS:
            return len(var.value().get_string_tensor())
        else:
            return var.shape[0]
155 156 157 158 159 160

    def _index_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to python index."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
T
tianshuo78520a 已提交
161
        return int(var.numpy().flatten()[0])
162

163 164 165 166
    @property
    def _ndim_(var):
        return len(var.shape)

167 168 169 170
    @property
    def _size_(var):
        return np.prod(var.shape)

171 172 173 174 175 176 177
    @property
    def _T_(var):
        if len(var.shape) == 1:
            return var
        perm = []
        for i in range(len(var.shape)):
            perm.insert(0, i)
178
        if _in_legacy_dygraph():
179
            out, _ = _legacy_C_ops.transpose2(var, 'axis', perm)
180
        else:
181
            out = _C_ops.transpose(var, perm)
182 183
        return out

184
    def _scalar_add_(var, value):
185 186
        return _scalar_elementwise_op_(var, 1.0, value)

187
    def _scalar_sub_(var, value):
188 189
        return _scalar_elementwise_op_(var, 1.0, -value)

190
    def _scalar_rsub_(var, value):
191 192
        return _scalar_elementwise_op_(var, -1.0, value)

193
    def _scalar_mul_(var, value):
194 195
        return _scalar_elementwise_op_(var, value, 0.0)

196 197 198
    def _scalar_div_(var, value):
        return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

199 200 201 202
    # for binary operator such as elementwise, compare
    def _binary_creator_(method_name,
                         op_type,
                         reverse=False,
203 204
                         scalar_method=None,
                         call_final_api=False):
205

206
        def __impl__(self, other_var):
207 208 209 210 211 212 213 214 215
            # 1. scalar exists cases
            # we need combine the tensor.dtype and scalar.dtype, cast correct object
            if isinstance(other_var, float):
                # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
                if self.dtype in _supported_int_dtype_:
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
                # but only +, -, *, / can use this method
                if scalar_method is not None:
216
                    return scalar_method(self, other_var)
217 218 219 220 221 222
            elif isinstance(other_var, int):
                # in all cases(+, -, *, /, **, //, %), we can cast it to float
                # because the output tensor.dtype depend on the type of input tensor
                other_var = float(other_var)
                # division is a special case
                # NOTE(chenweihang): because we cast tensor to float32 instead float64,
223 224 225
                # the division result can only guarantee the numerical accuracy of 6 digits
                # after the decimal point. The result of numpy calculation is of float64 type,
                # so the calculation result here and the calculation result of numpy are
226 227
                # different after 6 decimal point. If necessary, we can also use float64 here.
                # torch's behavior here is consistent with ours
228
                if (op_type == "divide" or op_type == "elementwise_div"
229
                    ) and self.dtype in _supported_int_dtype_:
230 231
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
232
                # but only +, -, *, / can use this method
233 234 235 236 237
                if scalar_method is not None:
                    return scalar_method(self, other_var)
            else:
                # do nothing
                pass
238

239
            # 2. create varbase for scalar
240
            lhs_dtype = self.dtype
J
Jiabin Yang 已提交
241
            if framework._in_eager_mode_:
242
                other_var_should_be = core.eager.Tensor
243 244 245
            else:
                other_var_should_be = core.VarBase
            if not isinstance(other_var, other_var_should_be):
246 247 248
                if isinstance(other_var, complex):
                    import paddle
                    other_var = paddle.to_tensor(other_var, dtype='complex64')
249
                else:
250
                    if reverse:
251 252 253
                        other_var = create_tensor(other_var,
                                                  dtype=lhs_dtype,
                                                  shape=self.shape)
254 255
                    else:
                        # add fill_op
256 257
                        other_var = create_scalar(value=other_var,
                                                  dtype=lhs_dtype)
258

259
            # 3. promote types or unify right var type to left var
260
            rhs_dtype = other_var.dtype
261
            if lhs_dtype != rhs_dtype:
262
                if method_name in _supported_promote_complex_types_ and (
263 264
                        lhs_dtype in _complex_dtypes
                        or rhs_dtype in _complex_dtypes):
265 266 267 268 269 270 271 272 273 274
                    # only when lhs_dtype or rhs_dtype is complex type,
                    # the dtype will promote, in other cases, directly
                    # use lhs_dtype, this is consistent will original rule
                    promote_dtype = core._promote_types_if_complex_exists(
                        lhs_dtype, rhs_dtype)
                    self = self if lhs_dtype == promote_dtype else astype(
                        self, promote_dtype)
                    other_var = other_var if rhs_dtype == promote_dtype else astype(
                        other_var, promote_dtype)
                else:
275
                    warnings.warn(
276 277
                        'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'
                        .format(lhs_dtype, rhs_dtype, lhs_dtype))
278 279
                    other_var = astype(other_var, lhs_dtype)

280 281 282 283 284
            if reverse:
                tmp = self
                self = other_var
                other_var = tmp

285
            if (op_type == "divide" or op_type == "elementwise_div"
286
                ) and self.dtype in _supported_int_dtype_:
287 288 289
                self = astype(self, 'float32')
                other_var = astype(other_var, 'float32')

290
            # 4. calculation
291
            axis = -1
292 293 294 295
            if in_dygraph_mode():
                math_op = getattr(_C_ops, op_type)
            else:
                math_op = getattr(_legacy_C_ops, op_type)
296
            if call_final_api:
297
                if op_type == "matmul":
298
                    return math_op(self, other_var, False, False)
299 300 301 302 303
                if op_type == "pow":
                    if isinstance(other_var, core.eager.Tensor):
                        return _C_ops.elementwise_pow(self, other_var)
                    else:
                        return _C_ops.elementwise_pow(self, other_var)
304 305
                return math_op(self, other_var, -1)
            return math_op(self, other_var, 'axis', axis)
306

307 308 309 310
        if call_final_api:
            comment = ""
        else:
            comment = OpProtoHolder.instance().get_op_proto(op_type).comment
311 312 313 314

        __impl__.__doc__ = """
        {0}
        Args:
315
            other_var(Tensor|float|int): right hand Tensor
316 317

        Returns:
318
            Tensor
319 320 321 322
        """.format(comment)
        __impl__.__name__ = method_name
        return __impl__

323 324 325 326 327 328 329 330 331 332 333
    varbase_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
334
        ('size', _size_),
335
        ('T', _T_),
336 337
        ('__add__', _binary_creator_('__add__', 'add', False, _scalar_add_,
                                     True)) if framework._in_eager_mode_ else
338 339
        ('__add__',
         _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
340
        ##  a+b == b+a. Do not need to reverse explicitly
341
        ('__radd__',
342 343
         _binary_creator_('__radd__', 'add', False, _scalar_add_, True))
        if framework._in_eager_mode_ else
344 345 346
        ('__radd__',
         _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
        ('__sub__',
347 348
         _binary_creator_('__sub__', 'subtract', False, _scalar_sub_, True))
        if framework._in_eager_mode_ else
349 350 351
        ('__sub__',
         _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_)),
        ('__rsub__',
352
         _binary_creator_('__rsub__', 'subtract', True, _scalar_rsub_, True))
353 354 355 356
        if framework._in_eager_mode_ else
        ('__rsub__',
         _binary_creator_('__rsub__', 'elementwise_sub', True, _scalar_rsub_)),
        ('__mul__',
357 358
         _binary_creator_('__mul__', 'multiply', False, _scalar_mul_, True))
        if framework._in_eager_mode_ else
359 360
        ('__mul__',
         _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_)),
361
        ## a*b == b*a. Do not need to reverse explicitly
362
        ('__rmul__',
363 364
         _binary_creator_('__rmul__', 'multiply', False, _scalar_mul_, True))
        if framework._in_eager_mode_ else
365 366 367
        ('__rmul__',
         _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
        ('__div__',
368 369
         _binary_creator_('__div__', 'divide', False, _scalar_div_, True))
        if framework._in_eager_mode_ else
370 371 372
        ('__div__',
         _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_)),
        ('__truediv__',
373 374
         _binary_creator_('__truediv__', 'divide', False, _scalar_div_, True))
        if framework._in_eager_mode_ else
375 376 377
        ('__truediv__',
         _binary_creator_('__truediv__', 'elementwise_div', False,
                          _scalar_div_)),
378
        ('__rdiv__', _binary_creator_('__rdiv__', 'divide', True, None, True))
379 380 381 382
        if framework._in_eager_mode_ else
        ('__rdiv__',
         _binary_creator_('__rdiv__', 'elementwise_div', True, None)),
        ('__rtruediv__',
383
         _binary_creator_('rtruediv__', 'divide', True, None, True))
384 385 386
        if framework._in_eager_mode_ else
        ('__rtruediv__',
         _binary_creator_('rtruediv__', 'elementwise_div', True, None)),
387
        ('__pow__', _binary_creator_('__pow__', 'pow', False, _C_ops.pow, True))
388
        if framework._in_eager_mode_ else
389 390
        ('__pow__',
         _binary_creator_('__pow__', 'elementwise_pow', False, None)),
391 392
        ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
                                      None)),
393 394 395
        ('__floordiv__',
         _binary_creator_('__floordiv__', 'floor_divide', False, None, True))
        if framework._in_eager_mode_ else
396 397
        ('__floordiv__',
         _binary_creator_('__floordiv__', 'elementwise_floordiv', False, None)),
398
        ('__mod__', _binary_creator_('__mod__', 'modulo', False, None, True))
399 400 401 402
        if framework._in_eager_mode_ else
        ('__mod__',
         _binary_creator_('__mod__', 'elementwise_mod', False, None)),
        ('__matmul__',
403 404
         _binary_creator_('__matmul__', "matmul", False, None, True))
        if framework._in_eager_mode_ else
405 406
        ('__matmul__',
         _binary_creator_('__matmul__', "matmul_v2", False, None)),
407
        ## for logical compare
408
        ('__eq__', _binary_creator_('__eq__', 'equal', False, None, True))
409
        if framework._in_eager_mode_ else
410
        ('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
411
        ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None, True))
412
        if framework._in_eager_mode_ else
413
        ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
414
        ('__lt__', _binary_creator_('__lt__', 'less_than', False, None, True))
415
        if framework._in_eager_mode_ else
416
        ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
417 418
        ('__le__', _binary_creator_('__le__', 'less_equal', False, None, True))
        if framework._in_eager_mode_ else
419
        ('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
420 421
        ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None,
                                    True)) if framework._in_eager_mode_ else
422
        ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
423 424
        ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None,
                                    True)) if framework._in_eager_mode_ else
425
        ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
426
        ('__array_ufunc__', None)
427 428 429
    ]

    global _already_patch_varbase
430 431
    global _already_patch_eager_tensor

J
Jiabin Yang 已提交
432
    if framework._in_eager_mode_:
433 434
        local_already_patch = _already_patch_eager_tensor
        _already_patch_eager_tensor = True
435
        local_tensor = core.eager.Tensor
436 437 438 439
    else:
        local_already_patch = _already_patch_varbase
        _already_patch_varbase = True
        local_tensor = core.VarBase
440

441
    if not local_already_patch:
442 443 444
        for method in varbase_methods:
            method_name = method[0]
            method_impl = method[1]
445
            setattr(local_tensor, method_name, method_impl)
446 447
    else:
        import paddle.tensor
448
        # Tensor method from module paddle.tensor
449
        for method_name in paddle.tensor.tensor_method_func:
450
            if hasattr(local_tensor, method_name): continue
451
            method_impl = getattr(paddle.tensor, method_name, None)
452
            if method_impl: setattr(local_tensor, method_name, method_impl)
453

454 455
        for magic_method, origin_method in paddle.tensor.magic_method_func:
            impl = getattr(paddle.tensor, origin_method, None)
456
            if impl: setattr(local_tensor, magic_method, impl)