math_op_patch.py 9.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
#
Y
Yang Yu 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
Y
Yang Yu 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
Y
Yang Yu 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
from .. import core
Y
Yang Yu 已提交
18
from ..framework import Variable, unique_name
19
from .layer_function_generator import OpProtoHolder
20
from ..initializer import force_init_on_cpu
Y
Yang Yu 已提交
21

22 23 24 25 26 27 28 29
_supported_int_dtype_ = [
    core.VarDesc.VarType.UINT8,
    core.VarDesc.VarType.INT8,
    core.VarDesc.VarType.INT16,
    core.VarDesc.VarType.INT32,
    core.VarDesc.VarType.INT64,
]

Y
Yang Yu 已提交
30 31

def monkey_patch_variable():
Y
Yang Yu 已提交
32
    def unique_tmp_name():
Y
Yu Yang 已提交
33
        return unique_name.generate("tmp")
Y
Yang Yu 已提交
34 35 36 37 38 39 40 41

    def safe_get_dtype(var):
        try:
            dtype = var.dtype
        except:
            raise ValueError("Cannot get data type from %s", var.name)
        return dtype

42
    def current_block(var):
43
        return var.block
44 45 46 47 48

    def create_new_tmp_var(block, dtype):
        tmp_name = unique_tmp_name()
        return block.create_var(name=tmp_name, dtype=dtype)

Y
Yang Yu 已提交
49 50
    def create_tensor(block, value, dtype, shape):
        value = float(value)
51
        var = create_new_tmp_var(block, dtype)
Y
Yang Yu 已提交
52 53 54
        block.append_op(
            type="fill_constant",
            outputs={'Out': [var]},
55 56 57 58 59
            attrs={
                'dtype': var.dtype,
                'shape': shape,
                'value': value,
                'force_cpu': force_init_on_cpu()
H
Hongyu Liu 已提交
60 61 62
            },
            stop_gradient=True)
        var.stop_gradient = True
Y
Yang Yu 已提交
63 64
        return var

Y
Yang Yu 已提交
65 66 67
    def create_scalar(block, value, dtype):
        return create_tensor(block, value, dtype, shape=[1])

Y
Yang Yu 已提交
68 69 70
    def create_tensor_with_batchsize(ref_var, value, dtype):
        assert isinstance(ref_var, Variable)
        value = float(value)
71 72
        block = current_block(ref_var)
        var = create_new_tmp_var(block, dtype)
73 74 75 76 77 78
        batch_dim = -1
        for i, d in enumerate(ref_var.shape):
            if d < 0:
                batch_dim = i
                break
        assert batch_dim != -1
79
        block.append_op(
Y
Yang Yu 已提交
80 81 82
            type='fill_constant_batch_size_like',
            outputs={'Out': [var]},
            inputs={'Input': [ref_var]},
83 84 85 86 87
            attrs={
                'shape': ref_var.shape,
                'value': value,
                'input_dim_idx': batch_dim,
                'output_dim_idx': batch_dim
H
Hongyu Liu 已提交
88 89 90 91
            },
            stop_gradient=True)

        var.stop_gradient = True
Y
Yang Yu 已提交
92 93 94 95
        return var

    def astype(self, dtype):
        """
Y
Yang Yu 已提交
96
        Cast a variable to a specified data type.
Y
Yang Yu 已提交
97 98 99 100 101 102 103 104
        NOTE: The variable must be a Tensor
        Args:
            self(Variable): The source variable
            dtype: The target dtype

        Returns:
            Variable with new dtype
        """
105 106 107
        block = current_block(self)
        out = create_new_tmp_var(block, dtype)
        block.append_op(
Y
Yang Yu 已提交
108 109 110 111 112 113 114
            type="cast",
            inputs={"X": [self]},
            outputs={"Out": [out]},
            attrs={"in_dtype": self.dtype,
                   "out_dtype": out.dtype})
        return out

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    def _scalar_elementwise_op_(var, scale, bias):
        block = current_block(var)
        out = create_new_tmp_var(block, var.dtype)
        block.append_op(
            type="scale",
            inputs={"X": [var]},
            outputs={"Out": [out]},
            attrs={"scale": scale,
                   "bias": bias})
        return out

    def _scalar_elementwise_add_(var, value):
        return _scalar_elementwise_op_(var, 1.0, value)

    def _scalar_elementwise_sub_(var, value):
        return _scalar_elementwise_op_(var, 1.0, -value)

    def _scalar_elementwise_rsub_(var, value):
        return _scalar_elementwise_op_(var, -1.0, value)

    def _scalar_elementwise_mul_(var, value):
        return _scalar_elementwise_op_(var, value, 0.0)

    def _scalar_elementwise_div_(var, value):
        return _scalar_elementwise_op_(var, 1.0 / value, 0.0)

    def _elemwise_method_creator_(method_name,
                                  op_type,
                                  reverse=False,
                                  scalar_method=None):
Y
Yang Yu 已提交
145
        def __impl__(self, other_var):
