inference_transpiler.py 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
import os
18
import numpy as np
19 20 21
from .. import core
from ..framework import Program
from ..executor import global_scope
22 23


24
class InferenceTranspiler(object):
L
Luo Tao 已提交
25
    '''
26 27 28 29 30 31
    Convert the fluid program to optimized inference program.

    There are several optimizations:

      - fuse convolution and batch normalization
      - fuse batch normalization and relu (MKLDNN only)
L
Luo Tao 已提交
32 33

    Examples:
34

L
Luo Tao 已提交
35 36 37 38 39 40 41 42 43
    .. code-block:: python

        # As InferenceTranspiler will modify the original program,
        # please clone before use it.
        inference_transpiler_program = program.clone()
        t = fluid.InferenceTranspiler()
        t.transpile(inference_transpiler_program, place)
    '''

L
Luo Tao 已提交
44
    def transpile(self, program, place, scope=None):
45
        '''
L
Luo Tao 已提交
46 47 48 49 50 51
        Run the transpiler.

        Args:
            program (Program): program to transpile
            place (Place): inference place
            scope (Scope|None): inference Scope
L
Luo Tao 已提交
52
        '''
L
Luo Tao 已提交
53 54 55 56 57 58 59 60 61
        if not isinstance(program, Program):
            raise TypeError("program should be as Program type")
        if not isinstance(place, core.CPUPlace) and not isinstance(
                place, core.CUDAPlace):
            raise TypeError("place should be as CPUPlace/CUDAPlace type")
        if scope is None:
            scope = global_scope()
        if not isinstance(scope, core.Scope):
            raise TypeError("scope should be as Scope type or None")
W
Wu Yi 已提交
62 63
        self._fuse_batch_norm(program, place, scope)
        self._fuse_relu_mkldnn(program)
64

W
Wu Yi 已提交
65
    def _fuse_relu_mkldnn(self, program):
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        '''
        Transpile the program by fused relu activation for MKLDNN program.

        Relu activation following batch norm OP can be fused by adding
        :math:`fuse_with_relu` attribute to batch norm OP.

        The result of fuse is:

        - before:

          - batch_norm->relu->any_other_op

        - after:

          - batch_norm->any_other_op

        :param program: program to transpile
        :type program: Program
        '''
        use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
        if not use_mkldnn:
            return

        self.block = program.block(0)

        i = 0
        while i < len(self.block.ops) - 1:
            current_op = self.block.ops[i]
            if current_op.type in ['batch_norm']:
                next_op = self.block.ops[i + 1]
                if next_op.type == 'relu':
                    # modify bnorm OP to include relu
                    current_op.set_attr("fuse_with_relu", True)
                    # remove relu OP
W
Wu Yi 已提交
100
                    self.block._remove_op(i + 1)
101 102 103 104 105 106 107
            i = i + 1

        self._remove_unused_var()
        # TODO(luotao): use clone() method to flush the program.desc in force,
        # since some large program.desc will not be flushed immediately.
        # And a better solution will be considered later.
        program = program.clone()
L
Luo Tao 已提交
108

W
Wu Yi 已提交
109
    def _fuse_batch_norm(self, program, place, scope):
L
Luo Tao 已提交
110 111
        '''
        Transpile the program by fused batch normalization.
112 113 114

        The batch normalization followed the convolution or fully connected layer
        can be integrated with them. Doing so will give us a forward acceleration,
115
        especially in environments like mobile or embedded.
116

L
Luo Tao 已提交
117 118
        For input :math:`X`:

119 120
        - Conv process:        :math:`X = input * W + bias`
        - Batch norm process:  :math:`X' = (X - mean) / std`
L
Luo Tao 已提交
121
        - Scale Process:       :math:`Y = a * X' + b`
122 123 124

        After fuse into one operation:

L
Luo Tao 已提交
125 126 127 128
        .. math::

            Y &= (input * W + bias - mean) / std * a + b \\\\
              &= input * a * W / std + ((bias - mean) / std * a + b)
129

130
        The operator transformation is:
L
Luo Tao 已提交
131

132
        - before:
L
Luo Tao 已提交
133

134 135
          - conv->batch_norm->any_other_op (bias == 0)
          - conv->elementwise_add->batch_norm->any_other_op (bias != 0)
136 137

        - after:
L
Luo Tao 已提交
138

139
          - conv->elementwise_add->any_other_op
140

141
        The transpile stages are:
L
Luo Tao 已提交
142

143
        1. insert elementwise_add op when bias == 0.
144
        2. fuse the batch_norm's parameters to conv and elementwise_add operators.
145 146 147
        3. remove batch_norm ops which are not used in any other ops.
        4. adjust the input of any_other_op to be the output of elementwise_add operator.
        5. remove unused variables.
148

L
Luo Tao 已提交
149 150 151 152
        Args:
            program (Program): program to transpile
            place (Place): inference place
            scope (Scope): inference Scope
153

154 155 156
        '''
        self.scope = scope
        self.place = place
157
        self.block = program.block(0)
158
        self.input_map = {}  # store the input names should be adjusted
159

160
        i = 0
161
        while i < len(self.block.ops) - 2:
162
            current_op = self.block.ops[i]
163
            # TODO(luotao1): consider only conv2d now. fc would be delt later.
164
            if current_op.type in ['conv2d']:
165 166
                # TODO(luotao1): consider single chain network now.
                # For branch network, we counldn't use block.ops[i + 1] as
L
Luo Tao 已提交
167
                # the judgment condition.
168
                next_op = self.block.ops[i + 1]
169
                # conv2d without bias
170
                if (next_op.type == 'batch_norm'):
171 172 173
                    # insert bias op
                    bias_op = self._insert_bias_op(i + 1, current_op, next_op)
                    # fuse batch_norm
