math_op_patch.py 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .. import core
16 17 18 19 20 21 22
from ..framework import (
    Variable,
    convert_np_dtype_to_dtype_,
    _varbase_creator,
    _in_legacy_dygraph,
    in_dygraph_mode,
)
23
from ..layers.layer_function_generator import OpProtoHolder
24
from . import no_grad
J
Jiabin Yang 已提交
25
from .. import framework
26

27
import numpy as np
28
import warnings
29
from paddle import _C_ops, _legacy_C_ops
30

31 32 33 34 35 36
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
37
    core.VarDesc.VarType.BOOL,
38 39
]

40 41 42 43
# NOTE(chenweihang): We currently do not fully support the type promotion
# between tensors. Parting support here is because the interoperation of
# real and complex numbers in paddle quantum is very frequent, such as the
# binary operation between `float` and `complex64`, so we must support the
44 45 46 47 48 49 50 51 52 53
# correct type promotion on the APIs paddle quantum used.
# Now only check in dygraph (paddle quantum based dygraph)
# Full type promotion support will need to be fully verified later.
_supported_promote_complex_types_ = [
    '__add__',
    '__radd__',
    '__sub__',
    '__rsub__',
    '__mul__',
    '__rmul__',
54
    '__div__',
55
    '__truediv__',
56
    '__rdiv__',
57 58 59 60
    '__rtruediv__',
    '__matmul__',
]

61 62 63 64 65
_complex_dtypes = [
    core.VarDesc.VarType.COMPLEX64,
    core.VarDesc.VarType.COMPLEX128,
]

66
_already_patch_varbase = False
67
_already_patch_eager_tensor = False
68

69 70 71 72 73 74 75

def monkey_patch_math_varbase():
    """
    Similar to monkey_patch_variable.
    The difference is, in dygraph mode, use auto-generated op functions for better performance.
    """

76
    @no_grad
77
    def create_tensor(value, dtype, shape):
78
        if framework._in_eager_mode_:
79 80 81
            out = _C_ops.full(
                shape, value, dtype, framework._current_expected_place()
            )
82 83
        else:
            out = _varbase_creator(dtype=dtype)
84 85 86 87 88 89 90 91 92 93 94
            out = _legacy_C_ops.fill_constant(
                out,
                'dtype',
                dtype,
                'shape',
                shape,
                'value',
                value,
                'force_cpu',
                False,
            )
95 96
        out.stop_gradient = True
        return out
97 98 99 100 101 102 103

    def create_scalar(value, dtype):
        return create_tensor(value, dtype, shape=[1])

    def astype(self, dtype):
        """

104
        Cast a Tensor to a specified data type.
105 106

        Args:
107
            dtype: The target data type.
108 109

        Returns:
110
            Tensor: a new Tensor with target dtype
111 112 113 114

        Examples:
            .. code-block:: python

115
                import paddle
116 117
                import numpy as np

118 119 120 121
                original_tensor = paddle.ones([2, 2])
                print("original tensor's dtype is: {}".format(original_tensor.dtype))
                new_tensor = original_tensor.astype('float32')
                print("new tensor's dtype is: {}".format(new_tensor.dtype))
122 123

        """
124 125
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
126 127

        if _in_legacy_dygraph():
128 129 130
            return _legacy_C_ops.cast(
                self, 'in_dtype', self.dtype, 'out_dtype', dtype
            )
131
        return _C_ops.cast(self, dtype)
132 133

    def _scalar_elementwise_op_(var, scale, bias):
134
        if framework.in_dygraph_mode():
135 136
            return _C_ops.scale(var, float(scale), bias, True)
        return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias)
137

138 139 140
    def _neg_(var):
        return _scalar_elementwise_op_(var, -1.0, 0.0)

141 142
    def _float_(var):
        numel = np.prod(var.shape)
143 144 145
        assert (
            numel == 1
        ), "only one element variable can be converted to float."
