math_op_patch.py 17.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .. import core
16 17 18 19 20 21
from ..framework import (
    Variable,
    convert_np_dtype_to_dtype_,
    _varbase_creator,
    in_dygraph_mode,
)
22
from ..layers.layer_function_generator import OpProtoHolder
23
from . import no_grad
J
Jiabin Yang 已提交
24
from .. import framework
25

26
import numpy as np
27
import warnings
28
from paddle import _C_ops, _legacy_C_ops
29

30 31 32 33 34 35
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
36
    core.VarDesc.VarType.BOOL,
37 38
]

39 40 41 42
# NOTE(chenweihang): We currently do not fully support the type promotion
# between tensors. Parting support here is because the interoperation of
# real and complex numbers in paddle quantum is very frequent, such as the
# binary operation between `float` and `complex64`, so we must support the
43 44 45 46 47 48 49 50 51 52
# correct type promotion on the APIs paddle quantum used.
# Now only check in dygraph (paddle quantum based dygraph)
# Full type promotion support will need to be fully verified later.
_supported_promote_complex_types_ = [
    '__add__',
    '__radd__',
    '__sub__',
    '__rsub__',
    '__mul__',
    '__rmul__',
53
    '__div__',
54
    '__truediv__',
55
    '__rdiv__',
56 57 58 59
    '__rtruediv__',
    '__matmul__',
]

60 61 62 63 64
_complex_dtypes = [
    core.VarDesc.VarType.COMPLEX64,
    core.VarDesc.VarType.COMPLEX128,
]

65
_already_patch_varbase = False
66
_already_patch_eager_tensor = False
67

68 69 70 71 72 73 74

def monkey_patch_math_varbase():
    """
    Similar to monkey_patch_variable.
    The difference is, in dygraph mode, use auto-generated op functions for better performance.
    """

75
    @no_grad
76
    def create_tensor(value, dtype, shape):
77
        if framework.global_var._in_eager_mode_:
78 79 80
            out = _C_ops.full(
                shape, value, dtype, framework._current_expected_place()
            )
81 82
        else:
            out = _varbase_creator(dtype=dtype)
83 84 85 86 87 88 89 90 91 92 93
            out = _legacy_C_ops.fill_constant(
                out,
                'dtype',
                dtype,
                'shape',
                shape,
                'value',
                value,
                'force_cpu',
                False,
            )
94 95
        out.stop_gradient = True
        return out
96 97

    def create_scalar(value, dtype):
98
        return create_tensor(value, dtype, shape=[])
99 100 101 102

    def astype(self, dtype):
        """

103
        Cast a Tensor to a specified data type.
104 105

        Args:
106
            dtype: The target data type.
107 108

        Returns:
109
            Tensor: a new Tensor with target dtype
110 111 112 113

        Examples:
            .. code-block:: python

114
                import paddle
115 116
                import numpy as np

117 118 119 120
                original_tensor = paddle.ones([2, 2])
                print("original tensor's dtype is: {}".format(original_tensor.dtype))
                new_tensor = original_tensor.astype('float32')
                print("new tensor's dtype is: {}".format(new_tensor.dtype))
121 122

        """
123 124
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
125
        return _C_ops.cast(self, dtype)
126 127

    def _scalar_elementwise_op_(var, scale, bias):
128
        if framework.in_dygraph_mode():
129
            return _C_ops.scale(var, float(scale), bias, True)
姜永久 已提交
130 131
        else:
            return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias)
132

133 134 135
    def _neg_(var):
        return _scalar_elementwise_op_(var, -1.0, 0.0)

136 137
    def _float_(var):
        numel = np.prod(var.shape)
138 139 140
        assert (
            numel == 1
        ), "only one element variable can be converted to float."
141 142
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
143
        return float(var.item())
144 145 146 147 148 149

    def _long_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to long."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
150
        return int(var.item())
151 152 153 154 155 156

    def _int_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to int."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
157
        return int(var.item())
158 159

    def _len_(var):
160
        assert var.ndim > 0, "len() of a 0D tensor is wrong"
S
Steffy-zxf 已提交
161 162 163 164 165 166
        if var.type == core.VarDesc.VarType.VOCAB:
            return len(var.value().get_map_tensor())
        elif var.type == core.VarDesc.VarType.STRINGS:
            return len(var.value().get_string_tensor())
        else:
            return var.shape[0]
167 168 169

    def _index_(var):
        numel = np.prod(var.shape)
170 171 172
        assert (
            numel == 1
        ), "only one element variable can be converted to python index."
173 174
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
175
        return int(var.item())
176

177 178 179 180
    @property
    def _ndim_(var):
        return len(var.shape)

181 182 183 184
    @property
    def _size_(var):
        return np.prod(var.shape)

185 186 187 188 189 190 191
    @property
    def _T_(var):
        if len(var.shape) == 1:
            return var
        perm = []
        for i in range(len(var.shape)):
            perm.insert(0, i)
姜永久 已提交
192
        out = _C_ops.transpose(var, perm)
193 194
        return out

195
    def _scalar_add_(var, value):
