repvgg.py 13.3 KB
Newer Older
jm_12138's avatar
jm_12138 已提交
1 2 3 4
import paddle.nn as nn
import paddle
import numpy as np

C
cuicheng01 已提交
5 6
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

littletomatodonkey's avatar
littletomatodonkey 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
MODEL_URLS = {
    "RepVGG_A0":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A0_pretrained.pdparams",
    "RepVGG_A1":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A1_pretrained.pdparams",
    "RepVGG_A2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_A2_pretrained.pdparams",
    "RepVGG_B0":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B0_pretrained.pdparams",
    "RepVGG_B1":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1_pretrained.pdparams",
    "RepVGG_B2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2_pretrained.pdparams",
    "RepVGG_B3":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3_pretrained.pdparams",
    "RepVGG_B1g2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g2_pretrained.pdparams",
    "RepVGG_B1g4":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B1g4_pretrained.pdparams",
    "RepVGG_B2g2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g2_pretrained.pdparams",
    "RepVGG_B2g4":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B2g4_pretrained.pdparams",
    "RepVGG_B3g2":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g2_pretrained.pdparams",
    "RepVGG_B3g4":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/RepVGG_B3g4_pretrained.pdparams",
}
C
cuicheng01 已提交
35 36 37 38 39 40

__all__ = list(MODEL_URLS.keys())

optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}
jm_12138's avatar
jm_12138 已提交
41 42 43


class ConvBN(nn.Layer):
L
littletomatodonkey 已提交
44 45 46 47 48 49 50
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 groups=1):
jm_12138's avatar
jm_12138 已提交
51
        super(ConvBN, self).__init__()
L
littletomatodonkey 已提交
52 53 54 55 56 57 58 59
        self.conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            bias_attr=False)
jm_12138's avatar
jm_12138 已提交
60 61 62 63 64 65 66 67 68
        self.bn = nn.BatchNorm2D(num_features=out_channels)

    def forward(self, x):
        y = self.conv(x)
        y = self.bn(y)
        return y


class RepVGGBlock(nn.Layer):
L
littletomatodonkey 已提交
69 70 71 72 73 74 75 76 77
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 padding_mode='zeros'):
jm_12138's avatar
jm_12138 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        super(RepVGGBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode

        assert kernel_size == 3
        assert padding == 1

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.ReLU()

        self.rbr_identity = nn.BatchNorm2D(
L
littletomatodonkey 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
            num_features=in_channels
        ) if out_channels == in_channels and stride == 1 else None
        self.rbr_dense = ConvBN(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups)
        self.rbr_1x1 = ConvBN(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=stride,
            padding=padding_11,
            groups=groups)
jm_12138's avatar
jm_12138 已提交
112 113 114 115 116 117 118 119 120

    def forward(self, inputs):
        if not self.training:
            return self.nonlinearity(self.rbr_reparam(inputs))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)