146 147 148 149 150 151 152 153 154
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        return float(var.numpy().flatten()[0])

    def _long_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to long."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
T
tianshuo78520a 已提交
155
        return int(var.numpy().flatten()[0])
156 157 158 159 160 161 162 163 164

    def _int_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to int."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        return int(var.numpy().flatten()[0])

    def _len_(var):
165
        assert var.ndim > 0, "len() of a 0D tensor is wrong"
S
Steffy-zxf 已提交
166 167 168 169 170 171
        if var.type == core.VarDesc.VarType.VOCAB:
            return len(var.value().get_map_tensor())
        elif var.type == core.VarDesc.VarType.STRINGS:
            return len(var.value().get_string_tensor())
        else:
            return var.shape[0]
172 173 174

    def _index_(var):
        numel = np.prod(var.shape)
175 176 177
        assert (
            numel == 1
        ), "only one element variable can be converted to python index."
178 179
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
T
tianshuo78520a 已提交
180
        return int(var.numpy().flatten()[0])
181

182 183 184 185
    @property
    def _ndim_(var):
        return len(var.shape)

186 187 188 189
    @property
    def _size_(var):
        return np.prod(var.shape)

190 191 192 193 194 195 196
    @property
    def _T_(var):
        if len(var.shape) == 1:
            return var
        perm = []
        for i in range(len(var.shape)):
            perm.insert(0, i)
197
        if _in_legacy_dygraph():
198
            out, _ = _legacy_C_ops.transpose2(var, 'axis', perm)
199
        else:
200
            out = _C_ops.transpose(var, perm)
201 202
        return out

203
    def _scalar_add_(var, value):
204 205
        return _scalar_elementwise_op_(var, 1.0, value)

206
    def _scalar_sub_(var, value):
207 208
        return _scalar_elementwise_op_(var, 1.0, -value)

209
    def _scalar_rsub_(var, value):
210 211
        return _scalar_elementwise_op_(var, -1.0, value)

212
    def _scalar_mul_(var, value):
213 214
        return _scalar_elementwise_op_(var, value, 0.0)

215 216 217
    def _scalar_div_(var, value):
        return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

218
    # for binary operator such as elementwise, compare
219 220 221 222 223 224 225
    def _binary_creator_(
        method_name,
        op_type,
        reverse=False,
        scalar_method=None,
        call_final_api=False,
    ):
226
        def __impl__(self, other_var):
227 228 229 230 231 232 233 234 235
            # 1. scalar exists cases
            # we need combine the tensor.dtype and scalar.dtype, cast correct object
            if isinstance(other_var, float):
                # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
                if self.dtype in _supported_int_dtype_:
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
                # but only +, -, *, / can use this method
                if scalar_method is not None:
236
                    return scalar_method(self, other_var)
237 238 239 240 241 242
            elif isinstance(other_var, int):
                # in all cases(+, -, *, /, **, //, %), we can cast it to float
                # because the output tensor.dtype depend on the type of input tensor
                other_var = float(other_var)
                # division is a special case
                # NOTE(chenweihang): because we cast tensor to float32 instead float64,
243 244 245
                # the division result can only guarantee the numerical accuracy of 6 digits
                # after the decimal point. The result of numpy calculation is of float64 type,
                # so the calculation result here and the calculation result of numpy are
246 247
                # different after 6 decimal point. If necessary, we can also use float64 here.
                # torch's behavior here is consistent with ours
248 249 250
                if (
                    op_type == "divide" or op_type == "elementwise_div"
                ) and self.dtype in _supported_int_dtype_:
251 252
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
253
                # but only +, -, *, / can use this method
254 255 256 257 258
                if scalar_method is not None:
                    return scalar_method(self, other_var)
            else:
                # do nothing
                pass
259

260
            # 2. create varbase for scalar
261
            lhs_dtype = self.dtype
J
Jiabin Yang 已提交
262
            if framework._in_eager_mode_:
263
                other_var_should_be = core.eager.Tensor
