resnext101_wsl.py 15.2 KB
Newer Older
W
WuHaobo 已提交
1
import paddle
littletomatodonkey's avatar
littletomatodonkey 已提交
2 3 4
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
5 6
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
littletomatodonkey's avatar
littletomatodonkey 已提交
7 8
from paddle.nn.initializer import Uniform

C
cuicheng01 已提交
9 10 11 12 13 14 15 16 17 18 19 20
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url

MODEL_URLS = {
              "ResNeXt101_32x8d_wsl": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x8d_wsl_pretrained.pdparams",
              "ResNeXt101_32x16d_wsl": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x816_wsl_pretrained.pdparams",
              "ResNeXt101_32x32d_wsl": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x32d_wsl_pretrained.pdparams",
              "ResNeXt101_32x48d_wsl": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNeXt101_32x48d_wsl_pretrained.pdparams",

             }

__all__ = list(MODEL_URLS.keys())

littletomatodonkey's avatar
littletomatodonkey 已提交
21 22 23 24 25 26 27 28 29 30 31


class ConvBNLayer(nn.Layer):
    def __init__(self,
                 input_channels,
                 output_channels,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None,
                 name=None):
32
        super(ConvBNLayer, self).__init__()
W
WuHaobo 已提交
33
        if "downsample" in name:
34
            conv_name = name + ".0"
W
WuHaobo 已提交
35
        else:
littletomatodonkey's avatar
littletomatodonkey 已提交
36
            conv_name = name
37
        self._conv = Conv2D(
littletomatodonkey's avatar
littletomatodonkey 已提交
38 39 40 41 42 43 44 45
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            weight_attr=ParamAttr(name=conv_name + ".weight"),
            bias_attr=False)
W
WuHaobo 已提交
46
        if "downsample" in name:
47
            bn_name = name[:9] + "downsample.1"
W
WuHaobo 已提交
48 49
        else:
            if "conv1" == name:
50
                bn_name = "bn" + name[-1]
W
WuHaobo 已提交
51
            else:
littletomatodonkey's avatar
littletomatodonkey 已提交
52 53 54 55 56 57 58 59 60
                bn_name = (name[:10] if name[7:9].isdigit() else name[:9]
                           ) + "bn" + name[-1]
        self._bn = BatchNorm(
            num_channels=output_channels,
            act=act,
            param_attr=ParamAttr(name=bn_name + ".weight"),
            bias_attr=ParamAttr(name=bn_name + ".bias"),
            moving_mean_name=bn_name + ".running_mean",
            moving_variance_name=bn_name + ".running_var")
61 62 63 64 65 66

    def forward(self, inputs):
        x = self._conv(inputs)
        x = self._bn(x)
        return x

littletomatodonkey's avatar
littletomatodonkey 已提交
67 68

class ShortCut(nn.Layer):
69
    def __init__(self, input_channels, output_channels, stride, name=None):
W
wqz960 已提交
70
        super(ShortCut, self).__init__()
71 72 73 74

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
littletomatodonkey's avatar
littletomatodonkey 已提交
75
        if input_channels != output_channels or stride != 1:
76
            self._conv = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
77 78 79 80 81
                input_channels,
                output_channels,
                filter_size=1,
                stride=stride,
                name=name)
82 83

    def forward(self, inputs):
littletomatodonkey's avatar
littletomatodonkey 已提交
84
        if self.input_channels != self.output_channels or self.stride != 1:
85
            return self._conv(inputs)
littletomatodonkey's avatar
littletomatodonkey 已提交
86
        return inputs
87

littletomatodonkey's avatar
littletomatodonkey 已提交
88 89 90 91

class BottleneckBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, stride, cardinality,
                 width, name):
W
wqz960 已提交
92
        super(BottleneckBlock, self).__init__()
93 94

        self._conv0 = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
95 96 97 98 99
            input_channels,
            output_channels,
            filter_size=1,
            act="relu",
            name=name + ".conv1")
100
        self._conv1 = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
101 102 103 104 105 106 107
            output_channels,
            output_channels,
            filter_size=3,
            act="relu",
            stride=stride,
            groups=cardinality,
            name=name + ".conv2")
108
        self._conv2 = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
109 110 111 112 113
            output_channels,
            output_channels // (width // 8),
            filter_size=1,
            act=None,
            name=name + ".conv3")
W
wqz960 已提交
114
        self._short = ShortCut(
littletomatodonkey's avatar
littletomatodonkey 已提交
115 116 117 118
            input_channels,
            output_channels // (width // 8),
            stride=stride,
            name=name + ".downsample")
