resnext101_wsl.py 13.8 KB
Newer Older
W
WuHaobo 已提交
1
import paddle
littletomatodonkey's avatar
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import Uniform

__all__ = [
    "ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl", "ResNeXt101_32x32d_wsl",
    "ResNeXt101_32x48d_wsl"
]


class ConvBNLayer(nn.Layer):
    def __init__(self,
                 input_channels,
                 output_channels,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None,
                 name=None):
24
        super(ConvBNLayer, self).__init__()
W
WuHaobo 已提交
25
        if "downsample" in name:
26
            conv_name = name + ".0"
W
WuHaobo 已提交
27
        else:
littletomatodonkey's avatar
littletomatodonkey 已提交
28 29 30 31 32 33 34 35 36 37
            conv_name = name
        self._conv = Conv2d(
            in_channels=input_channels,
            out_channels=output_channels,
            kernel_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            weight_attr=ParamAttr(name=conv_name + ".weight"),
            bias_attr=False)
W
WuHaobo 已提交
38
        if "downsample" in name:
39
            bn_name = name[:9] + "downsample.1"
W
WuHaobo 已提交
40 41
        else:
            if "conv1" == name:
42
                bn_name = "bn" + name[-1]
W
WuHaobo 已提交
43
            else:
littletomatodonkey's avatar
littletomatodonkey 已提交
44 45 46 47 48 49 50 51 52
                bn_name = (name[:10] if name[7:9].isdigit() else name[:9]
                           ) + "bn" + name[-1]
        self._bn = BatchNorm(
            num_channels=output_channels,
            act=act,
            param_attr=ParamAttr(name=bn_name + ".weight"),
            bias_attr=ParamAttr(name=bn_name + ".bias"),
            moving_mean_name=bn_name + ".running_mean",
            moving_variance_name=bn_name + ".running_var")
53 54 55 56 57 58

    def forward(self, inputs):
        x = self._conv(inputs)
        x = self._bn(x)
        return x

littletomatodonkey's avatar
littletomatodonkey 已提交
59 60

class ShortCut(nn.Layer):
61
    def __init__(self, input_channels, output_channels, stride, name=None):
W
wqz960 已提交
62
        super(ShortCut, self).__init__()
63 64 65 66

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
littletomatodonkey's avatar
littletomatodonkey 已提交
67
        if input_channels != output_channels or stride != 1:
68
            self._conv = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
69 70 71 72 73
                input_channels,
                output_channels,
                filter_size=1,
                stride=stride,
                name=name)
74 75

    def forward(self, inputs):
littletomatodonkey's avatar
littletomatodonkey 已提交
76
        if self.input_channels != self.output_channels or self.stride != 1:
77
            return self._conv(inputs)
littletomatodonkey's avatar
littletomatodonkey 已提交
78
        return inputs
79

littletomatodonkey's avatar
littletomatodonkey 已提交
80 81 82 83

class BottleneckBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, stride, cardinality,
                 width, name):
W
wqz960 已提交
84
        super(BottleneckBlock, self).__init__()
85 86

        self._conv0 = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
87 88 89 90 91
            input_channels,
            output_channels,
            filter_size=1,
            act="relu",
            name=name + ".conv1")
92
        self._conv1 = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
93 94 95 96 97 98 99
            output_channels,
            output_channels,
            filter_size=3,
            act="relu",
            stride=stride,
            groups=cardinality,
            name=name + ".conv2")
100
        self._conv2 = ConvBNLayer(
littletomatodonkey's avatar
littletomatodonkey 已提交
101 102 103 104 105
            output_channels,
            output_channels // (width // 8),
            filter_size=1,
            act=None,
            name=name + ".conv3")
W
wqz960 已提交
106
        self._short = ShortCut(
littletomatodonkey's avatar
littletomatodonkey 已提交
107 108 109 110
            input_channels,
            output_channels // (width // 8),
            stride=stride,
            name=name + ".downsample")
