inference_transpiler.py 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import numpy as np
17 18 19
from .. import core
from ..framework import Program
from ..executor import global_scope
20 21


22
class InferenceTranspiler(object):
L
Luo Tao 已提交
23
    '''
24 25 26 27 28 29
    Convert the fluid program to optimized inference program.

    There are several optimizations:

      - fuse convolution and batch normalization
      - fuse batch normalization and relu (MKLDNN only)
L
Luo Tao 已提交
30 31

    Examples:
32

L
Luo Tao 已提交
33 34 35 36 37 38 39 40 41
    .. code-block:: python

        # As InferenceTranspiler will modify the original program,
        # please clone before use it.
        inference_transpiler_program = program.clone()
        t = fluid.InferenceTranspiler()
        t.transpile(inference_transpiler_program, place)
    '''

L
Luo Tao 已提交
42
    def transpile(self, program, place, scope=None):
43
        '''
L
Luo Tao 已提交
44 45 46 47 48 49
        Run the transpiler.

        Args:
            program (Program): program to transpile
            place (Place): inference place
            scope (Scope|None): inference Scope
L
Luo Tao 已提交
50
        '''
L
Luo Tao 已提交
51 52 53 54 55 56 57 58 59 60
        if not isinstance(program, Program):
            raise TypeError("program should be as Program type")
        if not isinstance(place, core.CPUPlace) and not isinstance(
                place, core.CUDAPlace):
            raise TypeError("place should be as CPUPlace/CUDAPlace type")
        if scope is None:
            scope = global_scope()
        if not isinstance(scope, core.Scope):
            raise TypeError("scope should be as Scope type or None")
        self.fuse_batch_norm(program, place, scope)
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        self.fuse_relu_mkldnn(program)

    def fuse_relu_mkldnn(self, program):
        '''
        Transpile the program by fused relu activation for MKLDNN program.

        Relu activation following batch norm OP can be fused by adding
        :math:`fuse_with_relu` attribute to batch norm OP.

        The result of fuse is:

        - before:

          - batch_norm->relu->any_other_op

        - after:

          - batch_norm->any_other_op

        :param program: program to transpile
        :type program: Program
        '''
        use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
        if not use_mkldnn:
            return

        self.block = program.block(0)

        i = 0
        while i < len(self.block.ops) - 1:
            current_op = self.block.ops[i]
            if current_op.type in ['batch_norm']:
                next_op = self.block.ops[i + 1]
                if next_op.type == 'relu':
                    # modify bnorm OP to include relu
                    current_op.set_attr("fuse_with_relu", True)
                    # remove relu OP
                    self.block.remove_op(i + 1)
            i = i + 1

        self._remove_unused_var()
        # TODO(luotao): use clone() method to flush the program.desc in force,
        # since some large program.desc will not be flushed immediately.
        # And a better solution will be considered later.
        program = program.clone()
L
Luo Tao 已提交
106

L
Luo Tao 已提交
107
    def fuse_batch_norm(self, program, place, scope):
L
Luo Tao 已提交
108 109
        '''
        Transpile the program by fused batch normalization.
110 111 112

        The batch normalization followed the convolution or fully connected layer
        can be integrated with them. Doing so will give us a forward acceleration,
113
        especially in environments like mobile or embedded.
114

L
Luo Tao 已提交
115 116
        For input :math:`X`:

117 118
        - Conv process:        :math:`X = input * W + bias`
        - Batch norm process:  :math:`X' = (X - mean) / std`
L
Luo Tao 已提交
119
        - Scale Process:       :math:`Y = a * X' + b`
120 121 122

        After fuse into one operation:

L
Luo Tao 已提交
123 124 125 126
        .. math::

            Y &= (input * W + bias - mean) / std * a + b \\\\
              &= input * a * W / std + ((bias - mean) / std * a + b)
127

128
        The operator transformation is:
L
Luo Tao 已提交
129

130
        - before:
L
Luo Tao 已提交
131

132 133
          - conv->batch_norm->any_other_op (bias == 0)
          - conv->elementwise_add->batch_norm->any_other_op (bias != 0)
134 135

        - after:
L
Luo Tao 已提交
136

137
          - conv->elementwise_add->any_other_op
138

139
        The transpile stages are:
L
Luo Tao 已提交
140

141
        1. insert elementwise_add op when bias == 0.
142
        2. fuse the batch_norm's parameters to conv and elementwise_add operators.
143 144 145
        3. remove batch_norm ops which are not used in any other ops.
        4. adjust the input of any_other_op to be the output of elementwise_add operator.
        5. remove unused variables.
146

L
Luo Tao 已提交
147 148 149 150
        Args:
            program (Program): program to transpile
            place (Place): inference place
            scope (Scope): inference Scope
151

152 153 154
        '''
        self.scope = scope
        self.place = place
155
        self.block = program.block(0)
156
        self.input_map = {}  # store the input names should be adjusted
157

158
        i = 0
159
        while i < len(self.block.ops) - 2:
160
            current_op = self.block.ops[i]
161
            # TODO(luotao1): consider only conv2d now. fc would be delt later.
162
            if current_op.type in ['conv2d']:
163 164
                # TODO(luotao1): consider single chain network now.
                # For branch network, we counldn't use block.ops[i + 1] as
L
Luo Tao 已提交
165
                # the judgment condition.
166
                next_op = self.block.ops[i + 1]
167
                # conv2d without bias
168
                if (next_op.type == 'batch_norm'):
169 170 171
                    # insert bias op
                    bias_op = self._insert_bias_op(i + 1, current_op, next_op)
                    # fuse batch_norm
