inference_transpiler.py 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
import shutil
from . import core


class InferenceTranspiler:
    def transpile(self, program, scope, place):
        '''
        Transpile the program to a inference program by fused batch normalization.
 
        The batch normalization followed the convolution or fully connected layer 
        can be integrated with them. Doing so will give us a forward acceleration, 
        especially in environments like mobile or embedded.
                    
        For input X:
        - Conv process:        X = input * W + bias 
        - Batch norm process:  X' = (X - mean) / std 
        - Scale Process:       Y = a * X' + b

        After fuse into one operation:

        Y = (input * W + bias - mean) / std * a + b
          = input * a * W / std + ((bias - mean) / std * a + b)

        The operator transformation is: 
        - before:
          - conv->batch_norm->any_other_op (bias == 0)
          - conv->elementwise_add->batch_norm->any_other_op (bias != 0)
        - after: 
          - conv->elementwise_add->any_other_op
        
        The transpile stages are:
48
        1. insert elementwise_add op when bias == 0.
49
        2. fuse the batch_norm's parameters to conv and elementwise_add operators.
50 51 52
        3. remove batch_norm ops which are not used in any other ops.
        4. adjust the input of any_other_op to be the output of elementwise_add operator.
        5. remove unused variables.
53 54 55 56 57 58 59 60 61 62 63 64

        :param program: program to transpile 
        :type program: Program
        :param scope: inference scope 
        :type scope: Scope
        :param place: inference place 
        :type place: Place
        :return: program by fused batch normalization
        :rtype: Program
        '''
        self.scope = scope
        self.place = place
65
        self.block = program.block(0)
66 67
        self.input_map = {}  # store the input names should be adjusted 

68
        i = 0
69 70
        while i < len(self.block.ops):
            current_op = self.block.ops[i]
71
            # TODO(luotao1): consider only conv2d now. fc would be delt later.
72 73
            if current_op.type in ['conv2d']:
                next_op = self.block.ops[i + 1]
74
                # conv2d without bias
75
                if (next_op.type == 'batch_norm'):
76 77 78
                    # insert bias op
                    bias_op = self._insert_bias_op(i + 1, current_op, next_op)
                    # fuse batch_norm
79
                    self._fuse_param(current_op, next_op, bias_op, 0)
80
                    # remove batch_norm_op
81
                    self.block.remove_op(i + 2)
82
                    i = i + 1
83 84 85 86 87 88 89 90 91
                # conv2d with bias, the next_op.type is elementwise_add
                elif (next_op.type == 'elementwise_add'):
                    next_next_op = self.block.ops[i + 2]
                    if (next_next_op.type == 'batch_norm'):
                        # fuse batch_norm
                        self._fuse_param(current_op, next_next_op, next_op, 1)
                        # remove batch_norm_op
                        self.block.remove_op(i + 2)
                        i = i + 1
92 93
            i = i + 1

94
        self._adjust_input()
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        self._remove_unused_var()
        return program

    # ====================== private transpiler functions =====================
    def _insert_bias_op(self, index, current_op, bn_op):
        '''
        Construct elementwise_add operator for adding bias 
        and insert it into program.
        
        :param index: insert location of bias_op
        :type index: Int
        :param current_op: current operator (conv or fc)
        :type current_op: Operator
        :param bn_op: batch norm operator
        :type bn_op: Operator
        :return: bias_op
        :rtype: Operator
        '''
        # The input of bias_op is current_op's output and Bias of bn_op
        # The output of bias_op is bn_op's output
115 116 117 118 119 120 121 122 123 124 125
        x_var = self.block.var(current_op.output("Output")[0])
        y_var = self.block.var(bn_op.input("Bias")[0])
        out_var = self.block.var(bn_op.output("Y")[0])

        bias_op = self.block.insert_op(
            index,
            type="elementwise_add",
            inputs={"X": x_var,
                    "Y": y_var},
            outputs={"Out": out_var},
            attrs={"axis": 1})  # dim_start=1
126 127
        return bias_op

128
    def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
129 130 131 132 133 134 135 136 137
        '''
        fuse the batch_norm_op' parameters to current_op (conv or fc)
        
        :param current_op: current operator (conv or fc)
        :type current_op: Operator
        :param bn_op: batch norm operator
        :type bn_op: Operator
        :param bias_op: elementwise_add operator for adding bias
        :type bias_op: Operator
138 139
        :param with_bias: If current operator has bias, with_bias = 1; otherwise 0. 
        :type with_bias: Int
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        '''

        def _load_tensor(param_name):
            return self.scope.find_var(param_name[0]).get_tensor()

        def _load_param(param_name):
            return np.array(_load_tensor(param_name))

        bias_bn = _load_param(bn_op.input("Bias"))  #Bias
        scale_bn = _load_param(bn_op.input("Scale"))  #Scale
        mean_bn = _load_param(bn_op.input("Mean"))  #Mean
        var_bn = _load_param(bn_op.input("Variance"))  #Variance

        # TODO(luotao1): consider only conv2d now. fc would be delt later.
        current_param = _load_param(current_op.input("Filter"))
        current_tensor = _load_tensor(current_op.input("Filter"))

        std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5)))
        tmp = np.float32(np.divide(scale_bn, std_bn))

        # add bias of batch_norm_op to conv2d
161 162 163 164
        if with_bias:
            bias = _load_param(bias_op.input("Y"))
        else:
            bias = np.zeros(bias_bn.shape)
165 166 167 168 169 170 171 172 173 174 175 176 177 178
        bias = np.float32(
            np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn))
        bias_tensor = _load_tensor(bias_op.input("Y"))
        bias_tensor.set(bias, self.place)

        # re-compute weight of conv2d
        tmp = tmp.reshape(tmp.shape[0], -1)
        dst_param = current_param.reshape((tmp.shape[0], -1))
        dst_param = np.float32(np.multiply(dst_param, tmp))
        dst_param = dst_param.reshape(current_param.shape)

        # set the updated parameters
        current_tensor.set(np.array(dst_param), self.place)

179 180 181 182 183 184 185 186 187 188 189
        # collect the renamed input
        self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]

    def _adjust_input(self):
        for i in range(len(self.block.ops)):
            current_op = self.block.ops[i]
            for input_arg in current_op.input_arg_names:
                if input_arg in self.input_map:
                    current_op.rename_input(input_arg,
                                            self.input_map[input_arg])

190 191
    def _remove_unused_var(self):
        '''
192
        remove unused varibles in program
193 194
        '''
        args = []
195 196 197 198
        for i in range(len(self.block.ops)):
            current_op = self.block.ops[i]
            args += current_op.input_arg_names
            args += current_op.output_arg_names
199 200
        args = list(set(args))  # unique the input and output arguments

201 202 203
        for var in self.block.vars.keys():
            if var not in args:
                self.block.remove_var(var)