math_op_patch.py 17.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .. import core
16 17 18 19 20
from ..framework import (
    Variable,
    convert_np_dtype_to_dtype_,
    in_dygraph_mode,
)
21
from ..framework import _create_tensor as framework_create_tensor
22
from ..layers.layer_function_generator import OpProtoHolder
23
from . import no_grad
J
Jiabin Yang 已提交
24
from .. import framework
25

26
import numpy as np
27
import warnings
28
from paddle import _C_ops, _legacy_C_ops
29

30 31 32 33 34 35
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
36
    core.VarDesc.VarType.BOOL,
37 38
]

39 40 41 42
# NOTE(chenweihang): We currently do not fully support the type promotion
# between tensors. Parting support here is because the interoperation of
# real and complex numbers in paddle quantum is very frequent, such as the
# binary operation between `float` and `complex64`, so we must support the
43 44 45 46 47 48 49 50 51 52
# correct type promotion on the APIs paddle quantum used.
# Now only check in dygraph (paddle quantum based dygraph)
# Full type promotion support will need to be fully verified later.
_supported_promote_complex_types_ = [
    '__add__',
    '__radd__',
    '__sub__',
    '__rsub__',
    '__mul__',
    '__rmul__',
53
    '__div__',
54
    '__truediv__',
55
    '__rdiv__',
56 57 58 59
    '__rtruediv__',
    '__matmul__',
]

60 61 62 63 64
_complex_dtypes = [
    core.VarDesc.VarType.COMPLEX64,
    core.VarDesc.VarType.COMPLEX128,
]

65
_already_patch_eager_tensor = False
66

67 68 69 70 71 72 73

def monkey_patch_math_varbase():
    """
    Similar to monkey_patch_variable.
    The difference is, in dygraph mode, use auto-generated op functions for better performance.
    """

74
    @no_grad
75
    def create_tensor(value, dtype, shape):
76
        if framework.global_var._in_eager_mode_:
77 78 79
            out = _C_ops.full(
                shape, value, dtype, framework._current_expected_place()
            )
80
        else:
81
            out = framework_create_tensor(dtype=dtype)
82 83 84 85 86 87 88 89 90 91 92
            out = _legacy_C_ops.fill_constant(
                out,
                'dtype',
                dtype,
                'shape',
                shape,
                'value',
                value,
                'force_cpu',
                False,
            )
93 94
        out.stop_gradient = True
        return out
95 96

    def create_scalar(value, dtype):
97
        return create_tensor(value, dtype, shape=[])
98 99 100 101

    def astype(self, dtype):
        """

102
        Cast a Tensor to a specified data type.
103 104

        Args:
105
            dtype: The target data type.
106 107

        Returns:
108
            Tensor: a new Tensor with target dtype
109 110 111 112

        Examples:
            .. code-block:: python

113
                import paddle
114 115
                import numpy as np

116 117 118 119
                original_tensor = paddle.ones([2, 2])
                print("original tensor's dtype is: {}".format(original_tensor.dtype))
                new_tensor = original_tensor.astype('float32')
                print("new tensor's dtype is: {}".format(new_tensor.dtype))
120 121

        """
122 123
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
124
        return _C_ops.cast(self, dtype)
125 126

    def _scalar_elementwise_op_(var, scale, bias):
127
        if framework.in_dygraph_mode():
128
            return _C_ops.scale(var, float(scale), bias, True)
姜永久 已提交
129 130
        else:
            return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias)
131

132 133 134
    def _neg_(var):
        return _scalar_elementwise_op_(var, -1.0, 0.0)

135 136
    def _float_(var):
        numel = np.prod(var.shape)
137 138 139
        assert (
            numel == 1
        ), "only one element variable can be converted to float."
140 141
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
142
        return float(np.array(var).flatten()[0])
143 144 145 146 147 148

    def _long_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to long."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
149
        return int(np.array(var).flatten()[0])
150 151 152 153 154 155

    def _int_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to int."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
156
        return int(np.array(var).flatten()[0])
157 158

    def _len_(var):
159
        assert var.ndim > 0, "len() of a 0D tensor is wrong"
S
Steffy-zxf 已提交
160 161 162 163 164 165
        if var.type == core.VarDesc.VarType.VOCAB:
            return len(var.value().get_map_tensor())
        elif var.type == core.VarDesc.VarType.STRINGS:
            return len(var.value().get_string_tensor())
        else:
            return var.shape[0]
166 167 168

    def _index_(var):
        numel = np.prod(var.shape)
169 170 171
        assert (
            numel == 1
        ), "only one element variable can be converted to python index."
172 173
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
174
        return int(np.array(var).flatten()[0])
175

176 177 178 179
    @property
    def _ndim_(var):
        return len(var.shape)

180 181 182 183
    @property
    def _size_(var):
        return np.prod(var.shape)

184 185 186 187 188 189 190
    @property
    def _T_(var):
        if len(var.shape) == 1:
            return var
        perm = []
        for i in range(len(var.shape)):
            perm.insert(0, i)
姜永久 已提交
191
        out = _C_ops.transpose(var, perm)
192 193
        return out