119 120 121 122 123 124

    def forward(self, inputs):
        x = self._conv0(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        y = self._short(inputs)
125 126 127
        y = paddle.add(x, y)
        y = F.relu(y)
        return y
littletomatodonkey's avatar
littletomatodonkey 已提交
128

129

littletomatodonkey's avatar
littletomatodonkey 已提交
130
class ResNeXt101WSL(nn.Layer):
131
    def __init__(self, layers=101, cardinality=32, width=48, class_dim=1000):
W
fix  
wqz960 已提交
132
        super(ResNeXt101WSL, self).__init__()
133 134 135 136 137 138

        self.class_dim = class_dim

        self.layers = layers
        self.cardinality = cardinality
        self.width = width
littletomatodonkey's avatar
littletomatodonkey 已提交
139
        self.scale = width // 8
140 141 142

        self.depth = [3, 4, 23, 3]
        self.base_width = cardinality * width
littletomatodonkey's avatar
littletomatodonkey 已提交
143 144
        num_filters = [self.base_width * i
                       for i in [1, 2, 4, 8]]  # [256, 512, 1024, 2048]
145 146
        self._conv_stem = ConvBNLayer(
            3, 64, 7, stride=2, act="relu", name="conv1")
147
        self._pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
148

W
wqz960 已提交
149
        self._conv1_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
150 151 152 153 154 155
            64,
            num_filters[0],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer1.0")
W
wqz960 已提交
156
        self._conv1_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
157 158 159 160 161 162
            num_filters[0] // (width // 8),
            num_filters[0],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer1.1")
W
wqz960 已提交
163
        self._conv1_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
164 165 166 167 168 169
            num_filters[0] // (width // 8),
            num_filters[0],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer1.2")
170

W
wqz960 已提交
171
        self._conv2_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
172 173 174 175 176 177
            num_filters[0] // (width // 8),
            num_filters[1],
            stride=2,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.0")
W
wqz960 已提交
178
        self._conv2_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
179 180 181 182 183 184
            num_filters[1] // (width // 8),
            num_filters[1],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.1")
W
wqz960 已提交
185
        self._conv2_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
186 187 188 189 190 191
            num_filters[1] // (width // 8),
            num_filters[1],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.2")
W
wqz960 已提交
192
        self._conv2_3 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
193 194 195 196 197 198
            num_filters[1] // (width // 8),
            num_filters[1],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.3")
199

W
wqz960 已提交
200
        self._conv3_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
201 202 203 204 205 206
            num_filters[1] // (width // 8),
            num_filters[2],
            stride=2,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.0")
W
wqz960 已提交
207
        self._conv3_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
208 209 210 211 212 213
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.1")
W
wqz960 已提交
214
        self._conv3_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
215 216 217 218 219 220
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.2")
W
wqz960 已提交
221
        self._conv3_3 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
222 223 224 225 226 227
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.3")
W
wqz960 已提交
228
        self._conv3_4 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
229 230 231 232 233 234
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.4")
W
wqz960 已提交
235
        self._conv3_5 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
236 237 238 239 240 241
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.5")
W
wqz960 已提交
242
        self._conv3_6 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
243 244 245 246 247 248
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.6")
W
wqz960 已提交
249
        self._conv3_7 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
250 251 252 253 254 255
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.7")
W
wqz960 已提交
256
        self._conv3_8 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
257 258 259 260 261 262
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.8")
W
wqz960 已提交
263
        self._conv3_9 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
264 265 266 267 268 269
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.9")
W
wqz960 已提交
270
        self._conv3_10 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
271 272 273 274 275 276
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.10")
W
wqz960 已提交
277
        self._conv3_11 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
278 279 280 281 282 283
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.11")
W
wqz960 已提交
284
        self._conv3_12 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
285 286 287 288 289 290
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.12")
W
wqz960 已提交
291
        self._conv3_13 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
292 293 294 295 296 297
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.13")
W
wqz960 已提交
298
        self._conv3_14 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
299 300 301 302 303 304
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.14")
W
wqz960 已提交
305
        self._conv3_15 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
306 307 308 309 310 311
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.15")
W
wqz960 已提交
312
        self._conv3_16 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
313 314 315 316 317 318
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.16")
W
wqz960 已提交
319
        self._conv3_17 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
320 321 322 323 324 325
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.17")
W
wqz960 已提交
326
        self._conv3_18 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
327 328 329 330 331 332
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.18")
W
wqz960 已提交
333
        self._conv3_19 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
334 335 336 337 338 339
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.19")
W
wqz960 已提交
340
        self._conv3_20 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
341 342 343 344 345 346
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.20")
W
wqz960 已提交
347
        self._conv3_21 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
348 349 350 351 352 353
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.21")
W
wqz960 已提交
354
        self._conv3_22 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
355 356 357 358 359 360
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.22")
361

W
wqz960 已提交
362
        self._conv4_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
363 364 365 366 367 368
            num_filters[2] // (width // 8),
            num_filters[3],
            stride=2,
            cardinality=self.cardinality,
            width=self.width,
            name="layer4.0")