264 265 266
            else:
                other_var_should_be = core.VarBase
            if not isinstance(other_var, other_var_should_be):
267 268
                if isinstance(other_var, complex):
                    import paddle
269

270
                    other_var = paddle.to_tensor(other_var, dtype='complex64')
271
                else:
272
                    if reverse:
273 274 275
                        other_var = create_tensor(
                            other_var, dtype=lhs_dtype, shape=self.shape
                        )
276 277
                    else:
                        # add fill_op
278 279 280
                        other_var = create_scalar(
                            value=other_var, dtype=lhs_dtype
                        )
281

282
            # 3. promote types or unify right var type to left var
283
            rhs_dtype = other_var.dtype
284
            if lhs_dtype != rhs_dtype:
285
                if method_name in _supported_promote_complex_types_ and (
286 287
                    lhs_dtype in _complex_dtypes or rhs_dtype in _complex_dtypes
                ):
288 289 290 291
                    # only when lhs_dtype or rhs_dtype is complex type,
                    # the dtype will promote, in other cases, directly
                    # use lhs_dtype, this is consistent will original rule
                    promote_dtype = core._promote_types_if_complex_exists(
292 293 294 295 296 297 298 299 300 301 302 303
                        lhs_dtype, rhs_dtype
                    )
                    self = (
                        self
                        if lhs_dtype == promote_dtype
                        else astype(self, promote_dtype)
                    )
                    other_var = (
                        other_var
                        if rhs_dtype == promote_dtype
                        else astype(other_var, promote_dtype)
                    )
304
                else:
305
                    warnings.warn(
306 307 308 309
                        'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'.format(
                            lhs_dtype, rhs_dtype, lhs_dtype
                        )
                    )
310 311
                    other_var = astype(other_var, lhs_dtype)

312 313 314 315 316
            if reverse:
                tmp = self
                self = other_var
                other_var = tmp

317 318 319
            if (
                op_type == "divide" or op_type == "elementwise_div"
            ) and self.dtype in _supported_int_dtype_:
320 321 322
                self = astype(self, 'float32')
                other_var = astype(other_var, 'float32')

323
            # 4. calculation
324
            axis = -1
325 326 327 328
            if in_dygraph_mode():
                math_op = getattr(_C_ops, op_type)
            else:
                math_op = getattr(_legacy_C_ops, op_type)
329
            if call_final_api:
330
                if op_type == "matmul":
331
                    return math_op(self, other_var, False, False)
332 333 334 335 336
                if op_type == "pow":
                    if isinstance(other_var, core.eager.Tensor):
                        return _C_ops.elementwise_pow(self, other_var)
                    else:
                        return _C_ops.elementwise_pow(self, other_var)
337 338
                return math_op(self, other_var, -1)
            return math_op(self, other_var, 'axis', axis)
339

340 341 342 343
        if call_final_api:
            comment = ""
        else:
            comment = OpProtoHolder.instance().get_op_proto(op_type).comment
344 345 346 347

        __impl__.__doc__ = """
        {0}
        Args:
348
            other_var(Tensor|float|int): right hand Tensor
349 350

        Returns:
351
            Tensor
352 353 354
        """.format(
            comment
        )
355 356 357
        __impl__.__name__ = method_name
        return __impl__

358 359 360 361 362 363 364 365 366 367 368
    varbase_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
369
        ('size', _size_),
370
        ('T', _T_),
371 372 373 374
        (
            '__add__',
            _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),
        ),
375
        #  a+b == b+a. Do not need to reverse explicitly
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
        (
            '__radd__',
            _binary_creator_(
                '__radd__', 'elementwise_add', False, _scalar_add_
            ),
        ),
        (
            '__sub__',
            _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_),
        ),
        (
            '__rsub__',
            _binary_creator_(
                '__rsub__', 'elementwise_sub', True, _scalar_rsub_
            ),
        ),
        (
            '__mul__',
            _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_),
        ),
