resnet.py 17.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

15
import math
Q
qingqing01 已提交
16 17 18
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
19
from paddle import ParamAttr
Q
qingqing01 已提交
20 21 22 23
from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay
from .name_adapter import NameAdapter
from numbers import Integral
F
Feng Ni 已提交
24
from ppdet.modeling.layers import DeformableConvV2
Q
qingqing01 已提交
25

26 27
__all__ = ['ResNet', 'Res5Head']

28 29 30 31 32 33 34 35
ResNet_cfg = {
    18: [2, 2, 2, 2],
    34: [3, 4, 6, 3],
    50: [3, 4, 6, 3],
    101: [3, 4, 23, 3],
    152: [3, 8, 36, 3],
}

Q
qingqing01 已提交
36 37 38 39 40 41 42 43

class ConvNormLayer(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size,
                 stride,
                 name_adapter,
44
                 groups=1,
Q
qingqing01 已提交
45 46 47 48 49
                 act=None,
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True,
                 lr=1.0,
F
Feng Ni 已提交
50
                 dcn_v2=False,
Q
qingqing01 已提交
51 52 53 54 55 56
                 name=None):
        super(ConvNormLayer, self).__init__()
        assert norm_type in ['bn', 'sync_bn']
        self.norm_type = norm_type
        self.act = act

F
Feng Ni 已提交
57
        if not dcn_v2:
58
            self.conv = nn.Conv2D(
F
Feng Ni 已提交
59 60 61 62 63
                in_channels=ch_in,
                out_channels=ch_out,
                kernel_size=filter_size,
                stride=stride,
                padding=(filter_size - 1) // 2,
64
                groups=groups,
F
Feng Ni 已提交
65 66 67 68 69 70 71 72 73 74
                weight_attr=ParamAttr(
                    learning_rate=lr, name=name + "_weights"),
                bias_attr=False)
        else:
            self.conv = DeformableConvV2(
                in_channels=ch_in,
                out_channels=ch_out,
                kernel_size=filter_size,
                stride=stride,
                padding=(filter_size - 1) // 2,
75
                groups=groups,
F
Feng Ni 已提交
76 77 78 79
                weight_attr=ParamAttr(
                    learning_rate=lr, name=name + '_weights'),
                bias_attr=False,
                name=name)
Q
qingqing01 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

        bn_name = name_adapter.fix_conv_norm_name(name)
        norm_lr = 0. if freeze_norm else lr
        param_attr = ParamAttr(
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            name=bn_name + "_scale",
            trainable=False if freeze_norm else True)
        bias_attr = ParamAttr(
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            name=bn_name + "_offset",
            trainable=False if freeze_norm else True)

        global_stats = True if freeze_norm else False
95 96 97 98 99 100 101 102 103 104 105 106
        if norm_type == 'sync_bn':
            self.norm = nn.SyncBatchNorm(
                ch_out, weight_attr=param_attr, bias_attr=bias_attr)
        else:
            self.norm = nn.BatchNorm(
                ch_out,
                act=None,
                param_attr=param_attr,
                bias_attr=bias_attr,
                use_global_stats=global_stats,
                moving_mean_name=bn_name + '_mean',
                moving_variance_name=bn_name + '_variance')
Q
qingqing01 已提交
107 108 109 110 111 112 113 114
        norm_params = self.norm.parameters()

        if freeze_norm:
            for param in norm_params:
                param.stop_gradient = True

    def forward(self, inputs):
        out = self.conv(inputs)
115
        if self.norm_type in ['bn', 'sync_bn']:
Q
qingqing01 已提交
116
            out = self.norm(out)
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        if self.act:
            out = getattr(F, self.act)(out)
        return out


