fuse_utils.py 6.9 KB
Newer Older
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
X
XGZhang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16

X
XGZhang 已提交
17 18
import paddle
import paddle.nn as nn
19

X
XGZhang 已提交
20 21 22 23 24 25 26
from . import utils


class Identity(nn.Layer):
    '''a layer to replace bn or relu layers'''

    def __init__(self, *args, **kwargs):
27
        super().__init__()
X
XGZhang 已提交
28 29 30 31 32

    def forward(self, input):
        return input


33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
def fuse_conv_bn(model):
    is_train = False
    if model.training:
        model.eval()
        is_train = True
    fuse_list = []
    tmp_pair = [None, None]
    for name, layer in model.named_sublayers():
        if isinstance(layer, nn.Conv2D):
            tmp_pair[0] = name
        if isinstance(layer, nn.BatchNorm2D):
            tmp_pair[1] = name

        if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
            fuse_list.append(tmp_pair)
            tmp_pair = [None, None]
    model = fuse_layers(model, fuse_list)
    if is_train:
        model.train()


X
XGZhang 已提交
54 55
def fuse_layers(model, layers_to_fuse, inplace=False):
    '''
56 57 58 59 60 61 62 63 64 65 66 67 68 69
    fuse layers in layers_to_fuse

    Args:
        model(paddle.nn.Layer): The model to be fused.
        layers_to_fuse(list): The layers' names to be fused. For
            example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
            A TypeError would be raised if "fuse" was set as
            True but "fuse_list" was None.
                              Default: None.
        inplace(bool): Whether apply fusing to the input model.
                       Default: False.

    Return
        fused_model(paddle.nn.Layer): The fused model.
X
XGZhang 已提交
70
    '''
71
    if inplace is False:
X
XGZhang 已提交
72 73 74 75 76 77 78 79 80 81 82
        model = copy.deepcopy(model)
    for layers in layers_to_fuse:
        _fuse_layers(model, layers)
    return model


def _fuse_layers(model, layers_list):
    '''fuse all the layers in layers_list'''
    layer_list = []
    for layer_name in layers_list:
        parent_layer, sub_name = utils.find_parent_layer_and_sub_name(
83 84
            model, layer_name
        )
X
XGZhang 已提交
85 86 87
        layer_list.append(getattr(parent_layer, sub_name))
    new_layers = _fuse_func(layer_list)
    for i, item in enumerate(layers_list):
88
        parent_layer, sub_name = utils.find_parent_layer_and_sub_name(
89 90
            model, item
        )
X
XGZhang 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        setattr(parent_layer, sub_name, new_layers[i])


def _fuse_func(layer_list):
    '''choose the fuser method and fuse layers'''
    types = tuple(type(m) for m in layer_list)
    fusion_method = types_to_fusion_method.get(types, None)
    new_layers = [None] * len(layer_list)
    fused_layer = fusion_method(*layer_list)
    for handle_id, pre_hook_fn in layer_list[0]._forward_pre_hooks.items():
        fused_layer.register_forward_pre_hook(pre_hook_fn)
        del layer_list[0]._forward_pre_hooks[handle_id]
    for handle_id, hook_fn in layer_list[-1]._forward_post_hooks.items():
        fused_layer.register_forward_post_hook(hook_fn)
        del layer_list[-1]._forward_post_hooks[handle_id]
    new_layers[0] = fused_layer
    for i in range(1, len(layer_list)):
        identity = Identity()
        identity.training = layer_list[0].training
        new_layers[i] = identity
    return new_layers


def _fuse_conv_bn(conv, bn):
    '''fuse conv and bn for train or eval'''
116 117 118
    assert (
        conv.training == bn.training
    ), "Conv and BN both must be in the same mode (train or eval)."
X
XGZhang 已提交
119
    if conv.training:
120 121 122
        assert (
            bn._num_features == conv._out_channels
        ), 'Output channel of Conv2d must match num_features of BatchNorm2d'
