repvgg.py 13.9 KB
Newer Older
C
cuicheng01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Code was based on https://github.com/DingXiaoH/RepVGG

jm_12138's avatar
jm_12138 已提交
17 18 19 20
import paddle.nn as nn
import paddle
import numpy as np

C
cuicheng01 已提交
21 22
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

littletomatodonkey's avatar
littletomatodonkey 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
MODEL_URLS = {
    "RepVGG_A0":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A0_pretrained.pdparams",
    "RepVGG_A1":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A1_pretrained.pdparams",
    "RepVGG_A2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A2_pretrained.pdparams",
    "RepVGG_B0":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B0_pretrained.pdparams",
    "RepVGG_B1":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1_pretrained.pdparams",
    "RepVGG_B2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2_pretrained.pdparams",
    "RepVGG_B3":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3_pretrained.pdparams",
    "RepVGG_B1g2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g2_pretrained.pdparams",
    "RepVGG_B1g4":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g4_pretrained.pdparams",
    "RepVGG_B2g2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g2_pretrained.pdparams",
    "RepVGG_B2g4":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g4_pretrained.pdparams",
    "RepVGG_B3g2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g2_pretrained.pdparams",
    "RepVGG_B3g4":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g4_pretrained.pdparams",
}
C
cuicheng01 已提交
51 52 53 54 55 56

__all__ = list(MODEL_URLS.keys())

optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}
jm_12138's avatar
jm_12138 已提交
57 58 59


class ConvBN(nn.Layer):
L
littletomatodonkey 已提交
60 61 62 63 64 65 66
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 groups=1):
jm_12138's avatar
jm_12138 已提交
67
        super(ConvBN, self).__init__()
L
littletomatodonkey 已提交
68 69 70 71 72 73 74 75
        self.conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias_attr=False)
jm_12138's avatar
jm_12138 已提交
76 77 78 79 80 81 82 83 84
        self.bn = nn.BatchNorm2D(num_features=out_channels)

    def forward(self, x):
        y = self.conv(x)
        y = self.bn(y)
        return y


class RepVGGBlock(nn.Layer):
L
littletomatodonkey 已提交
85 86 87 88 89 90 91 92 93
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 padding_mode='zeros'):
jm_12138's avatar
jm_12138 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
        super(RepVGGBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode

        assert kernel_size == 3
        assert padding == 1

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.ReLU()

        self.rbr_identity = nn.BatchNorm2D(
L
littletomatodonkey 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            num_features=in_channels
        ) if out_channels == in_channels and stride == 1 else None
        self.rbr_dense = ConvBN(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups)
        self.rbr_1x1 = ConvBN(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=stride,
            padding=padding_11,
            groups=groups)
jm_12138's avatar
jm_12138 已提交
128 129 130 131 132 133 134 135 136

    def forward(self, inputs):
        if not self.training:
            return self.nonlinearity(self.rbr_reparam(inputs))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)