class BasicBlock(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 stride,
                 shortcut,
                 name_adapter,
                 name,
                 variant='b',
                 lr=1.0,
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True,
                 dcn_v2=False):
        super(BasicBlock, self).__init__()
        assert dcn_v2 is False, "Not implemented yet."
        conv_name1, conv_name2, shortcut_name = name_adapter.fix_basicblock_name(
            name)

        self.shortcut = shortcut
        if not shortcut:
            if variant == 'd' and stride == 2:
                self.short = nn.Sequential()
                self.short.add_sublayer(
                    'pool',
                    nn.AvgPool2D(
                        kernel_size=2, stride=2, padding=0, ceil_mode=True))
                self.short.add_sublayer(
                    'conv',
                    ConvNormLayer(
                        ch_in=ch_in,
                        ch_out=ch_out,
                        filter_size=1,
                        stride=1,
                        name_adapter=name_adapter,
                        norm_type=norm_type,
                        norm_decay=norm_decay,
                        freeze_norm=freeze_norm,
                        lr=lr,
                        name=shortcut_name))
            else:
                self.short = ConvNormLayer(
                    ch_in=ch_in,
                    ch_out=ch_out,
                    filter_size=1,
                    stride=stride,
                    name_adapter=name_adapter,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
                    lr=lr,
                    name=shortcut_name)

        self.branch2a = ConvNormLayer(
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=3,
            stride=stride,
            name_adapter=name_adapter,
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
            name=conv_name1)

        self.branch2b = ConvNormLayer(
            ch_in=ch_out,
            ch_out=ch_out,
            filter_size=3,
            stride=1,
            name_adapter=name_adapter,
            act=None,
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
            name=conv_name2)

    def forward(self, inputs):
        out = self.branch2a(inputs)
        out = self.branch2b(out)

        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        out = paddle.add(x=out, y=short)
        out = F.relu(out)

Q
qingqing01 已提交
213 214 215 216 217 218 219 220 221 222 223 224
        return out


class BottleNeck(nn.Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 stride,
                 shortcut,
                 name_adapter,
                 name,
                 variant='b',
225 226 227
                 groups=1,
                 base_width=4,
                 base_channels=64,
Q
qingqing01 已提交
228 229 230
                 lr=1.0,
                 norm_type='bn',
                 norm_decay=0.,
F
Feng Ni 已提交
231 232
                 freeze_norm=True,
                 dcn_v2=False):
Q
qingqing01 已提交
233 234 235 236 237 238
        super(BottleNeck, self).__init__()
        if variant == 'a':
            stride1, stride2 = stride, 1
        else:
            stride1, stride2 = 1, stride

239 240 241 242 243 244 245 246
        # ResNeXt
        if groups == 1:
            width = ch_out
        else:
            width = int(
                math.floor(ch_out * (base_width * 1.0 / base_channels)) *
                groups)

Q
qingqing01 已提交
247 248 249 250 251
        conv_name1, conv_name2, conv_name3, \
            shortcut_name = name_adapter.fix_bottleneck_name(name)

        self.shortcut = shortcut
        if not shortcut:
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
            if variant == 'd' and stride == 2:
                self.short = nn.Sequential()
                self.short.add_sublayer(
                    'pool',
                    nn.AvgPool2D(
                        kernel_size=2, stride=2, padding=0, ceil_mode=True))
                self.short.add_sublayer(
                    'conv',
                    ConvNormLayer(
                        ch_in=ch_in,
                        ch_out=ch_out * 4,
                        filter_size=1,
                        stride=1,
                        name_adapter=name_adapter,
                        norm_type=norm_type,
                        norm_decay=norm_decay,
                        freeze_norm=freeze_norm,
                        lr=lr,
                        name=shortcut_name))
            else:
                self.short = ConvNormLayer(
                    ch_in=ch_in,
                    ch_out=ch_out * 4,
                    filter_size=1,
                    stride=stride,
                    name_adapter=name_adapter,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
                    lr=lr,
                    name=shortcut_name)