146 147 148 149 150
            # FIXME(zjl): elementwise_div between integers cannot be converted to scale,
            # which may lose accuracy. This is a hot fix for release 1.6.
            if scalar_method is not None and not (
                    op_type == 'elementwise_div' and
                    self.dtype in _supported_int_dtype_):
151 152 153 154 155 156 157 158
                if isinstance(other_var, float):
                    if self.dtype in _supported_int_dtype_:
                        assert other_var == int(other_var), \
                            "float value {} cannot convert to integer".format(other_var)
                    return scalar_method(self, other_var)
                elif isinstance(other_var, int):
                    return scalar_method(self, float(other_var))

Y
Yang Yu 已提交
159 160 161 162 163 164 165 166 167 168 169
            lhs_dtype = safe_get_dtype(self)

            if not isinstance(other_var, Variable):
                if reverse:
                    has_batch_size = False
                    for elem in self.shape:
                        if elem < 0:
                            has_batch_size = True
                            break
                    if not has_batch_size:
                        other_var = create_tensor(
170
                            current_block(self),
Y
Yang Yu 已提交
171 172 173 174 175 176 177
                            other_var,
                            dtype=lhs_dtype,
                            shape=self.shape)
                    else:
                        other_var = create_tensor_with_batchsize(
                            self, other_var, lhs_dtype)
                else:
178
                    # add fill_op to current_block
Y
Yang Yu 已提交
179
                    other_var = create_scalar(
180
                        current_block(self), value=other_var, dtype=lhs_dtype)
Y
Yang Yu 已提交
181 182 183 184 185 186 187 188 189

            rhs_dtype = safe_get_dtype(other_var)
            if lhs_dtype != rhs_dtype:
                other_var = astype(other_var, lhs_dtype)
            if reverse:
                tmp = self
                self = other_var
                other_var = tmp

190
            out = create_new_tmp_var(current_block(self), dtype=lhs_dtype)
191

192 193 194 195 196 197 198 199
            axis = -1
            if other_var.shape[0] == -1:
                axis = 0
            assert len(self.shape) >= len(other_var.shape), (
                "The rank of the first argument of an binary operator cannot "
                "be smaller than the rank of its second argument: %s vs %s" %
                (len(self.shape), len(other_var.shape)))

200
            current_block(self).append_op(
Y
Yang Yu 已提交
201 202 203
                type=op_type,
                inputs={'X': [self],
                        'Y': [other_var]},
204 205
                outputs={'Out': out},
                attrs={'axis': axis})
Y
Yang Yu 已提交
206 207 208 209 210 211 212 213
            return out

        comment = OpProtoHolder.instance().get_op_proto(op_type).comment

        __impl__.__doc__ = """
        {0}
        Args:
            self(Variable): left hand variable
214
            other_var(Variable|float|int): right hand variable
Y
Yang Yu 已提交
215 216 217 218 219 220 221 222

        Returns:
            Variable
        """.format(comment)
        __impl__.__name__ = method_name
        return __impl__

    # inject methods
223 224
    for method_name, op_type, reverse, scalar_method in (
        ("__add__", "elementwise_add", False, _scalar_elementwise_add_),
Y
Yang Yu 已提交
225
            # a+b == b+a. Do not need to reverse explicitly
226 227 228 229
        ("__radd__", "elementwise_add", False, _scalar_elementwise_add_),
        ("__sub__", "elementwise_sub", False, _scalar_elementwise_sub_),
        ("__rsub__", "elementwise_sub", True, _scalar_elementwise_rsub_),
        ("__mul__", "elementwise_mul", False, _scalar_elementwise_mul_),
Y
Yang Yu 已提交
230
            # a*b == b*a. Do not need to reverse explicitly
231 232 233 234 235 236 237 238 239
        ("__rmul__", "elementwise_mul", False, _scalar_elementwise_mul_),
        ("__div__", "elementwise_div", False, _scalar_elementwise_div_),
        ("__truediv__", "elementwise_div", False, _scalar_elementwise_div_),
        ("__rdiv__", "elementwise_div", True, None),
        ("__rtruediv__", "elementwise_div", True, None),
        ("__pow__", "elementwise_pow", False, None),
        ("__rpow__", "elementwise_pow", True, None),
        ("__floordiv__", "elementwise_floordiv", False, None),
        ("__mod__", "elementwise_mod", False, None),
240
            # for logical compare
241 242 243 244 245 246
        ("__eq__", "equal", False, None),
        ("__ne__", "not_equal", False, None),
        ("__lt__", "less_than", False, None),
        ("__le__", "less_equal", False, None),
        ("__gt__", "greater_than", False, None),
        ("__ge__", "greater_equal", False, None)):
Y
Yang Yu 已提交
247
        setattr(Variable, method_name,
248 249
                _elemwise_method_creator_(method_name, op_type, reverse,
                                          scalar_method))
Y
Yang Yu 已提交
250 251

    Variable.astype = astype