L
littletomatodonkey 已提交
137 138
        return self.nonlinearity(
            self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
jm_12138's avatar
jm_12138 已提交
139 140 141

    def eval(self):
        if not hasattr(self, 'rbr_reparam'):
L
littletomatodonkey 已提交
142 143 144 145 146 147 148 149 150
            self.rbr_reparam = nn.Conv2D(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                groups=self.groups,
                padding_mode=self.padding_mode)
jm_12138's avatar
jm_12138 已提交
151 152 153 154 155 156 157 158 159 160 161
        self.training = False
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam.weight.set_value(kernel)
        self.rbr_reparam.bias.set_value(bias)
        for layer in self.sublayers():
            layer.eval()

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
L
littletomatodonkey 已提交
162 163
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(
            kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
jm_12138's avatar
jm_12138 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, ConvBN):
            kernel = branch.conv.weight
            running_mean = branch.bn._mean
            running_var = branch.bn._variance
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn._epsilon
        else:
            assert isinstance(branch, nn.BatchNorm2D)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros(
                    (self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = paddle.to_tensor(kernel_value)
            kernel = self.id_tensor
            running_mean = branch._mean
            running_var = branch._variance
            gamma = branch.weight
            beta = branch.bias
            eps = branch._epsilon
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape((-1, 1, 1, 1))
        return kernel * t, beta - running_mean * gamma / std


class RepVGG(nn.Layer):
L
littletomatodonkey 已提交
202 203 204 205
    def __init__(self,
                 num_blocks,
                 width_multiplier=None,
                 override_groups_map=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
206
                 class_num=1000):
jm_12138's avatar
jm_12138 已提交
207 208 209 210 211 212 213 214 215 216
        super(RepVGG, self).__init__()

        assert len(width_multiplier) == 4
        self.override_groups_map = override_groups_map or dict()

        assert 0 not in self.override_groups_map

        self.in_planes = min(64, int(64 * width_multiplier[0]))

        self.stage0 = RepVGGBlock(
L
littletomatodonkey 已提交
217 218 219 220 221
            in_channels=3,
            out_channels=self.in_planes,
            kernel_size=3,
            stride=2,
            padding=1)
jm_12138's avatar
jm_12138 已提交
222 223 224 225 226 227 228 229 230 231
        self.cur_layer_idx = 1
        self.stage1 = self._make_stage(
            int(64 * width_multiplier[0]), num_blocks[0], stride=2)
        self.stage2 = self._make_stage(
            int(128 * width_multiplier[1]), num_blocks[1], stride=2)
        self.stage3 = self._make_stage(
            int(256 * width_multiplier[2]), num_blocks[2], stride=2)
        self.stage4 = self._make_stage(
            int(512 * width_multiplier[3]), num_blocks[3], stride=2)
        self.gap = nn.AdaptiveAvgPool2D(output_size=1)
littletomatodonkey's avatar
littletomatodonkey 已提交
232
        self.linear = nn.Linear(int(512 * width_multiplier[3]), class_num)
jm_12138's avatar
jm_12138 已提交
233 234

    def _make_stage(self, planes, num_blocks, stride):
L
littletomatodonkey 已提交
235
        strides = [stride] + [1] * (num_blocks - 1)
jm_12138's avatar
jm_12138 已提交
236 237 238
        blocks = []
        for stride in strides:
            cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
L
littletomatodonkey 已提交
239 240 241 242 243 244 245 246
            blocks.append(
                RepVGGBlock(
                    in_channels=self.in_planes,
                    out_channels=planes,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=cur_groups))
jm_12138's avatar
jm_12138 已提交
247 248 249 250
            self.in_planes = planes
            self.cur_layer_idx += 1
        return nn.Sequential(*blocks)

L
littletomatodonkey 已提交
251 252 253 254 255 256
    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

jm_12138's avatar
jm_12138 已提交
257 258 259 260 261 262 263 264 265 266 267 268
    def forward(self, x):
        out = self.stage0(x)
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = self.gap(out)
        out = paddle.flatten(out, start_axis=1)
        out = self.linear(out)
        return out


C
cuicheng01 已提交
269 270 271 272 273 274 275 276 277 278 279
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
littletomatodonkey's avatar
littletomatodonkey 已提交
280

C
cuicheng01 已提交
281 282 283

def RepVGG_A0(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
284 285 286 287
        num_blocks=[2, 4, 14, 1],
        width_multiplier=[0.75, 0.75, 0.75, 2.5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
288 289
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_A0"], use_ssld=use_ssld)
C
cuicheng01 已提交
290
    return model
jm_12138's avatar
jm_12138 已提交
291 292


C
cuicheng01 已提交
293 294
def RepVGG_A1(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
295 296 297 298
        num_blocks=[2, 4, 14, 1],
        width_multiplier=[1, 1, 1, 2.5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
299 300
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_A1"], use_ssld=use_ssld)
C
cuicheng01 已提交
301
    return model
jm_12138's avatar
jm_12138 已提交
302 303


C
cuicheng01 已提交
304 305
def RepVGG_A2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
306 307 308 309
        num_blocks=[2, 4, 14, 1],
        width_multiplier=[1.5, 1.5, 1.5, 2.75],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
310 311
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_A2"], use_ssld=use_ssld)
C
cuicheng01 已提交
312
    return model
jm_12138's avatar
jm_12138 已提交
313 314


C
cuicheng01 已提交
315 316
def RepVGG_B0(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
317 318 319 320
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[1, 1, 1, 2.5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
321 322
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B0"], use_ssld=use_ssld)
C
cuicheng01 已提交
323
    return model
jm_12138's avatar
jm_12138 已提交
324 325


C
cuicheng01 已提交
326 327
def RepVGG_B1(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
328 329 330 331
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2, 2, 2, 4],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
332 333
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B1"], use_ssld=use_ssld)
C
cuicheng01 已提交
334
    return model
jm_12138's avatar
jm_12138 已提交
335 336


C
cuicheng01 已提交
337 338
def RepVGG_B1g2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
339 340 341 342
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2, 2, 2, 4],
        override_groups_map=g2_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
343 344
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B1g2"], use_ssld=use_ssld)
C
cuicheng01 已提交
345
    return model
jm_12138's avatar
jm_12138 已提交
346 347


C
cuicheng01 已提交
348 349
def RepVGG_B1g4(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
350 351 352 353
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2, 2, 2, 4],
        override_groups_map=g4_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
354 355
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B1g4"], use_ssld=use_ssld)
C
cuicheng01 已提交
356
    return model
jm_12138's avatar
jm_12138 已提交
357 358


C
cuicheng01 已提交
359 360
def RepVGG_B2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
361 362 363 364
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2.5, 2.5, 2.5, 5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
365 366
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B2"], use_ssld=use_ssld)
C
cuicheng01 已提交
367
    return model
jm_12138's avatar
jm_12138 已提交
368 369


C
cuicheng01 已提交
370 371
def RepVGG_B2g2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
372 373 374 375
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2.5, 2.5, 2.5, 5],
        override_groups_map=g2_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
376 377
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B2g2"], use_ssld=use_ssld)
C
cuicheng01 已提交
378
    return model
jm_12138's avatar
jm_12138 已提交
379 380


C
cuicheng01 已提交
381 382
def RepVGG_B2g4(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
383 384 385 386
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2.5, 2.5, 2.5, 5],
        override_groups_map=g4_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
387 388
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B2g4"], use_ssld=use_ssld)
C
cuicheng01 已提交
389
    return model
jm_12138's avatar
jm_12138 已提交
390 391


C
cuicheng01 已提交
392 393
def RepVGG_B3(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
394 395 396 397
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[3, 3, 3, 5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
398 399
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B3"], use_ssld=use_ssld)
C
cuicheng01 已提交
400
    return model
jm_12138's avatar
jm_12138 已提交
401 402


C
cuicheng01 已提交
403 404
def RepVGG_B3g2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
405 406 407 408
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[3, 3, 3, 5],
        override_groups_map=g2_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
409 410
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B3g2"], use_ssld=use_ssld)
C
cuicheng01 已提交
411
    return model
jm_12138's avatar
jm_12138 已提交
412 413


C
cuicheng01 已提交
414 415
def RepVGG_B3g4(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
416 417 418 419
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[3, 3, 3, 5],
        override_groups_map=g4_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
420 421
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B3g4"], use_ssld=use_ssld)
C
cuicheng01 已提交
422
    return model