194
    def _scalar_add_(var, value):
195 196
        return _scalar_elementwise_op_(var, 1.0, value)

197
    def _scalar_sub_(var, value):
198 199
        return _scalar_elementwise_op_(var, 1.0, -value)

200
    def _scalar_rsub_(var, value):
201 202
        return _scalar_elementwise_op_(var, -1.0, value)

203
    def _scalar_mul_(var, value):
204 205
        return _scalar_elementwise_op_(var, value, 0.0)

206 207 208
    def _scalar_div_(var, value):
        return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

209
    # for binary operator such as elementwise, compare
210 211 212 213 214 215 216
    def _binary_creator_(
        method_name,
        op_type,
        reverse=False,
        scalar_method=None,
        call_final_api=False,
    ):
217
        def __impl__(self, other_var):
218 219 220 221 222 223 224 225 226
            # 1. scalar exists cases
            # we need combine the tensor.dtype and scalar.dtype, cast correct object
            if isinstance(other_var, float):
                # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
                if self.dtype in _supported_int_dtype_:
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
                # but only +, -, *, / can use this method
                if scalar_method is not None:
227
                    return scalar_method(self, other_var)
228 229 230 231 232 233
            elif isinstance(other_var, int):
                # in all cases(+, -, *, /, **, //, %), we can cast it to float
                # because the output tensor.dtype depend on the type of input tensor
                other_var = float(other_var)
                # division is a special case
                # NOTE(chenweihang): because we cast tensor to float32 instead float64,
234 235 236
                # the division result can only guarantee the numerical accuracy of 6 digits
                # after the decimal point. The result of numpy calculation is of float64 type,
                # so the calculation result here and the calculation result of numpy are
237 238
                # different after 6 decimal point. If necessary, we can also use float64 here.
                # torch's behavior here is consistent with ours
239 240 241
                if (
                    op_type == "divide" or op_type == "elementwise_div"
                ) and self.dtype in _supported_int_dtype_:
242 243
                    self = astype(self, 'float32')
                # here use `scale` replace `elementwise` to get better performance
244
                # but only +, -, *, / can use this method
245 246 247 248 249
                if scalar_method is not None:
                    return scalar_method(self, other_var)
            else:
                # do nothing
                pass
250

251
            # 2. create varbase for scalar
252
            lhs_dtype = self.dtype
W
wanghuancoder 已提交
253
            other_var_should_be = core.eager.Tensor
254
            if not isinstance(other_var, other_var_should_be):
255 256
                if isinstance(other_var, complex):
                    import paddle
257

258
                    other_var = paddle.to_tensor(other_var, dtype='complex64')
259
                else:
260
                    if reverse:
261 262 263
                        other_var = create_tensor(
                            other_var, dtype=lhs_dtype, shape=self.shape
                        )
264 265
                    else:
                        # add fill_op
266 267 268
                        other_var = create_scalar(
                            value=other_var, dtype=lhs_dtype
                        )
269

270
            # 3. promote types or unify right var type to left var
271
            rhs_dtype = other_var.dtype
272
            if lhs_dtype != rhs_dtype:
273
                if method_name in _supported_promote_complex_types_ and (
274 275
                    lhs_dtype in _complex_dtypes or rhs_dtype in _complex_dtypes
                ):
276 277 278 279
                    # only when lhs_dtype or rhs_dtype is complex type,
                    # the dtype will promote, in other cases, directly
                    # use lhs_dtype, this is consistent will original rule
                    promote_dtype = core._promote_types_if_complex_exists(
280 281 282 283 284 285 286 287 288 289 290 291
                        lhs_dtype, rhs_dtype
                    )
                    self = (
                        self
                        if lhs_dtype == promote_dtype
                        else astype(self, promote_dtype)
                    )
                    other_var = (
                        other_var
                        if rhs_dtype == promote_dtype
                        else astype(other_var, promote_dtype)
                    )
292
                else:
293
                    warnings.warn(
294 295 296 297
                        'The dtype of left and right variables are not the same, left dtype is {}, but right dtype is {}, the right dtype will convert to {}'.format(
                            lhs_dtype, rhs_dtype, lhs_dtype
                        )
                    )
298 299
                    other_var = astype(other_var, lhs_dtype)

300 301 302 303 304
            if reverse:
                tmp = self
                self = other_var
                other_var = tmp

305 306 307
            if (
                op_type == "divide" or op_type == "elementwise_div"
            ) and self.dtype in _supported_int_dtype_:
308 309 310
                self = astype(self, 'float32')
                other_var = astype(other_var, 'float32')

311
            # 4. calculation
312
            axis = -1
313 314 315 316
            if in_dygraph_mode():
                math_op = getattr(_C_ops, op_type)
            else:
                math_op = getattr(_legacy_C_ops, op_type)
317
            if call_final_api:
318
                if op_type == "matmul":
319
                    return math_op(self, other_var, False, False)
320 321 322 323 324
                if op_type == "pow":
                    if isinstance(other_var, core.eager.Tensor):
                        return _C_ops.elementwise_pow(self, other_var)
                    else:
                        return _C_ops.elementwise_pow(self, other_var)
