math_op_patch.py 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from .. import core
18
from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator, ComplexVariable
19
from ..layers.layer_function_generator import OpProtoHolder
20
from . import no_grad
21

22 23 24
import numpy as np
import six

25 26 27 28 29 30 31 32
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
]

33 34
_already_patch_varbase = False

35 36 37 38 39 40 41

def monkey_patch_math_varbase():
    """
    Similar to monkey_patch_variable.
    The difference is, in dygraph mode, use auto-generated op functions for better performance.
    """

42
    @no_grad
43
    def create_tensor(value, dtype, shape):
44 45 46 47 48
        out = _varbase_creator(dtype=dtype)
        out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
                                     'value', value, 'force_cpu', False)
        out.stop_gradient = True
        return out
49 50 51 52 53 54 55

    def create_scalar(value, dtype):
        return create_tensor(value, dtype, shape=[1])

    def astype(self, dtype):
        """

56
        Cast a Tensor to a specified data type.
57 58

        Args:
59
            dtype: The target data type.
60 61

        Returns:
62
            Tensor: a new Tensor with target dtype
63 64 65 66

        Examples:
            .. code-block:: python

67
                import paddle
68 69
                import numpy as np

70 71 72 73
                original_tensor = paddle.ones([2, 2])
                print("original tensor's dtype is: {}".format(original_tensor.dtype))
                new_tensor = original_tensor.astype('float32')
                print("new tensor's dtype is: {}".format(new_tensor.dtype))
74 75

        """
76 77 78
        if not isinstance(dtype, core.VarDesc.VarType):
            dtype = convert_np_dtype_to_dtype_(dtype)
        return core.ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', dtype)
79 80

    def _scalar_elementwise_op_(var, scale, bias):
81
        return core.ops.scale(var, 'scale', scale, 'bias', bias)
82

83 84 85
    def _neg_(var):
        return _scalar_elementwise_op_(var, -1.0, 0.0)

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    def _float_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to float."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        return float(var.numpy().flatten()[0])

    def _long_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to long."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        if six.PY2:
            return long(var.numpy().flatten()[0])
        else:
            return int(var.numpy().flatten()[0])

    def _int_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to int."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        return int(var.numpy().flatten()[0])

    def _len_(var):
        return var.shape[0]

    def _index_(var):
        numel = np.prod(var.shape)
        assert numel == 1, "only one element variable can be converted to python index."
        tensor = var.value().get_tensor()
        assert tensor._is_initialized(), "variable's tensor is not initialized"
        if six.PY2:
            return long(var.numpy().flatten()[0])
        else:
            return int(var.numpy().flatten()[0])

123 124 125 126
    @property
    def _ndim_(var):
        return len(var.shape)

127 128 129 130
    @property
    def _size_(var):
        return np.prod(var.shape)

131
    def _scalar_add_(var, value):
132 133
        return _scalar_elementwise_op_(var, 1.0, value)

134
    def _scalar_sub_(var, value):
135 136
        return _scalar_elementwise_op_(var, 1.0, -value)

137
    def _scalar_rsub_(var, value):
138 139
        return _scalar_elementwise_op_(var, -1.0, value)

140
    def _scalar_mul_(var, value):
141 142
        return _scalar_elementwise_op_(var, value, 0.0)

143
    def _scalar_div_(var, value):
144 145
        return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

146 147 148 149 150
    # for binary operator such as elementwise, compare
    def _binary_creator_(method_name,
                         op_type,
                         reverse=False,
                         scalar_method=None):
151
        def __impl__(self, other_var):
152 153 154 155 156 157 158
            # tensor and ComplexVariable opetator
            if isinstance(other_var, ComplexVariable):
                # need import paddle in closure
                import paddle
                math_op = getattr(paddle.incubate.complex.tensor, op_type)
                return math_op(self, other_var)

159 160 161 162 163 164 165 166 167 168 169 170 171
            # FIXME(zjl): elementwise_div between integers cannot be converted to scale,
            # which may lose accuracy. This is a hot fix for release 1.6.
            if scalar_method is not None and not (
                    op_type == 'elementwise_div' and
                    self.dtype in _supported_int_dtype_):
                if isinstance(other_var, float):
                    if self.dtype in _supported_int_dtype_:
                        assert other_var == int(other_var), \
                            "float value {} cannot convert to integer".format(other_var)
                    return scalar_method(self, other_var)
                elif isinstance(other_var, int):
                    return scalar_method(self, float(other_var))