174
                    self._fuse_param(current_op, next_op, bias_op, 0)
175
                    # remove batch_norm_op
W
Wu Yi 已提交
176
                    self.block._remove_op(i + 2)
177
                    i = i + 1
178 179 180 181 182 183 184
                # conv2d with bias, the next_op.type is elementwise_add
                elif (next_op.type == 'elementwise_add'):
                    next_next_op = self.block.ops[i + 2]
                    if (next_next_op.type == 'batch_norm'):
                        # fuse batch_norm
                        self._fuse_param(current_op, next_next_op, next_op, 1)
                        # remove batch_norm_op
W
Wu Yi 已提交
185
                        self.block._remove_op(i + 2)
186
                        i = i + 1
187 188
            i = i + 1

189
        self._adjust_input()
190
        self._remove_unused_var()
191 192
        # TODO(luotao): use clone() method to flush the program.desc in force,
        # since some large program.desc will not be flushed immediately.
L
Luo Tao 已提交
193
        # And a better solution will be considered later.
L
Luo Tao 已提交
194
        program = program.clone()
195 196 197 198

    # ====================== private transpiler functions =====================
    def _insert_bias_op(self, index, current_op, bn_op):
        '''
199
        Construct elementwise_add operator for adding bias
200
        and insert it into program.
201

202 203 204 205 206 207 208 209 210 211 212
        :param index: insert location of bias_op
        :type index: Int
        :param current_op: current operator (conv or fc)
        :type current_op: Operator
        :param bn_op: batch norm operator
        :type bn_op: Operator
        :return: bias_op
        :rtype: Operator
        '''
        # The input of bias_op is current_op's output and Bias of bn_op
        # The output of bias_op is bn_op's output
213 214 215 216
        x_var = self.block.var(current_op.output("Output")[0])
        y_var = self.block.var(bn_op.input("Bias")[0])
        out_var = self.block.var(bn_op.output("Y")[0])

W
Wu Yi 已提交
217
        bias_op = self.block._insert_op(
218 219 220 221 222 223
            index,
            type="elementwise_add",
            inputs={"X": x_var,
                    "Y": y_var},
            outputs={"Out": out_var},
            attrs={"axis": 1})  # dim_start=1
224 225
        return bias_op

226
    def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
227 228
        '''
        fuse the batch_norm_op' parameters to current_op (conv or fc)
229

230 231 232 233 234 235
        :param current_op: current operator (conv or fc)
        :type current_op: Operator
        :param bn_op: batch norm operator
        :type bn_op: Operator
        :param bias_op: elementwise_add operator for adding bias
        :type bias_op: Operator
236
        :param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
237
        :type with_bias: Int
238 239
        '''

L
Luo Tao 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
        def _update_param(op, old_param_name, new_param):
            # For the sake of remaining the original variables the same as before,
            # create new variables in scope to store the new parameters.
            old_param_name = old_param_name[0]
            old_var = self.block.vars[old_param_name]
            new_param_name = old_param_name + '_fuse_bn'
            new_var = self.block.create_parameter(
                name=new_param_name.encode('ascii'),
                type=old_var.type,
                dtype=old_var.dtype,
                shape=old_var.shape)
            op.rename_input(old_param_name, new_param_name)
            self.scope.var(new_param_name)

            tensor = self.scope.find_var(new_param_name).get_tensor()
            tensor.set(np.array(new_param), self.place)
256 257

        def _load_param(param_name):
L
Luo Tao 已提交
258
            return np.array(self.scope.find_var(param_name[0]).get_tensor())
259 260 261 262 263 264 265 266 267 268 269 270

        bias_bn = _load_param(bn_op.input("Bias"))  #Bias
        scale_bn = _load_param(bn_op.input("Scale"))  #Scale
        mean_bn = _load_param(bn_op.input("Mean"))  #Mean
        var_bn = _load_param(bn_op.input("Variance"))  #Variance

        # TODO(luotao1): consider only conv2d now. fc would be delt later.
        current_param = _load_param(current_op.input("Filter"))
        std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5)))
        tmp = np.float32(np.divide(scale_bn, std_bn))

        # add bias of batch_norm_op to conv2d
271 272 273 274
        if with_bias:
            bias = _load_param(bias_op.input("Y"))
        else:
            bias = np.zeros(bias_bn.shape)
275 276 277 278 279 280 281 282 283
        bias = np.float32(
            np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn))

        # re-compute weight of conv2d
        tmp = tmp.reshape(tmp.shape[0], -1)
        dst_param = current_param.reshape((tmp.shape[0], -1))
        dst_param = np.float32(np.multiply(dst_param, tmp))
        dst_param = dst_param.reshape(current_param.shape)

L
Luo Tao 已提交
284 285 286
        # update parameters
        _update_param(current_op, current_op.input("Filter"), dst_param)
        _update_param(bias_op, bias_op.input("Y"), bias)
287

288 289 290
        # collect the renamed input
        self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]

291
    def _adjust_input(self):
292 293 294 295 296 297 298
        for i in range(len(self.block.ops)):
            current_op = self.block.ops[i]
            for input_arg in current_op.input_arg_names:
                if input_arg in self.input_map:
                    current_op.rename_input(input_arg,
                                            self.input_map[input_arg])

299 300
    def _remove_unused_var(self):
        '''
301
        remove unused varibles in program
302 303
        '''
        args = []
304 305 306 307
        for i in range(len(self.block.ops)):
            current_op = self.block.ops[i]
            args += current_op.input_arg_names
            args += current_op.output_arg_names
308 309
        args = list(set(args))  # unique the input and output arguments

310
        for var in list(self.block.vars.keys()):
311
            if var not in args:
W
Wu Yi 已提交
312
                self.block._remove_var(var)