Q
qingqing01 已提交
283 284 285

        self.branch2a = ConvNormLayer(
            ch_in=ch_in,
286
            ch_out=width,
Q
qingqing01 已提交
287 288 289
            filter_size=1,
            stride=stride1,
            name_adapter=name_adapter,
290
            groups=1,
Q
qingqing01 已提交
291 292 293 294 295 296 297 298
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
            name=conv_name1)

        self.branch2b = ConvNormLayer(
299 300
            ch_in=width,
            ch_out=width,
Q
qingqing01 已提交
301 302 303
            filter_size=3,
            stride=stride2,
            name_adapter=name_adapter,
304
            groups=groups,
Q
qingqing01 已提交
305 306 307 308 309
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
F
Feng Ni 已提交
310
            dcn_v2=dcn_v2,
Q
qingqing01 已提交
311 312 313
            name=conv_name2)

        self.branch2c = ConvNormLayer(
314
            ch_in=width,
Q
qingqing01 已提交
315 316 317 318
            ch_out=ch_out * 4,
            filter_size=1,
            stride=1,
            name_adapter=name_adapter,
319
            groups=1,
Q
qingqing01 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
            name=conv_name3)

    def forward(self, inputs):

        out = self.branch2a(inputs)
        out = self.branch2b(out)
        out = self.branch2c(out)
        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        out = paddle.add(x=out, y=short)
        out = F.relu(out)

        return out


class Blocks(nn.Layer):
    def __init__(self,
344
                 depth,
Q
qingqing01 已提交
345 346 347 348 349
                 ch_in,
                 ch_out,
                 count,
                 name_adapter,
                 stage_num,
350 351 352 353
                 variant='b',
                 groups=1,
                 base_width=-1,
                 base_channels=-1,
Q
qingqing01 已提交
354 355 356
                 lr=1.0,
                 norm_type='bn',
                 norm_decay=0.,
F
Feng Ni 已提交
357 358
                 freeze_norm=True,
                 dcn_v2=False):
Q
qingqing01 已提交
359 360 361 362 363
        super(Blocks, self).__init__()

        self.blocks = []
        for i in range(count):
            conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
            if depth >= 50:
                block = self.add_sublayer(
                    conv_name,
                    BottleNeck(
                        ch_in=ch_in if i == 0 else ch_out * 4,
                        ch_out=ch_out,
                        stride=2 if i == 0 and stage_num != 2 else 1,
                        shortcut=False if i == 0 else True,
                        name_adapter=name_adapter,
                        name=conv_name,
                        variant=variant,
                        groups=groups,
                        base_width=base_width,
                        base_channels=base_channels,
                        lr=lr,
                        norm_type=norm_type,
                        norm_decay=norm_decay,
                        freeze_norm=freeze_norm,
                        dcn_v2=dcn_v2))
            else:
                ch_in = ch_in // 4 if i > 0 else ch_in
                block = self.add_sublayer(
                    conv_name,
                    BasicBlock(
                        ch_in=ch_in if i == 0 else ch_out,
                        ch_out=ch_out,
                        stride=2 if i == 0 and stage_num != 2 else 1,
                        shortcut=False if i == 0 else True,
                        name_adapter=name_adapter,
                        name=conv_name,
                        variant=variant,
                        lr=lr,
                        norm_type=norm_type,
                        norm_decay=norm_decay,
                        freeze_norm=freeze_norm,
                        dcn_v2=dcn_v2))
Q
qingqing01 已提交
400 401 402 403 404 405 406 407 408 409 410 411
            self.blocks.append(block)

    def forward(self, inputs):
        block_out = inputs
        for block in self.blocks:
            block_out = block(block_out)
        return block_out


@register
@serializable
class ResNet(nn.Layer):
412 413
    __shared__ = ['norm_type']

Q
qingqing01 已提交
414 415 416
    def __init__(self,
                 depth=50,
                 variant='b',
417 418 419 420
                 lr_mult_list=[1.0, 1.0, 1.0, 1.0],
                 groups=1,
                 base_width=-1,
                 base_channels=-1,
Q
qingqing01 已提交
421 422 423 424 425
                 norm_type='bn',
                 norm_decay=0,
                 freeze_norm=True,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
F
Feng Ni 已提交
426
                 dcn_v2_stages=[-1],
Q
qingqing01 已提交
427 428
                 num_stages=4):
        super(ResNet, self).__init__()
429 430
        self._model_type = 'ResNet' if groups == 1 else 'ResNeXt'
        assert num_stages >= 1 and num_stages <= 4
