resnet.py 18.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

15
import math
16 17
from numbers import Integral

Q
qingqing01 已提交
18 19 20 21 22
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay
W
wangxinxin08 已提交
23
from paddle.nn.initializer import Uniform
F
Feng Ni 已提交
24
from ppdet.modeling.layers import DeformableConvV2
25 26
from .name_adapter import NameAdapter
from ..shape_spec import ShapeSpec
Q
qingqing01 已提交
27

W
wangxinxin08 已提交
28
__all__ = ['ResNet', 'Res5Head', 'Blocks', 'BasicBlock', 'BottleNeck']
29

30 31 32 33 34 35 36 37
ResNet_cfg = {
    18: [2, 2, 2, 2],
    34: [3, 4, 6, 3],
    50: [3, 4, 6, 3],
    101: [3, 4, 23, 3],
    152: [3, 8, 36, 3],
}

Q
qingqing01 已提交
38 39 40 41 42 43 44

class ConvNormLayer(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size,
                 stride,
45
                 groups=1,
Q
qingqing01 已提交
46 47 48 49 50
                 act=None,
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True,
                 lr=1.0,
W
wangxinxin08 已提交
51
                 dcn_v2=False):
Q
qingqing01 已提交
52 53 54 55 56
        super(ConvNormLayer, self).__init__()
        assert norm_type in ['bn', 'sync_bn']
        self.norm_type = norm_type
        self.act = act

F
Feng Ni 已提交
57
        if not dcn_v2:
58
            self.conv = nn.Conv2D(
F
Feng Ni 已提交
59 60 61 62 63
                in_channels=ch_in,
                out_channels=ch_out,
                kernel_size=filter_size,
                stride=stride,
                padding=(filter_size - 1) // 2,
64
                groups=groups,
W
wangxinxin08 已提交
65
                weight_attr=paddle.ParamAttr(learning_rate=lr),
F
Feng Ni 已提交
66 67 68 69 70 71 72 73
                bias_attr=False)
        else:
            self.conv = DeformableConvV2(
                in_channels=ch_in,
                out_channels=ch_out,
                kernel_size=filter_size,
                stride=stride,
                padding=(filter_size - 1) // 2,
74
                groups=groups,
W
wangxinxin08 已提交
75 76
                weight_attr=paddle.ParamAttr(learning_rate=lr),
                bias_attr=False)
Q
qingqing01 已提交
77 78

        norm_lr = 0. if freeze_norm else lr
79
        param_attr = paddle.ParamAttr(
Q
qingqing01 已提交
80 81 82
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            trainable=False if freeze_norm else True)
83
        bias_attr = paddle.ParamAttr(
Q
qingqing01 已提交
84 85 86 87 88
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            trainable=False if freeze_norm else True)

        global_stats = True if freeze_norm else False
89 90 91 92 93 94 95 96 97
        if norm_type == 'sync_bn':
            self.norm = nn.SyncBatchNorm(
                ch_out, weight_attr=param_attr, bias_attr=bias_attr)
        else:
            self.norm = nn.BatchNorm(
                ch_out,
                act=None,
                param_attr=param_attr,
                bias_attr=bias_attr,
98
                use_global_stats=global_stats)
Q
qingqing01 已提交
99 100 101 102 103 104 105 106
        norm_params = self.norm.parameters()

        if freeze_norm:
            for param in norm_params:
                param.stop_gradient = True

    def forward(self, inputs):
        out = self.conv(inputs)
107
        if self.norm_type in ['bn', 'sync_bn']:
Q
qingqing01 已提交
108
            out = self.norm(out)
109 110 111 112 113
        if self.act:
            out = getattr(F, self.act)(out)
        return out