111 112 113 114 115 116

    def forward(self, inputs):
        x = self._conv0(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        y = self._short(inputs)
littletomatodonkey's avatar
littletomatodonkey 已提交
117 118
        return paddle.elementwise_add(x, y, act="relu")

119

littletomatodonkey's avatar
littletomatodonkey 已提交
120
class ResNeXt101WSL(nn.Layer):
121
    def __init__(self, layers=101, cardinality=32, width=48, class_dim=1000):
W
fix  
wqz960 已提交
122
        super(ResNeXt101WSL, self).__init__()
123 124 125 126 127 128

        self.class_dim = class_dim

        self.layers = layers
        self.cardinality = cardinality
        self.width = width
littletomatodonkey's avatar
littletomatodonkey 已提交
129
        self.scale = width // 8
130 131 132

        self.depth = [3, 4, 23, 3]
        self.base_width = cardinality * width
littletomatodonkey's avatar
littletomatodonkey 已提交
133 134
        num_filters = [self.base_width * i
                       for i in [1, 2, 4, 8]]  # [256, 512, 1024, 2048]
135 136
        self._conv_stem = ConvBNLayer(
            3, 64, 7, stride=2, act="relu", name="conv1")
littletomatodonkey's avatar
littletomatodonkey 已提交
137
        self._pool = MaxPool2d(kernel_size=3, stride=2, padding=1)
138

W
wqz960 已提交
139
        self._conv1_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
140 141 142 143 144 145
            64,
            num_filters[0],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer1.0")
W
wqz960 已提交
146
        self._conv1_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
147 148 149 150 151 152
            num_filters[0] // (width // 8),
            num_filters[0],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer1.1")
W
wqz960 已提交
153
        self._conv1_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
154 155 156 157 158 159
            num_filters[0] // (width // 8),
            num_filters[0],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer1.2")
160

W
wqz960 已提交
161
        self._conv2_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
162 163 164 165 166 167
            num_filters[0] // (width // 8),
            num_filters[1],
            stride=2,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.0")
W
wqz960 已提交
168
        self._conv2_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
169 170 171 172 173 174
            num_filters[1] // (width // 8),
            num_filters[1],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.1")
W
wqz960 已提交
175
        self._conv2_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
176 177 178 179 180 181
            num_filters[1] // (width // 8),
            num_filters[1],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.2")
W
wqz960 已提交
182
        self._conv2_3 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
183 184 185 186 187 188
            num_filters[1] // (width // 8),
            num_filters[1],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer2.3")
189

W
wqz960 已提交
190
        self._conv3_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
191 192 193 194 195 196
            num_filters[1] // (width // 8),
            num_filters[2],
            stride=2,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.0")
W
wqz960 已提交
197
        self._conv3_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
198 199 200 201 202 203
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.1")
W
wqz960 已提交
204
        self._conv3_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
205 206 207 208 209 210
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.2")
W
wqz960 已提交
211
        self._conv3_3 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
212 213 214 215 216 217
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.3")
W
wqz960 已提交
218
        self._conv3_4 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
219 220 221 222 223 224
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.4")
W
wqz960 已提交
225
        self._conv3_5 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
226 227 228 229 230 231
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.5")
W
wqz960 已提交
232
        self._conv3_6 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
233 234 235 236 237 238
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.6")
W
wqz960 已提交
239
        self._conv3_7 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
240 241 242 243 244 245
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.7")
W
wqz960 已提交
246
        self._conv3_8 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
247 248 249 250 251 252
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.8")
W
wqz960 已提交
253
        self._conv3_9 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
254 255 256 257 258 259
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.9")
W
wqz960 已提交
260
        self._conv3_10 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
261 262 263 264 265 266
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.10")
W
wqz960 已提交
267
        self._conv3_11 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
268 269 270 271 272 273
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.11")
W
wqz960 已提交
274
        self._conv3_12 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
275 276 277 278 279 280
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.12")
W
wqz960 已提交
281
        self._conv3_13 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
282 283 284 285 286 287
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.13")
W
wqz960 已提交
288
        self._conv3_14 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
289 290 291 292 293 294
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.14")
W
wqz960 已提交
295
        self._conv3_15 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
296 297 298 299 300 301
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.15")
W
wqz960 已提交
302
        self._conv3_16 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
303 304 305 306 307 308
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.16")
W
wqz960 已提交
309
        self._conv3_17 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
310 311 312 313 314 315
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.17")
W
wqz960 已提交
316
        self._conv3_18 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
317 318 319 320 321 322
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.18")
W
wqz960 已提交
323
        self._conv3_19 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
324 325 326 327 328 329
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.19")
W
wqz960 已提交
330
        self._conv3_20 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
331 332 333 334 335 336
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.20")
W
wqz960 已提交
337
        self._conv3_21 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
338 339 340 341 342 343
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.21")
W
wqz960 已提交
344
        self._conv3_22 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
345 346 347 348 349 350
            num_filters[2] // (width // 8),
            num_filters[2],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer3.22")
351

W
wqz960 已提交
352
        self._conv4_0 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
353 354 355 356 357 358
            num_filters[2] // (width // 8),
            num_filters[3],
            stride=2,
            cardinality=self.cardinality,
            width=self.width,
            name="layer4.0")