L
littletomatodonkey 已提交
121 122
        return self.nonlinearity(
            self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
jm_12138's avatar
jm_12138 已提交
123 124 125

    def eval(self):
        if not hasattr(self, 'rbr_reparam'):
L
littletomatodonkey 已提交
126 127 128 129 130 131 132 133 134
            self.rbr_reparam = nn.Conv2D(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                groups=self.groups,
                padding_mode=self.padding_mode)
jm_12138's avatar
jm_12138 已提交
135 136 137 138 139 140 141 142 143 144 145
        self.training = False
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam.weight.set_value(kernel)
        self.rbr_reparam.bias.set_value(bias)
        for layer in self.sublayers():
            layer.eval()

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
L
littletomatodonkey 已提交
146 147
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(
            kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
jm_12138's avatar
jm_12138 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, ConvBN):
            kernel = branch.conv.weight
            running_mean = branch.bn._mean
            running_var = branch.bn._variance
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn._epsilon
        else:
            assert isinstance(branch, nn.BatchNorm2D)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros(
                    (self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = paddle.to_tensor(kernel_value)
            kernel = self.id_tensor
            running_mean = branch._mean
            running_var = branch._variance
            gamma = branch.weight
            beta = branch.bias
            eps = branch._epsilon
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape((-1, 1, 1, 1))
        return kernel * t, beta - running_mean * gamma / std


class RepVGG(nn.Layer):
L
littletomatodonkey 已提交
186 187 188 189
    def __init__(self,
                 num_blocks,
                 width_multiplier=None,
                 override_groups_map=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
190
                 class_num=1000):
jm_12138's avatar
jm_12138 已提交
191 192 193 194 195 196 197 198 199 200
        super(RepVGG, self).__init__()

        assert len(width_multiplier) == 4
        self.override_groups_map = override_groups_map or dict()

        assert 0 not in self.override_groups_map

        self.in_planes = min(64, int(64 * width_multiplier[0]))

        self.stage0 = RepVGGBlock(
L
littletomatodonkey 已提交
201 202 203 204 205
            in_channels=3,
            out_channels=self.in_planes,
            kernel_size=3,
            stride=2,
            padding=1)
jm_12138's avatar
jm_12138 已提交
206 207 208 209 210 211 212 213 214 215
        self.cur_layer_idx = 1
        self.stage1 = self._make_stage(
            int(64 * width_multiplier[0]), num_blocks[0], stride=2)
        self.stage2 = self._make_stage(
            int(128 * width_multiplier[1]), num_blocks[1], stride=2)
        self.stage3 = self._make_stage(
            int(256 * width_multiplier[2]), num_blocks[2], stride=2)
        self.stage4 = self._make_stage(
            int(512 * width_multiplier[3]), num_blocks[3], stride=2)
        self.gap = nn.AdaptiveAvgPool2D(output_size=1)
littletomatodonkey's avatar
littletomatodonkey 已提交
216
        self.linear = nn.Linear(int(512 * width_multiplier[3]), class_num)
jm_12138's avatar
jm_12138 已提交
217 218

    def _make_stage(self, planes, num_blocks, stride):
L
littletomatodonkey 已提交
219
        strides = [stride] + [1] * (num_blocks - 1)
jm_12138's avatar
jm_12138 已提交
220 221 222
        blocks = []
        for stride in strides:
            cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
L
littletomatodonkey 已提交
223 224 225 226 227 228 229 230
            blocks.append(
                RepVGGBlock(
                    in_channels=self.in_planes,
                    out_channels=planes,
                    kernel_size=3,
                    stride=stride,
                    padding=1,
                    groups=cur_groups))
jm_12138's avatar
jm_12138 已提交
231 232 233 234
            self.in_planes = planes
            self.cur_layer_idx += 1
        return nn.Sequential(*blocks)

L
littletomatodonkey 已提交
235 236 237 238 239 240
    def eval(self):
        self.training = False
        for layer in self.sublayers():
            layer.training = False
            layer.eval()

jm_12138's avatar
jm_12138 已提交
241 242 243 244 245 246 247 248 249 250 251 252
    def forward(self, x):
        out = self.stage0(x)
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = self.gap(out)
        out = paddle.flatten(out, start_axis=1)
        out = self.linear(out)
        return out


C
cuicheng01 已提交
253 254 255 256 257 258 259 260 261 262 263
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )
littletomatodonkey's avatar
littletomatodonkey 已提交
264

C
cuicheng01 已提交
265 266 267

def RepVGG_A0(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
268 269 270 271
        num_blocks=[2, 4, 14, 1],
        width_multiplier=[0.75, 0.75, 0.75, 2.5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
272 273
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_A0"], use_ssld=use_ssld)
C
cuicheng01 已提交
274
    return model
jm_12138's avatar
jm_12138 已提交
275 276


C
cuicheng01 已提交
277 278
def RepVGG_A1(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
279 280 281 282
        num_blocks=[2, 4, 14, 1],
        width_multiplier=[1, 1, 1, 2.5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
283 284
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_A1"], use_ssld=use_ssld)
C
cuicheng01 已提交
285
    return model
jm_12138's avatar
jm_12138 已提交
286 287


C
cuicheng01 已提交
288 289
def RepVGG_A2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
290 291 292 293
        num_blocks=[2, 4, 14, 1],
        width_multiplier=[1.5, 1.5, 1.5, 2.75],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
294 295
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_A2"], use_ssld=use_ssld)
C
cuicheng01 已提交
296
    return model
jm_12138's avatar
jm_12138 已提交
297 298


C
cuicheng01 已提交
299 300
def RepVGG_B0(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
301 302 303 304
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[1, 1, 1, 2.5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
305 306
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B0"], use_ssld=use_ssld)
C
cuicheng01 已提交
307
    return model
jm_12138's avatar
jm_12138 已提交
308 309


C
cuicheng01 已提交
310 311
def RepVGG_B1(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
312 313 314 315
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2, 2, 2, 4],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
316 317
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B1"], use_ssld=use_ssld)
C
cuicheng01 已提交
318
    return model
jm_12138's avatar
jm_12138 已提交
319 320


C
cuicheng01 已提交
321 322
def RepVGG_B1g2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
323 324 325 326
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2, 2, 2, 4],
        override_groups_map=g2_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
327 328
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B1g2"], use_ssld=use_ssld)
C
cuicheng01 已提交
329
    return model
jm_12138's avatar
jm_12138 已提交
330 331


C
cuicheng01 已提交
332 333
def RepVGG_B1g4(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
334 335 336 337
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2, 2, 2, 4],
        override_groups_map=g4_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
338 339
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B1g4"], use_ssld=use_ssld)
C
cuicheng01 已提交
340
    return model
jm_12138's avatar
jm_12138 已提交
341 342


C
cuicheng01 已提交
343 344
def RepVGG_B2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
345 346 347 348
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2.5, 2.5, 2.5, 5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
349 350
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B2"], use_ssld=use_ssld)
C
cuicheng01 已提交
351
    return model
jm_12138's avatar
jm_12138 已提交
352 353


C
cuicheng01 已提交
354 355
def RepVGG_B2g2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
356 357 358 359
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2.5, 2.5, 2.5, 5],
        override_groups_map=g2_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
360 361
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B2g2"], use_ssld=use_ssld)
C
cuicheng01 已提交
362
    return model
jm_12138's avatar
jm_12138 已提交
363 364


C
cuicheng01 已提交
365 366
def RepVGG_B2g4(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
367 368 369 370
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[2.5, 2.5, 2.5, 5],
        override_groups_map=g4_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
371 372
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B2g4"], use_ssld=use_ssld)
C
cuicheng01 已提交
373
    return model
jm_12138's avatar
jm_12138 已提交
374 375


C
cuicheng01 已提交
376 377
def RepVGG_B3(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
378 379 380 381
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[3, 3, 3, 5],
        override_groups_map=None,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
382 383
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B3"], use_ssld=use_ssld)
C
cuicheng01 已提交
384
    return model
jm_12138's avatar
jm_12138 已提交
385 386


C
cuicheng01 已提交
387 388
def RepVGG_B3g2(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
389 390 391 392
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[3, 3, 3, 5],
        override_groups_map=g2_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
393 394
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B3g2"], use_ssld=use_ssld)
C
cuicheng01 已提交
395
    return model
jm_12138's avatar
jm_12138 已提交
396 397


C
cuicheng01 已提交
398 399
def RepVGG_B3g4(pretrained=False, use_ssld=False, **kwargs):
    model = RepVGG(
L
littletomatodonkey 已提交
400 401 402 403
        num_blocks=[4, 6, 16, 1],
        width_multiplier=[3, 3, 3, 5],
        override_groups_map=g4_map,
        **kwargs)
littletomatodonkey's avatar
littletomatodonkey 已提交
404 405
    _load_pretrained(
        pretrained, model, MODEL_URLS["RepVGG_B3g4"], use_ssld=use_ssld)
C
cuicheng01 已提交
406
    return model