196 197
        return _scalar_elementwise_op_(var, 1.0, value)

198
    def _scalar_sub_(var, value):
199 200
        return _scalar_elementwise_op_(var, 1.0, -value)

201
    def _scalar_rsub_(var, value):
202 203
        return _scalar_elementwise_op_(var, -1.0, value)

204
    def _scalar_mul_(var, value):
205 206
        return _scalar_elementwise_op_(var, value, 0.0)

207 208 209
    def _scalar_div_(var, value):
        return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

210
    # for binary operator such as elementwise, compare
211 212 213 214 215 216 217
    def _binary_creator_(
        method_name,
        op_type,
        reverse=False,
        scalar_method=None,
        call_final_api=False,
    ):
218
        def __impl__(self, other_var):
219 220 221 222 223 224 225 226 227
            # 1. scalar exists cases
            # we need combine the tensor.dtype and scalar.dtype, cast correct object
            if isinstance(other_var, float):
                # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
                if self.dtype in _supported_int_dtype_:
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
                # but only +, -, *, / can use this method
                if scalar_method is not None:
228
                    return scalar_method(self, other_var)
229 230 231 232 233 234
            elif isinstance(other_var, int):
                # in all cases(+, -, *, /, **, //, %), we can cast it to float
                # because the output tensor.dtype depend on the type of input tensor
                other_var = float(other_var)
                # division is a special case
                # NOTE(chenweihang): because we cast tensor to float32 instead float64,
235 236 237
                # the division result can only guarantee the numerical accuracy of 6 digits
                # after the decimal point. The result of numpy calculation is of float64 type,
                # so the calculation result here and the calculation result of numpy are
238 239
                # different after 6 decimal point. If necessary, we can also use float64 here.
                # torch's behavior here is consistent with ours
240 241 242
                if (
                    op_type == "divide" or op_type == "elementwise_div"
                ) and self.dtype in _supported_int_dtype_:
243 244
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
245
                # but only +, -, *, / can use this method
246 247 248 249 250
                if scalar_method is not None:
                    return scalar_method(self, other_var)
            else:
                # do nothing
                pass
251

252
            # 2. create varbase for scalar
253
            lhs_dtype = self.dtype
254
            if framework.global_var._in_eager_mode_:
255
                other_var_should_be = core.eager.Tensor
256 257 258
            else:
                other_var_should_be = core.VarBase
            if not isinstance(other_var, other_var_should_be):
259 260
                if isinstance(other_var, complex):
                    import paddle
261

262
                    other_var = paddle.to_tensor(other_var, dtype='complex64')
263
                else:
264
                    if reverse:
265 266 267
                        other_var = create_tensor(
                            other_var, dtype=lhs_dtype, shape=self.shape
                        )
268 269
                    else:
                        # add fill_op
270 271 272
                        other_var = create_scalar(
                            value=other_var, dtype=lhs_dtype
                        )
273

274
            # 3. promote types or unify right var type to left var
275
            rhs_dtype = other_var.dtype
276
            if lhs_dtype != rhs_dtype:
277
                if method_name in _supported_promote_complex_types_ and (
278 279
                    lhs_dtype in _complex_dtypes or rhs_dtype in _complex_dtypes
                ):
280 281 282 283
                    # only when lhs_dtype or rhs_dtype is complex type,
                    # the dtype will promote, in other cases, directly
                    # use lhs_dtype, this is consistent will original rule
                    promote_dtype = core._promote_types_if_complex_exists(
284 285 286 287 288 289 290 291 292 293 294 295
                        lhs_dtype, rhs_dtype
                    )
                    self = (
                        self
                        if lhs_dtype == promote_dtype
                        else astype(self, promote_dtype)
                    )
                    other_var = (
                        other_var
                        if rhs_dtype == promote_dtype
                        else astype(other_var, promote_dtype)
                    )
296
                else:
297
                    warnings.warn(
298 299 300 301
                        'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'.format(
                            lhs_dtype, rhs_dtype, lhs_dtype
                        )
                    )
302 303
                    other_var = astype(other_var, lhs_dtype)

304 305 306 307 308
            if reverse:
                tmp = self
                self = other_var
                other_var = tmp

309 310 311
            if (
                op_type == "divide" or op_type == "elementwise_div"
            ) and self.dtype in _supported_int_dtype_:
312 313 314
                self = astype(self, 'float32')
                other_var = astype(other_var, 'float32')

315
            # 4. calculation
316
            axis = -1
317 318 319 320
            if in_dygraph_mode():
                math_op = getattr(_C_ops, op_type)
            else:
                math_op = getattr(_legacy_C_ops, op_type)
321
            if call_final_api:
322
                if op_type == "matmul":
323
                    return math_op(self, other_var, False, False)
324 325 326 327 328
                if op_type == "pow":
                    if isinstance(other_var, core.eager.Tensor):
                        return _C_ops.elementwise_pow(self, other_var)
                    else:
                        return _C_ops.elementwise_pow(self, other_var)
329 330
                return math_op(self, other_var, -1)
            return math_op(self, other_var, 'axis', axis)