172
            lhs_dtype = self.dtype
173 174 175 176 177 178 179 180 181

            if not isinstance(other_var, core.VarBase):
                if reverse:
                    other_var = create_tensor(
                        other_var, dtype=lhs_dtype, shape=self.shape)
                else:
                    # add fill_op 
                    other_var = create_scalar(value=other_var, dtype=lhs_dtype)

182
            rhs_dtype = other_var.dtype
183 184 185 186 187 188 189 190
            if lhs_dtype != rhs_dtype:
                other_var = astype(other_var, lhs_dtype)
            if reverse:
                tmp = self
                self = other_var
                other_var = tmp

            axis = -1
191
            math_op = getattr(core.ops, op_type)
L
Leo Chen 已提交
192
            return math_op(self, other_var, 'axis', axis)
193 194 195 196 197 198

        comment = OpProtoHolder.instance().get_op_proto(op_type).comment

        __impl__.__doc__ = """
        {0}
        Args:
199
            other_var(Tensor|float|int): right hand Tensor
200 201

        Returns:
202
            Tensor
203 204 205 206
        """.format(comment)
        __impl__.__name__ = method_name
        return __impl__

207 208 209 210 211 212 213 214 215 216 217
    varbase_methods = [
        ('__neg__', _neg_),
        ('__float__', _float_),
        ('__long__', _long_),
        ('__int__', _int_),
        ('__len__', _len_),
        ('__index__', _index_),
        ('astype', astype),
        ('dim', lambda x: len(x.shape)),
        ('ndimension', lambda x: len(x.shape)),
        ('ndim', _ndim_),
218
        ('size', _size_),
219 220 221 222 223 224 225 226 227 228 229 230 231 232
        ('__add__',
         _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)),
        ##  a+b == b+a. Do not need to reverse explicitly
        ('__radd__',
         _binary_creator_('__radd__', 'elementwise_add', False, _scalar_add_)),
        ('__sub__', _binary_creator_('__sub__', 'elementwise_sub', False,
                                     _scalar_sub_)),
        ('__rsub__', _binary_creator_('__rsub__', 'elementwise_sub', True,
                                      _scalar_rsub_)),
        ('__mul__', _binary_creator_('__mul__', 'elementwise_mul', False,
                                     _scalar_mul_)),
        ## a*b == b*a. Do not need to reverse explicitly
        ('__rmul__',
         _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
S
ShenLiang 已提交
233 234 235 236 237 238
        ('__div__', _binary_creator_('__div__', 'elementwise_div', False,
                                     _scalar_div_)),
        ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
                                         False, _scalar_div_)),
        ('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
                                      None)),
239 240 241 242 243 244
        ('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
                                          None)),
        ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
                                     None)),
        ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
                                      None)),
S
ShenLiang 已提交
245 246 247 248
        ('__floordiv__', _binary_creator_('__floordiv__',
                                          'elementwise_floordiv', False, None)),
        ('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False,
                                     None)),
249 250 251 252 253 254 255
        ## for logical compare
        ('__eq__', _binary_creator_('__eq__', 'equal', False, None)),
        ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)),
        ('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
        ('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
        ('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
        ('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
256
        ('__array_ufunc__', None)
257 258 259 260 261 262 263 264 265 266
    ]

    global _already_patch_varbase
    if not _already_patch_varbase:
        for method in varbase_methods:
            method_name = method[0]
            method_impl = method[1]
            setattr(core.VarBase, method_name, method_impl)
    else:
        import paddle.tensor
267 268 269 270 271 272 273 274 275
        # Tensor method from module paddle.tensor
        tensor_methods = paddle.tensor.linalg.__all__ + \
                         paddle.tensor.math.__all__ + \
                         paddle.tensor.logic.__all__ + \
                         paddle.tensor.manipulation.__all__ + \
                         paddle.tensor.search.__all__ + \
                         paddle.tensor.stat.__all__ + \
                         paddle.tensor.attribute.__all__
        for method_name in tensor_methods:
276 277 278 279 280
            if hasattr(core.VarBase, method_name): continue
            method_impl = getattr(paddle.tensor, method_name, None)
            if method_impl: setattr(core.VarBase, method_name, method_impl)

    _already_patch_varbase = True