W
wangxinxin08 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
class SELayer(nn.Layer):
    def __init__(self, ch, reduction_ratio=16):
        super(SELayer, self).__init__()
        self.pool = nn.AdaptiveAvgPool2D(1)
        stdv = 1.0 / math.sqrt(ch)
        c_ = ch // reduction_ratio
        self.squeeze = nn.Linear(
            ch,
            c_,
            weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
            bias_attr=True)

        stdv = 1.0 / math.sqrt(c_)
        self.extract = nn.Linear(
            c_,
            ch,
            weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
            bias_attr=True)

    def forward(self, inputs):
        out = self.pool(inputs)
        out = paddle.squeeze(out, axis=[2, 3])
        out = self.squeeze(out)
        out = F.relu(out)
        out = self.extract(out)
        out = F.sigmoid(out)
        out = paddle.unsqueeze(out, axis=[2, 3])
        scale = out * inputs
        return scale


145
class BasicBlock(nn.Layer):
W
wangxinxin08 已提交
146 147 148

    expansion = 1

149 150 151 152 153 154
    def __init__(self,
                 ch_in,
                 ch_out,
                 stride,
                 shortcut,
                 variant='b',
W
wangxinxin08 已提交
155 156
                 groups=1,
                 base_width=64,
157 158 159 160
                 lr=1.0,
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True,
W
wangxinxin08 已提交
161 162
                 dcn_v2=False,
                 std_senet=False):
163 164
        super(BasicBlock, self).__init__()
        assert dcn_v2 is False, "Not implemented yet."
W
wangxinxin08 已提交
165
        assert groups == 1 and base_width == 64, 'BasicBlock only supports groups=1 and base_width=64'
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184

        self.shortcut = shortcut
        if not shortcut:
            if variant == 'd' and stride == 2:
                self.short = nn.Sequential()
                self.short.add_sublayer(
                    'pool',
                    nn.AvgPool2D(
                        kernel_size=2, stride=2, padding=0, ceil_mode=True))
                self.short.add_sublayer(
                    'conv',
                    ConvNormLayer(
                        ch_in=ch_in,
                        ch_out=ch_out,
                        filter_size=1,
                        stride=1,
                        norm_type=norm_type,
                        norm_decay=norm_decay,
                        freeze_norm=freeze_norm,
W
wangxinxin08 已提交
185
                        lr=lr))
186 187 188 189 190 191 192 193 194
            else:
                self.short = ConvNormLayer(
                    ch_in=ch_in,
                    ch_out=ch_out,
                    filter_size=1,
                    stride=stride,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
W
wangxinxin08 已提交
195
                    lr=lr)
196 197 198 199 200 201 202 203 204 205

        self.branch2a = ConvNormLayer(
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=3,
            stride=stride,
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
W
wangxinxin08 已提交
206
            lr=lr)
207 208 209 210 211 212 213 214 215 216

        self.branch2b = ConvNormLayer(
            ch_in=ch_out,
            ch_out=ch_out,
            filter_size=3,
            stride=1,
            act=None,
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
W
wangxinxin08 已提交
217 218 219 220 221
            lr=lr)

        self.std_senet = std_senet
        if self.std_senet:
            self.se = SELayer(ch_out)
222 223 224 225

    def forward(self, inputs):
        out = self.branch2a(inputs)
        out = self.branch2b(out)
W
wangxinxin08 已提交
226 227
        if self.std_senet:
            out = self.se(out)
228 229 230 231 232 233 234 235 236

        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        out = paddle.add(x=out, y=short)
        out = F.relu(out)

Q
qingqing01 已提交
237 238 239 240
        return out


class BottleNeck(nn.Layer):
W
wangxinxin08 已提交
241 242 243

    expansion = 4

Q
qingqing01 已提交
244 245 246 247 248 249
    def __init__(self,
                 ch_in,
                 ch_out,
                 stride,
                 shortcut,
                 variant='b',
250 251
                 groups=1,
                 base_width=4,
Q
qingqing01 已提交
252 253 254
                 lr=1.0,
                 norm_type='bn',
                 norm_decay=0.,
F
Feng Ni 已提交
255
                 freeze_norm=True,
W
wangxinxin08 已提交
256 257
                 dcn_v2=False,
                 std_senet=False):
Q
qingqing01 已提交
258 259 260 261 262 263
        super(BottleNeck, self).__init__()
        if variant == 'a':
            stride1, stride2 = stride, 1
        else:
            stride1, stride2 = 1, stride