W
wqz960 已提交
359
        self._conv4_1 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
360 361 362 363 364 365
            num_filters[3] // (width // 8),
            num_filters[3],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer4.1")
W
wqz960 已提交
366
        self._conv4_2 = BottleneckBlock(
littletomatodonkey's avatar
littletomatodonkey 已提交
367 368 369 370 371 372
            num_filters[3] // (width // 8),
            num_filters[3],
            stride=1,
            cardinality=self.cardinality,
            width=self.width,
            name="layer4.2")
373

littletomatodonkey's avatar
littletomatodonkey 已提交
374 375 376 377 378 379
        self._avg_pool = AdaptiveAvgPool2d(1)
        self._out = Linear(
            num_filters[3] // (width // 8),
            class_dim,
            weight_attr=ParamAttr(name="fc.weight"),
            bias_attr=ParamAttr(name="fc.bias"))
380 381 382 383 384 385 386 387

    def forward(self, inputs):
        x = self._conv_stem(inputs)
        x = self._pool(x)

        x = self._conv1_0(x)
        x = self._conv1_1(x)
        x = self._conv1_2(x)
W
WuHaobo 已提交
388

389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
        x = self._conv2_0(x)
        x = self._conv2_1(x)
        x = self._conv2_2(x)
        x = self._conv2_3(x)

        x = self._conv3_0(x)
        x = self._conv3_1(x)
        x = self._conv3_2(x)
        x = self._conv3_3(x)
        x = self._conv3_4(x)
        x = self._conv3_5(x)
        x = self._conv3_6(x)
        x = self._conv3_7(x)
        x = self._conv3_8(x)
        x = self._conv3_9(x)
        x = self._conv3_10(x)
        x = self._conv3_11(x)
        x = self._conv3_12(x)
        x = self._conv3_13(x)
        x = self._conv3_14(x)
        x = self._conv3_15(x)
        x = self._conv3_16(x)
        x = self._conv3_17(x)
        x = self._conv3_18(x)
        x = self._conv3_19(x)
        x = self._conv3_20(x)
        x = self._conv3_21(x)
        x = self._conv3_22(x)

        x = self._conv4_0(x)
        x = self._conv4_1(x)
        x = self._conv4_2(x)

        x = self._avg_pool(x)
littletomatodonkey's avatar
littletomatodonkey 已提交
423
        x = paddle.squeeze(x, axis=[2, 3])
424 425
        x = self._out(x)
        return x
W
WuHaobo 已提交
426

littletomatodonkey's avatar
littletomatodonkey 已提交
427

W
wqz960 已提交
428 429
def ResNeXt101_32x8d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=8, **args)
littletomatodonkey's avatar
littletomatodonkey 已提交
430 431
    return model

W
WuHaobo 已提交
432

W
wqz960 已提交
433 434
def ResNeXt101_32x16d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=16, **args)
littletomatodonkey's avatar
littletomatodonkey 已提交
435 436
    return model

W
WuHaobo 已提交
437

W
wqz960 已提交
438 439
def ResNeXt101_32x32d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=32, **args)
littletomatodonkey's avatar
littletomatodonkey 已提交
440 441
    return model

W
WuHaobo 已提交
442

W
wqz960 已提交
443 444
def ResNeXt101_32x48d_wsl(**args):
    model = ResNeXt101WSL(cardinality=32, width=48, **args)
littletomatodonkey's avatar
littletomatodonkey 已提交
445
    return model