Q
qingqing01 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443 444
        self.depth = depth
        self.variant = variant
        self.norm_type = norm_type
        self.norm_decay = norm_decay
        self.freeze_norm = freeze_norm
        self.freeze_at = freeze_at
        if isinstance(return_idx, Integral):
            return_idx = [return_idx]
        assert max(return_idx) < num_stages, \
            'the maximum return index must smaller than num_stages, ' \
            'but received maximum return index is {} and num_stages ' \
            'is {}'.format(max(return_idx), num_stages)
        self.return_idx = return_idx
        self.num_stages = num_stages
445 446 447 448 449
        assert len(lr_mult_list) == 4, \
            "lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
        if isinstance(dcn_v2_stages, Integral):
            dcn_v2_stages = [dcn_v2_stages]
        assert max(dcn_v2_stages) < num_stages
Q
qingqing01 已提交
450

F
Feng Ni 已提交
451 452 453 454 455
        if isinstance(dcn_v2_stages, Integral):
            dcn_v2_stages = [dcn_v2_stages]
        assert max(dcn_v2_stages) < num_stages
        self.dcn_v2_stages = dcn_v2_stages

Q
qingqing01 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        block_nums = ResNet_cfg[depth]
        na = NameAdapter(self)

        conv1_name = na.fix_c1_stage_name()
        if variant in ['c', 'd']:
            conv_def = [
                [3, 32, 3, 2, "conv1_1"],
                [32, 32, 3, 1, "conv1_2"],
                [32, 64, 3, 1, "conv1_3"],
            ]
        else:
            conv_def = [[3, 64, 7, 2, conv1_name]]
        self.conv1 = nn.Sequential()
        for (c_in, c_out, k, s, _name) in conv_def:
            self.conv1.add_sublayer(
                _name,
                ConvNormLayer(
                    ch_in=c_in,
                    ch_out=c_out,
                    filter_size=k,
                    stride=s,
                    name_adapter=na,
478
                    groups=1,
Q
qingqing01 已提交
479 480 481 482
                    act='relu',
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
483
                    lr=1.0,
Q
qingqing01 已提交
484 485
                    name=_name))

486
        self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
Q
qingqing01 已提交
487 488 489 490 491 492

        ch_in_list = [64, 256, 512, 1024]
        ch_out_list = [64, 128, 256, 512]

        self.res_layers = []
        for i in range(num_stages):
493
            lr_mult = lr_mult_list[i]
Q
qingqing01 已提交
494 495 496 497 498
            stage_num = i + 2
            res_name = "res{}".format(stage_num)
            res_layer = self.add_sublayer(
                res_name,
                Blocks(
499 500 501
                    depth,
                    ch_in_list[i] // 4
                    if i > 0 and depth < 50 else ch_in_list[i],
Q
qingqing01 已提交
502 503 504 505
                    ch_out_list[i],
                    count=block_nums[i],
                    name_adapter=na,
                    stage_num=stage_num,
506 507 508 509
                    variant=variant,
                    groups=groups,
                    base_width=base_width,
                    base_channels=base_channels,
Q
qingqing01 已提交
510 511 512
                    lr=lr_mult,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
F
Feng Ni 已提交
513 514
                    freeze_norm=freeze_norm,
                    dcn_v2=(i in self.dcn_v2_stages)))
Q
qingqing01 已提交
515 516 517 518 519 520 521 522 523 524 525 526 527 528
            self.res_layers.append(res_layer)

    def forward(self, inputs):
        x = inputs['image']
        conv1 = self.conv1(x)
        x = self.pool(conv1)
        outs = []
        for idx, stage in enumerate(self.res_layers):
            x = stage(x)
            if idx == self.freeze_at:
                x.stop_gradient = True
            if idx in self.return_idx:
                outs.append(x)
        return outs
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545


@register
class Res5Head(nn.Layer):
    def __init__(self, feat_in=1024, feat_out=512):
        super(Res5Head, self).__init__()
        na = NameAdapter(self)
        self.res5_conv = []
        self.res5 = self.add_sublayer(
            'res5_roi_feat',
            Blocks(
                feat_in, feat_out, count=3, name_adapter=na, stage_num=5))
        self.feat_out = feat_out * 4

    def forward(self, roi_feat, stage=0):
        y = self.res5(roi_feat)
        return y