264
        # ResNeXt
W
wangxinxin08 已提交
265
        width = int(ch_out * (base_width / 64.)) * groups
Q
qingqing01 已提交
266 267 268

        self.shortcut = shortcut
        if not shortcut:
269 270 271 272 273 274 275 276 277 278
            if variant == 'd' and stride == 2:
                self.short = nn.Sequential()
                self.short.add_sublayer(
                    'pool',
                    nn.AvgPool2D(
                        kernel_size=2, stride=2, padding=0, ceil_mode=True))
                self.short.add_sublayer(
                    'conv',
                    ConvNormLayer(
                        ch_in=ch_in,
W
wangxinxin08 已提交
279
                        ch_out=ch_out * self.expansion,
280 281 282 283 284
                        filter_size=1,
                        stride=1,
                        norm_type=norm_type,
                        norm_decay=norm_decay,
                        freeze_norm=freeze_norm,
W
wangxinxin08 已提交
285
                        lr=lr))
286 287 288
            else:
                self.short = ConvNormLayer(
                    ch_in=ch_in,
W
wangxinxin08 已提交
289
                    ch_out=ch_out * self.expansion,
290 291 292 293 294
                    filter_size=1,
                    stride=stride,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
W
wangxinxin08 已提交
295
                    lr=lr)
Q
qingqing01 已提交
296 297 298

        self.branch2a = ConvNormLayer(
            ch_in=ch_in,
299
            ch_out=width,
Q
qingqing01 已提交
300 301
            filter_size=1,
            stride=stride1,
302
            groups=1,
Q
qingqing01 已提交
303 304 305 306
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
W
wangxinxin08 已提交
307
            lr=lr)
Q
qingqing01 已提交
308 309

        self.branch2b = ConvNormLayer(
310 311
            ch_in=width,
            ch_out=width,
Q
qingqing01 已提交
312 313
            filter_size=3,
            stride=stride2,
314
            groups=groups,
Q
qingqing01 已提交
315 316 317 318 319
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
W
wangxinxin08 已提交
320
            dcn_v2=dcn_v2)
Q
qingqing01 已提交
321 322

        self.branch2c = ConvNormLayer(
323
            ch_in=width,
W
wangxinxin08 已提交
324
            ch_out=ch_out * self.expansion,
Q
qingqing01 已提交
325 326
            filter_size=1,
            stride=1,
327
            groups=1,
Q
qingqing01 已提交
328 329 330
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
W
wangxinxin08 已提交
331 332 333 334 335
            lr=lr)

        self.std_senet = std_senet
        if self.std_senet:
            self.se = SELayer(ch_out * self.expansion)
Q
qingqing01 已提交
336 337 338 339 340 341

    def forward(self, inputs):

        out = self.branch2a(inputs)
        out = self.branch2b(out)
        out = self.branch2c(out)
W
wangxinxin08 已提交
342 343 344 345

        if self.std_senet:
            out = self.se(out)

Q
qingqing01 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358
        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        out = paddle.add(x=out, y=short)
        out = F.relu(out)

        return out


class Blocks(nn.Layer):
    def __init__(self,
W
wangxinxin08 已提交
359
                 block,
Q
qingqing01 已提交
360 361 362 363 364
                 ch_in,
                 ch_out,
                 count,
                 name_adapter,
                 stage_num,
365 366
                 variant='b',
                 groups=1,
W
wangxinxin08 已提交
367
                 base_width=64,
Q
qingqing01 已提交
368 369 370
                 lr=1.0,
                 norm_type='bn',
                 norm_decay=0.,
F
Feng Ni 已提交
371
                 freeze_norm=True,
W
wangxinxin08 已提交
372 373
                 dcn_v2=False,
                 std_senet=False):
Q
qingqing01 已提交
374 375 376 377 378
        super(Blocks, self).__init__()

        self.blocks = []
        for i in range(count):
            conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