172
                    self._fuse_param(current_op, next_op, bias_op, 0)
173
                    # remove batch_norm_op
174
                    self.block.remove_op(i + 2)
175
                    i = i + 1
176 177 178 179 180 181 182 183 184
                # conv2d with bias, the next_op.type is elementwise_add
                elif (next_op.type == 'elementwise_add'):
                    next_next_op = self.block.ops[i + 2]
                    if (next_next_op.type == 'batch_norm'):
                        # fuse batch_norm
                        self._fuse_param(current_op, next_next_op, next_op, 1)
                        # remove batch_norm_op
                        self.block.remove_op(i + 2)
                        i = i + 1
185 186
            i = i + 1

187
        self._adjust_input()
188
        self._remove_unused_var()
189 190
        # TODO(luotao): use clone() method to flush the program.desc in force,
        # since some large program.desc will not be flushed immediately.
L
Luo Tao 已提交
191
        # And a better solution will be considered later.
L
Luo Tao 已提交
192
        program = program.clone()
193 194 195 196

    # ====================== private transpiler functions =====================
    def _insert_bias_op(self, index, current_op, bn_op):
        '''
197
        Construct elementwise_add operator for adding bias
198
        and insert it into program.
199

200 201 202 203 204 205 206 207 208 209 210
        :param index: insert location of bias_op
        :type index: Int
        :param current_op: current operator (conv or fc)
        :type current_op: Operator
        :param bn_op: batch norm operator
        :type bn_op: Operator
        :return: bias_op
        :rtype: Operator
        '''
        # The input of bias_op is current_op's output and Bias of bn_op
        # The output of bias_op is bn_op's output
211 212 213 214 215 216 217 218 219 220 221
        x_var = self.block.var(current_op.output("Output")[0])
        y_var = self.block.var(bn_op.input("Bias")[0])
        out_var = self.block.var(bn_op.output("Y")[0])

        bias_op = self.block.insert_op(
            index,
            type="elementwise_add",
            inputs={"X": x_var,
                    "Y": y_var},
            outputs={"Out": out_var},
            attrs={"axis": 1})  # dim_start=1
222 223
        return bias_op

224
    def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
225 226
        '''
        fuse the batch_norm_op' parameters to current_op (conv or fc)
227

228 229 230 231 232 233
        :param current_op: current operator (conv or fc)
        :type current_op: Operator
        :param bn_op: batch norm operator
        :type bn_op: Operator
        :param bias_op: elementwise_add operator for adding bias
        :type bias_op: Operator
234
        :param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
235
        :type with_bias: Int
236 237
        '''

L
Luo Tao 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        def _update_param(op, old_param_name, new_param):
            # For the sake of remaining the original variables the same as before,
            # create new variables in scope to store the new parameters.
            old_param_name = old_param_name[0]
            old_var = self.block.vars[old_param_name]
            new_param_name = old_param_name + '_fuse_bn'
            new_var = self.block.create_parameter(
                name=new_param_name.encode('ascii'),
                type=old_var.type,
                dtype=old_var.dtype,
                shape=old_var.shape)
            op.rename_input(old_param_name, new_param_name)
            self.scope.var(new_param_name)

            tensor = self.scope.find_var(new_param_name).get_tensor()
            tensor.set(np.array(new_param), self.place)
254 255

        def _load_param(param_name):
L
Luo Tao 已提交
256
            return np.array(self.scope.find_var(param_name[0]).get_tensor())
257 258 259 260 261 262 263 264 265 266 267 268

        bias_bn = _load_param(bn_op.input("Bias"))  #Bias
        scale_bn = _load_param(bn_op.input("Scale"))  #Scale
        mean_bn = _load_param(bn_op.input("Mean"))  #Mean
        var_bn = _load_param(bn_op.input("Variance"))  #Variance

        # TODO(luotao1): consider only conv2d now. fc would be delt later.
        current_param = _load_param(current_op.input("Filter"))
        std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5)))
        tmp = np.float32(np.divide(scale_bn, std_bn))

        # add bias of batch_norm_op to conv2d
269 270 271 272
        if with_bias:
            bias = _load_param(bias_op.input("Y"))
        else:
            bias = np.zeros(bias_bn.shape)
273 274 275 276 277 278 279 280 281
        bias = np.float32(
            np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn))

        # re-compute weight of conv2d
        tmp = tmp.reshape(tmp.shape[0], -1)
        dst_param = current_param.reshape((tmp.shape[0], -1))
        dst_param = np.float32(np.multiply(dst_param, tmp))
        dst_param = dst_param.reshape(current_param.shape)

L
Luo Tao 已提交
282 283 284
        # update parameters
        _update_param(current_op, current_op.input("Filter"), dst_param)
        _update_param(bias_op, bias_op.input("Y"), bias)
285

286 287 288
        # collect the renamed input
        self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]

289
    def _adjust_input(self):
290 291 292 293 294 295 296
        for i in range(len(self.block.ops)):
            current_op = self.block.ops[i]
            for input_arg in current_op.input_arg_names:
                if input_arg in self.input_map:
                    current_op.rename_input(input_arg,
                                            self.input_map[input_arg])

297 298
    def _remove_unused_var(self):
        '''
299
        remove unused varibles in program
300 301
        '''
        args = []
302 303 304 305
        for i in range(len(self.block.ops)):
            current_op = self.block.ops[i]
            args += current_op.input_arg_names
            args += current_op.output_arg_names
306 307
        args = list(set(args))  # unique the input and output arguments

308 309 310
        for var in self.block.vars.keys():
            if var not in args:
                self.block.remove_var(var)