X
XGZhang 已提交
123 124 125 126 127 128 129
        raise NotImplementedError
    else:
        return _fuse_conv_bn_eval(conv, bn)


def _fuse_conv_bn_eval(conv, bn):
    '''fuse conv and bn for eval'''
130
    assert not (conv.training or bn.training), "Fusion only for eval!"
X
XGZhang 已提交
131 132
    fused_conv = copy.deepcopy(conv)

133 134 135 136 137 138 139 140 141
    fused_weight, fused_bias = _fuse_conv_bn_weights(
        fused_conv.weight,
        fused_conv.bias,
        bn._mean,
        bn._variance,
        bn._epsilon,
        bn.weight,
        bn.bias,
    )
X
XGZhang 已提交
142 143 144
    fused_conv.weight.set_value(fused_weight)
    if fused_conv.bias is None:
        fused_conv.bias = paddle.create_parameter(
145 146
            shape=[fused_conv._out_channels], is_bias=True, dtype=bn.bias.dtype
        )
X
XGZhang 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159
    fused_conv.bias.set_value(fused_bias)
    return fused_conv


def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    '''fuse weights and bias of conv and bn'''
    if conv_b is None:
        conv_b = paddle.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = paddle.ones_like(bn_rm)
    if bn_b is None:
        bn_b = paddle.zeros_like(bn_rm)
    bn_var_rsqrt = paddle.rsqrt(bn_rv + bn_eps)
160 161 162
    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(
        [-1] + [1] * (len(conv_w.shape) - 1)
    )
X
XGZhang 已提交
163 164 165 166 167 168
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
    return conv_w, conv_b


def _fuse_linear_bn(linear, bn):
    '''fuse linear and bn'''
169 170 171
    assert (
        linear.training == bn.training
    ), "Linear and BN both must be in the same mode (train or eval)."
X
XGZhang 已提交
172
    if linear.training:
173 174 175
        assert (
            bn._num_features == linear.weight.shape[1]
        ), 'Output channel of Linear must match num_features of BatchNorm'
X
XGZhang 已提交
176 177 178 179 180 181 182
        raise NotImplementedError
    else:
        return _fuse_linear_bn_eval(linear, bn)


def _fuse_linear_bn_eval(linear, bn):
    '''fuse linear and bn for eval'''
183
    assert not (linear.training or bn.training), "Fusion only for eval!"
X
XGZhang 已提交
184 185
    fused_linear = copy.deepcopy(linear)

186 187 188 189 190 191 192 193 194
    fused_weight, fused_bias = _fuse_linear_bn_weights(
        fused_linear.weight,
        fused_linear.bias,
        bn._mean,
        bn._variance,
        bn._epsilon,
        bn.weight,
        bn.bias,
    )
X
XGZhang 已提交
195 196 197 198 199
    fused_linear.weight.set_value(fused_weight)
    if fused_linear.bias is None:
        fused_linear.bias = paddle.create_parameter(
            shape=[fused_linear.weight.shape[1]],
            is_bias=True,
200 201
            dtype=bn.bias.dtype,
        )
X
XGZhang 已提交
202 203 204 205
    fused_linear.bias.set_value(fused_bias)
    return fused_linear


206 207 208
def _fuse_linear_bn_weights(
    linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
):
X
XGZhang 已提交
209 210 211 212 213 214 215 216 217 218 219 220 221
    '''fuse weights and bias of linear and bn'''
    if linear_b is None:
        linear_b = paddle.zeros_like(bn_rm)
    bn_scale = bn_w * paddle.rsqrt(bn_rv + bn_eps)
    fused_w = linear_w * bn_scale.unsqueeze(-1)
    fused_b = (linear_b - bn_rm) * bn_scale + bn_b
    return fused_w, fused_b


types_to_fusion_method = {
    (nn.Conv2D, nn.BatchNorm2D): _fuse_conv_bn,
    (nn.Linear, nn.BatchNorm1D): _fuse_linear_bn,
}