325 326
                return math_op(self, other_var, -1)
            return math_op(self, other_var, 'axis', axis)
327

328 329 330 331
        if call_final_api:
            comment = ""
        else:
            comment = OpProtoHolder.instance().get_op_proto(op_type).comment
332 333 334 335

        __impl__.__doc__ = """
        {0}
        Args:
336
            other_var(Tensor|float|int): right hand Tensor
337 338

        Returns:
339
            Tensor
340 341 342
        """.format(
            comment
        )
343 344 345
        __impl__.__name__ = method_name
        return __impl__

346 347 348 349 350 351 352 353 354 355 356
    varbase_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
357
        ('size', _size_),
358
        ('T', _T_),
359 360 361 362
        (
            '__add__',
            _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_),
        ),
363
        #  a+b == b+a. Do not need to reverse explicitly
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
        (
            '__radd__',
            _binary_creator_(
                '__radd__', 'elementwise_add', False, _scalar_add_
            ),
        ),
        (
            '__sub__',
            _binary_creator_('__sub__', 'elementwise_sub', False, _scalar_sub_),
        ),
        (
            '__rsub__',
            _binary_creator_(
                '__rsub__', 'elementwise_sub', True, _scalar_rsub_
            ),
        ),
        (
            '__mul__',
            _binary_creator_('__mul__', 'elementwise_mul', False, _scalar_mul_),
        ),
384
        ## a*b == b*a. Do not need to reverse explicitly
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
        (
            '__rmul__',
            _binary_creator_(
                '__rmul__', 'elementwise_mul', False, _scalar_mul_
            ),
        ),
        (
            '__div__',
            _binary_creator_('__div__', 'elementwise_div', False, _scalar_div_),
        ),
        (
            '__truediv__',
            _binary_creator_(
                '__truediv__', 'elementwise_div', False, _scalar_div_
            ),
        ),
        (
            '__rdiv__',
            _binary_creator_('__rdiv__', 'elementwise_div', True, None),
        ),
        (
            '__rtruediv__',
            _binary_creator_('rtruediv__', 'elementwise_div', True, None),
        ),
        (
            '__pow__',
            _binary_creator_('__pow__', 'elementwise_pow', False, None),
        ),
        (
            '__rpow__',
            _binary_creator_('__rpow__', 'elementwise_pow', True, None),
        ),
        (
            '__floordiv__',
            _binary_creator_(
                '__floordiv__', 'elementwise_floordiv', False, None
            ),
        ),
        (
            '__mod__',
            _binary_creator_('__mod__', 'elementwise_mod', False, None),
        ),
        (
            '__matmul__',
            _binary_creator_('__matmul__', "matmul_v2", False, None),
        ),
431 432
        ## for logical compare
        ('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
433 434 435 436
        ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
        ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
        ('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
        ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
437
        ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
438
        ('__array_ufunc__', None),
439 440
    ]

441 442 443 444 445 446 447 448 449 450 451 452 453 454
    eager_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
        ('size', _size_),
        ('T', _T_),
        # for logical compare
455
        ('__array_ufunc__', None),
456 457 458 459 460 461 462
    ]

    eager_cpp_level_patch = [
        "__add__",
        "__radd__",
        '__sub__',
        '__rsub__',
463 464
        '__mul__',
        '__rmul__',
465 466 467 468
        '__div__',
        '__truediv__',
        '__rdiv__',
        '__rtruediv__',
469 470
        '__mod__',
        '__matmul__',
W
Weilong Wu 已提交
471 472
        '__gt__',
        '__ge__',
473 474
        '__lt__',
        '__le__',
W
Weilong Wu 已提交
475
        '__floordiv__',
476 477
        '__pow__',
        '__rpow__',
478
        '__eq__',
479
        '__ne__',
480 481
    ]

482 483
    global _already_patch_eager_tensor

W
wanghuancoder 已提交
484 485 486
    local_already_patch = _already_patch_eager_tensor
    _already_patch_eager_tensor = True
    local_tensor = core.eager.Tensor
487

488
    if not local_already_patch:
489
        if framework.global_var._in_eager_mode_:
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
            for method_name in eager_cpp_level_patch:
                method_impl = getattr(local_tensor, method_name, None)
                if method_impl:
                    setattr(local_tensor, method_name, method_impl)

            for method in eager_methods:
                method_name = method[0]
                method_impl = method[1]
                setattr(local_tensor, method_name, method_impl)

        else:
            for method in varbase_methods:
                method_name = method[0]
                method_impl = method[1]
                setattr(local_tensor, method_name, method_impl)
505 506
    else:
        import paddle.tensor
507

508
        # Tensor method from module paddle.tensor
509
        for method_name in paddle.tensor.tensor_method_func:
510 511
            if hasattr(local_tensor, method_name):
                continue
512
            method_impl = getattr(paddle.tensor, method_name, None)
513 514
            if method_impl:
                setattr(local_tensor, method_name, method_impl)
515

516 517
        for magic_method, origin_method in paddle.tensor.magic_method_func:
            impl = getattr(paddle.tensor, origin_method, None)
518 519
            if impl:
                setattr(local_tensor, magic_method, impl)