331

332 333 334 335
        if call_final_api:
            comment = ""
        else:
            comment = OpProtoHolder.instance().get_op_proto(op_type).comment
336 337 338 339

        __impl__.__doc__ = """
        {0}
        Args:
340
            other_var(Tensor|float|int): right hand Tensor
341 342

        Returns:
343
            Tensor
344 345 346
        """.format(
            comment
        )
347 348 349
        __impl__.__name__ = method_name
        return __impl__

350 351 352 353 354 355 356 357 358 359 360
    varbase_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
361
        ('size', _size_),
362
        ('T', _T_),
363 364 365 366
        (
            '__add__',
            _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),
        ),
367
        #  a+b == b+a. Do not need to reverse explicitly
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
        (
            '__radd__',
            _binary_creator_(
                '__radd__', 'elementwise_add', False, _scalar_add_
            ),
        ),
        (
            '__sub__',
            _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_),
        ),
        (
            '__rsub__',
            _binary_creator_(
                '__rsub__', 'elementwise_sub', True, _scalar_rsub_
            ),
        ),
        (
            '__mul__',
            _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_),
        ),
388
        ## a*b == b*a. Do not need to reverse explicitly
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
        (
            '__rmul__',
            _binary_creator_(
                '__rmul__', 'elementwise_mul', False, _scalar_mul_
            ),
        ),
        (
            '__div__',
            _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_),
        ),
        (
            '__truediv__',
            _binary_creator_(
                '__truediv__', 'elementwise_div', False, _scalar_div_
            ),
        ),
        (
            '__rdiv__',
            _binary_creator_('__rdiv__', 'elementwise_div', True, None),
        ),
        (
            '__rtruediv__',
            _binary_creator_('rtruediv__', 'elementwise_div', True, None),
        ),
        (
            '__pow__',
            _binary_creator_('__pow__', 'elementwise_pow', False, None),
        ),
        (
            '__rpow__',
            _binary_creator_('__rpow__', 'elementwise_pow', True, None),
        ),
        (
            '__floordiv__',
            _binary_creator_(
                '__floordiv__', 'elementwise_floordiv', False, None
            ),
        ),
        (
            '__mod__',
            _binary_creator_('__mod__', 'elementwise_mod', False, None),
        ),
        (
            '__matmul__',
            _binary_creator_('__matmul__', "matmul_v2", False, None),
        ),
435 436
        ## for logical compare
        ('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
437 438 439 440
        ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
        ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
        ('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
        ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
441
        ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
442
        ('__array_ufunc__', None),
443 444
    ]

445 446 447 448 449 450 451 452 453 454 455 456 457 458
    eager_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
        ('size', _size_),
        ('T', _T_),
        # for logical compare
459
        ('__array_ufunc__', None),
460 461 462 463 464 465 466
    ]

    eager_cpp_level_patch = [
        "__add__",
        "__radd__",
        '__sub__',
        '__rsub__',
467 468
        '__mul__',
        '__rmul__',
469 470 471 472
        '__div__',
        '__truediv__',
        '__rdiv__',
        '__rtruediv__',
473 474
        '__mod__',
        '__matmul__',
W
Weilong Wu 已提交
475 476
        '__gt__',
        '__ge__',
477 478
        '__lt__',
        '__le__',
W
Weilong Wu 已提交
479
        '__floordiv__',
480 481
        '__pow__',
        '__rpow__',
482
        '__eq__',
483
        '__ne__',
484 485
    ]

486
    global _already_patch_varbase
487 488
    global _already_patch_eager_tensor

489
    if framework.global_var._in_eager_mode_:
490 491
        local_already_patch = _already_patch_eager_tensor
        _already_patch_eager_tensor = True
492
        local_tensor = core.eager.Tensor
493 494 495 496
    else:
        local_already_patch = _already_patch_varbase
        _already_patch_varbase = True
        local_tensor = core.VarBase
497

498
    if not local_already_patch:
499
        if framework.global_var._in_eager_mode_:
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
            for method_name in eager_cpp_level_patch:
                method_impl = getattr(local_tensor, method_name, None)
                if method_impl:
                    setattr(local_tensor, method_name, method_impl)

            for method in eager_methods:
                method_name = method[0]
                method_impl = method[1]
                setattr(local_tensor, method_name, method_impl)

        else:
            for method in varbase_methods:
                method_name = method[0]
                method_impl = method[1]
                setattr(local_tensor, method_name, method_impl)
515 516
    else:
        import paddle.tensor
517

518
        # Tensor method from module paddle.tensor
519
        for method_name in paddle.tensor.tensor_method_func:
520 521
            if hasattr(local_tensor, method_name):
                continue
522
            method_impl = getattr(paddle.tensor, method_name, None)
523 524
            if method_impl:
                setattr(local_tensor, method_name, method_impl)
525

526 527
        for magic_method, origin_method in paddle.tensor.magic_method_func:
            impl = getattr(paddle.tensor, origin_method, None)
528 529
            if impl:
                setattr(local_tensor, magic_method, impl)