W
wqz960 已提交
369
        self._conv4_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
370 371 372 373 374 375
            num_filters[3] // (width // 8),
            num_filters[3],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer4.1")
W
wqz960 已提交
376
        self._conv4_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
377 378 379 380 381 382
            num_filters[3] // (width // 8),
            num_filters[3],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer4.2")
383

384
        self._avg_pool = AdaptiveAvgPool2D(1)
littletomatodonkey's avatar
littletomatodonkey 已提交
385 386 387 388 389
        self._out = Linear(
            num_filters[3] // (width // 8),
            class_dim,
            weight_attr=ParamAttr(name="fc.weight"),
            bias_attr=ParamAttr(name="fc.bias"))
390 391 392 393 394 395 396 397

    def forward(self, inputs):
        x = self._conv_stem(inputs)
        x = self._pool(x)

        x = self._conv1_0(x)
        x = self._conv1_1(x)
        x = self._conv1_2(x)
W
WuHaobo 已提交
398

399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
        x = self._conv2_0(x)
        x = self._conv2_1(x)
        x = self._conv2_2(x)
        x = self._conv2_3(x)

        x = self._conv3_0(x)
        x = self._conv3_1(x)
        x = self._conv3_2(x)
        x = self._conv3_3(x)
        x = self._conv3_4(x)
        x = self._conv3_5(x)
        x = self._conv3_6(x)
        x = self._conv3_7(x)
        x = self._conv3_8(x)
        x = self._conv3_9(x)
        x = self._conv3_10(x)
        x = self._conv3_11(x)
        x = self._conv3_12(x)
        x = self._conv3_13(x)
        x = self._conv3_14(x)
        x = self._conv3_15(x)
        x = self._conv3_16(x)
        x = self._conv3_17(x)
        x = self._conv3_18(x)
        x = self._conv3_19(x)
        x = self._conv3_20(x)
        x = self._conv3_21(x)
        x = self._conv3_22(x)

        x = self._conv4_0(x)
        x = self._conv4_1(x)
        x = self._conv4_2(x)

        x = self._avg_pool(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
433
        x = paddle.squeeze(x, axis=[2, 3])
434 435
        x = self._out(x)
        return x
W
WuHaobo 已提交
436

C
cuicheng01 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
    
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )


def ResNeXt101_32x8d_wsl(pretrained=False, use_ssld=False, **kwargs):
    model = ResNeXt101WSL(cardinality=32, width=8, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["ResNeXt101_32x8d_wsl"], use_ssld=use_ssld)
littletomatodonkey's avatar
littletomatodonkey 已提交
454 455
    return model

W
WuHaobo 已提交
456

W
wqz960 已提交
457
def ResNeXt101_32x16d_wsl(**args):
C
cuicheng01 已提交
458 459
    model = ResNeXt101WSL(cardinality=32, width=16, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["ResNeXt101_32x16d_ws"], use_ssld=use_ssld)
littletomatodonkey's avatar
littletomatodonkey 已提交
460 461
    return model

W
WuHaobo 已提交
462

W
wqz960 已提交
463
def ResNeXt101_32x32d_wsl(**args):
C
cuicheng01 已提交
464 465
    model = ResNeXt101WSL(cardinality=32, width=32, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["ResNeXt101_32x32d_wsl"], use_ssld=use_ssld)
littletomatodonkey's avatar
littletomatodonkey 已提交
466 467
    return model

W
WuHaobo 已提交
468

W
wqz960 已提交
469
def ResNeXt101_32x48d_wsl(**args):
C
cuicheng01 已提交
470 471
    model = ResNeXt101WSL(cardinality=32, width=48, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["ResNeXt101_32x48d_wsl"], use_ssld=use_ssld)
littletomatodonkey's avatar
littletomatodonkey 已提交
472
    return model