W
wangxinxin08 已提交
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
            layer = self.add_sublayer(
                conv_name,
                block(
                    ch_in=ch_in,
                    ch_out=ch_out,
                    stride=2 if i == 0 and stage_num != 2 else 1,
                    shortcut=False if i == 0 else True,
                    variant=variant,
                    groups=groups,
                    base_width=base_width,
                    lr=lr,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
                    dcn_v2=dcn_v2,
                    std_senet=std_senet))
            self.blocks.append(layer)
            if i == 0:
                ch_in = ch_out * block.expansion
Q
qingqing01 已提交
398 399 400 401 402 403 404 405 406 407 408

    def forward(self, inputs):
        block_out = inputs
        for block in self.blocks:
            block_out = block(block_out)
        return block_out


@register
@serializable
class ResNet(nn.Layer):
409 410
    __shared__ = ['norm_type']

Q
qingqing01 已提交
411 412
    def __init__(self,
                 depth=50,
W
wangxinxin08 已提交
413
                 ch_in=64,
Q
qingqing01 已提交
414
                 variant='b',
415 416
                 lr_mult_list=[1.0, 1.0, 1.0, 1.0],
                 groups=1,
W
wangxinxin08 已提交
417
                 base_width=64,
Q
qingqing01 已提交
418 419 420 421 422
                 norm_type='bn',
                 norm_decay=0,
                 freeze_norm=True,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
F
Feng Ni 已提交
423
                 dcn_v2_stages=[-1],
W
wangxinxin08 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
                 num_stages=4,
                 std_senet=False):
        """
        Residual Network, see https://arxiv.org/abs/1512.03385
        
        Args:
            depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
            ch_in (int): output channel of first stage, default 64
            variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
            lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
                                 lower learning rate ratio is need for pretrained model 
                                 got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
            groups (int): group convolution cardinality
            base_width (int): base width of each group convolution
            norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
            norm_decay (float): weight decay for normalization layer weights
            freeze_norm (bool): freeze normalization layers
            freeze_at (int): freeze the backbone at which stage
            return_idx (list): index of the stages whose feature maps are returned
            dcn_v2_stages (list): index of stages who select deformable conv v2
            num_stages (int): total num of stages
            std_senet (bool): whether use senet, default True
        """
Q
qingqing01 已提交
447
        super(ResNet, self).__init__()
448 449
        self._model_type = 'ResNet' if groups == 1 else 'ResNeXt'
        assert num_stages >= 1 and num_stages <= 4
Q
qingqing01 已提交
450 451
        self.depth = depth
        self.variant = variant
W
wangxinxin08 已提交
452 453
        self.groups = groups
        self.base_width = base_width
Q
qingqing01 已提交
454 455 456 457 458 459 460 461 462 463 464 465
        self.norm_type = norm_type
        self.norm_decay = norm_decay
        self.freeze_norm = freeze_norm
        self.freeze_at = freeze_at
        if isinstance(return_idx, Integral):
            return_idx = [return_idx]
        assert max(return_idx) < num_stages, \
            'the maximum return index must smaller than num_stages, ' \
            'but received maximum return index is {} and num_stages ' \
            'is {}'.format(max(return_idx), num_stages)
        self.return_idx = return_idx
        self.num_stages = num_stages
466 467 468 469 470
        assert len(lr_mult_list) == 4, \
            "lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
        if isinstance(dcn_v2_stages, Integral):
            dcn_v2_stages = [dcn_v2_stages]
        assert max(dcn_v2_stages) < num_stages
Q
qingqing01 已提交
471

F
Feng Ni 已提交
472 473 474 475 476
        if isinstance(dcn_v2_stages, Integral):
            dcn_v2_stages = [dcn_v2_stages]
        assert max(dcn_v2_stages) < num_stages
        self.dcn_v2_stages = dcn_v2_stages