396
        ## a*b == b*a. Do not need to reverse explicitly
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
        (
            '__rmul__',
            _binary_creator_(
                '__rmul__', 'elementwise_mul', False, _scalar_mul_
            ),
        ),
        (
            '__div__',
            _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_),
        ),
        (
            '__truediv__',
            _binary_creator_(
                '__truediv__', 'elementwise_div', False, _scalar_div_
            ),
        ),
        (
            '__rdiv__',
            _binary_creator_('__rdiv__', 'elementwise_div', True, None),
        ),
        (
            '__rtruediv__',
            _binary_creator_('rtruediv__', 'elementwise_div', True, None),
        ),
        (
            '__pow__',
            _binary_creator_('__pow__', 'elementwise_pow', False, None),
        ),
        (
            '__rpow__',
            _binary_creator_('__rpow__', 'elementwise_pow', True, None),
        ),
        (
            '__floordiv__',
            _binary_creator_(
                '__floordiv__', 'elementwise_floordiv', False, None
            ),
        ),
        (
            '__mod__',
            _binary_creator_('__mod__', 'elementwise_mod', False, None),
        ),
        (
            '__matmul__',
            _binary_creator_('__matmul__', "matmul_v2", False, None),
        ),
443 444
        ## for logical compare
        ('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
445 446 447 448
        ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
        ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
        ('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
        ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
449
        ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
450
        ('__array_ufunc__', None),
451 452
    ]

453 454 455 456 457 458 459 460 461 462 463 464 465 466
    eager_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
        ('size', _size_),
        ('T', _T_),
        # for logical compare
467
        ('__array_ufunc__', None),
468 469 470 471 472 473 474
    ]

    eager_cpp_level_patch = [
        "__add__",
        "__radd__",
        '__sub__',
        '__rsub__',
475 476
        '__mul__',
        '__rmul__',
477 478 479 480
        '__div__',
        '__truediv__',
        '__rdiv__',
        '__rtruediv__',
481 482
        '__mod__',
        '__matmul__',
W
Weilong Wu 已提交
483 484
        '__gt__',
        '__ge__',
485 486
        '__lt__',
        '__le__',
W
Weilong Wu 已提交
487
        '__floordiv__',
488 489
        '__pow__',
        '__rpow__',
490
        '__eq__',
491
        '__ne__',
492 493
    ]

494
    global _already_patch_varbase
495 496
    global _already_patch_eager_tensor

J
Jiabin Yang 已提交
497
    if framework._in_eager_mode_:
498 499
        local_already_patch = _already_patch_eager_tensor
        _already_patch_eager_tensor = True
500
        local_tensor = core.eager.Tensor
501 502 503 504
    else:
        local_already_patch = _already_patch_varbase
        _already_patch_varbase = True
        local_tensor = core.VarBase
505

506
    if not local_already_patch:
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
        if framework._in_eager_mode_:
            for method_name in eager_cpp_level_patch:
                method_impl = getattr(local_tensor, method_name, None)
                if method_impl:
                    setattr(local_tensor, method_name, method_impl)

            for method in eager_methods:
                method_name = method[0]
                method_impl = method[1]
                setattr(local_tensor, method_name, method_impl)

        else:
            for method in varbase_methods:
                method_name = method[0]
                method_impl = method[1]
                setattr(local_tensor, method_name, method_impl)
523 524
    else:
        import paddle.tensor
525

526
        # Tensor method from module paddle.tensor
527
        for method_name in paddle.tensor.tensor_method_func:
528 529
            if hasattr(local_tensor, method_name):
                continue
530
            method_impl = getattr(paddle.tensor, method_name, None)
531 532
            if method_impl:
                setattr(local_tensor, method_name, method_impl)
533

534 535
        for magic_method, origin_method in paddle.tensor.magic_method_func:
            impl = getattr(paddle.tensor, origin_method, None)
536 537
            if impl:
                setattr(local_tensor, magic_method, impl)