Q
qingqing01 已提交
477 478 479 480 481 482
        block_nums = ResNet_cfg[depth]
        na = NameAdapter(self)

        conv1_name = na.fix_c1_stage_name()
        if variant in ['c', 'd']:
            conv_def = [
W
wangxinxin08 已提交
483 484 485
                [3, ch_in // 2, 3, 2, "conv1_1"],
                [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
                [ch_in // 2, ch_in, 3, 1, "conv1_3"],
Q
qingqing01 已提交
486 487
            ]
        else:
W
wangxinxin08 已提交
488
            conv_def = [[3, ch_in, 7, 2, conv1_name]]
Q
qingqing01 已提交
489 490 491 492 493 494 495 496 497
        self.conv1 = nn.Sequential()
        for (c_in, c_out, k, s, _name) in conv_def:
            self.conv1.add_sublayer(
                _name,
                ConvNormLayer(
                    ch_in=c_in,
                    ch_out=c_out,
                    filter_size=k,
                    stride=s,
498
                    groups=1,
Q
qingqing01 已提交
499 500 501 502
                    act='relu',
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
W
wangxinxin08 已提交
503
                    lr=1.0))
Q
qingqing01 已提交
504

W
wangxinxin08 已提交
505
        self.ch_in = ch_in
Q
qingqing01 已提交
506
        ch_out_list = [64, 128, 256, 512]
W
wangxinxin08 已提交
507
        block = BottleNeck if depth >= 50 else BasicBlock
508

W
wangxinxin08 已提交
509
        self._out_channels = [block.expansion * v for v in ch_out_list]
510
        self._out_strides = [4, 8, 16, 32]
Q
qingqing01 已提交
511 512 513

        self.res_layers = []
        for i in range(num_stages):
514
            lr_mult = lr_mult_list[i]
Q
qingqing01 已提交
515 516 517 518 519
            stage_num = i + 2
            res_name = "res{}".format(stage_num)
            res_layer = self.add_sublayer(
                res_name,
                Blocks(
W
wangxinxin08 已提交
520 521
                    block,
                    self.ch_in,
Q
qingqing01 已提交
522 523 524 525
                    ch_out_list[i],
                    count=block_nums[i],
                    name_adapter=na,
                    stage_num=stage_num,
526 527 528
                    variant=variant,
                    groups=groups,
                    base_width=base_width,
Q
qingqing01 已提交
529 530 531
                    lr=lr_mult,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
F
Feng Ni 已提交
532
                    freeze_norm=freeze_norm,
W
wangxinxin08 已提交
533 534
                    dcn_v2=(i in self.dcn_v2_stages),
                    std_senet=std_senet))
Q
qingqing01 已提交
535
            self.res_layers.append(res_layer)
W
wangxinxin08 已提交
536
            self.ch_in = self._out_channels[i]
Q
qingqing01 已提交
537

538 539 540 541 542 543 544 545
    @property
    def out_shape(self):
        return [
            ShapeSpec(
                channels=self._out_channels[i], stride=self._out_strides[i])
            for i in self.return_idx
        ]

Q
qingqing01 已提交
546 547 548
    def forward(self, inputs):
        x = inputs['image']
        conv1 = self.conv1(x)
549
        x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
Q
qingqing01 已提交
550 551 552 553 554 555 556 557
        outs = []
        for idx, stage in enumerate(self.res_layers):
            x = stage(x)
            if idx == self.freeze_at:
                x.stop_gradient = True
            if idx in self.return_idx:
                outs.append(x)
        return outs
558 559 560 561


@register
class Res5Head(nn.Layer):
562
    def __init__(self, depth=50):
563
        super(Res5Head, self).__init__()
564 565 566
        feat_in, feat_out = [1024, 512]
        if depth < 50:
            feat_in = 256
567
        na = NameAdapter(self)
W
wangguanzhong 已提交
568 569
        self.res5 = Blocks(
            depth, feat_in, feat_out, count=3, name_adapter=na, stage_num=5)
570 571 572 573 574 575
        self.feat_out = feat_out if depth < 50 else feat_out * 4

    @property
    def out_shape(self):
        return [ShapeSpec(
            channels=self.feat_out,
576
            stride=16, )]
577 578 579 580

    def forward(self, roi_feat, stage=0):
        y = self.res5(roi_feat)
        return y