resnet.py 21.2 KB
Newer Older
C
cuicheng01 已提交
1
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
C
cuicheng01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

G
gaotingquan 已提交
15 16
# reference: https://arxiv.org/pdf/1512.03385

C
cuicheng01 已提交
17
from __future__ import absolute_import, division, print_function
C
cuicheng01 已提交
18 19 20 21 22

import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
Z
zhiboniu 已提交
23
from paddle.nn import Conv2D, BatchNorm, Linear, BatchNorm2D
C
cuicheng01 已提交
24 25
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
Z
zhiboniu 已提交
26
from paddle.regularizer import L2Decay
C
cuicheng01 已提交
27 28
import math

C
cuicheng01 已提交
29
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
D
dongshuilong 已提交
30
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
C
cuicheng01 已提交
31 32

MODEL_URLS = {
D
dongshuilong 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    "ResNet18":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
    "ResNet18_vd":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
    "ResNet34":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
    "ResNet34_vd":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
    "ResNet50":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
    "ResNet50_vd":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
    "ResNet101":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
    "ResNet101_vd":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
    "ResNet152":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
    "ResNet152_vd":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
    "ResNet200_vd":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
C
cuicheng01 已提交
55
}
C
cuicheng01 已提交
56

57 58 59 60 61 62 63 64 65
MODEL_STAGES_PATTERN = {
    "ResNet18": ["blocks[1]", "blocks[3]", "blocks[5]", "blocks[7]"],
    "ResNet34": ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"],
    "ResNet50": ["blocks[2]", "blocks[6]", "blocks[12]", "blocks[15]"],
    "ResNet101": ["blocks[2]", "blocks[6]", "blocks[29]", "blocks[32]"],
    "ResNet152": ["blocks[2]", "blocks[10]", "blocks[46]", "blocks[49]"],
    "ResNet200": ["blocks[2]", "blocks[14]", "blocks[62]", "blocks[65]"]
}

C
cuicheng01 已提交
66 67 68 69 70 71 72 73 74 75
__all__ = MODEL_URLS.keys()
'''
ResNet config: dict.
    key: depth of ResNet.
    values: config's dict of specific model.
        keys:
            block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
            block_depth: The number of blocks in different stages in ResNet.
            num_channels: The number of channels to enter the next stage.
'''
C
cuicheng01 已提交
76 77
NET_CONFIG = {
    "18": {
D
dongshuilong 已提交
78 79 80 81
        "block_type": "BasicBlock",
        "block_depth": [2, 2, 2, 2],
        "num_channels": [64, 64, 128, 256]
    },
C
cuicheng01 已提交
82
    "34": {
D
dongshuilong 已提交
83 84 85 86
        "block_type": "BasicBlock",
        "block_depth": [3, 4, 6, 3],
        "num_channels": [64, 64, 128, 256]
    },
C
cuicheng01 已提交
87
    "50": {
D
dongshuilong 已提交
88 89 90 91
        "block_type": "BottleneckBlock",
        "block_depth": [3, 4, 6, 3],
        "num_channels": [64, 256, 512, 1024]
    },
C
cuicheng01 已提交
92
    "101": {
D
dongshuilong 已提交
93 94 95 96
        "block_type": "BottleneckBlock",
        "block_depth": [3, 4, 23, 3],
        "num_channels": [64, 256, 512, 1024]
    },
C
cuicheng01 已提交
97
    "152": {
D
dongshuilong 已提交
98 99 100 101
        "block_type": "BottleneckBlock",
        "block_depth": [3, 8, 36, 3],
        "num_channels": [64, 256, 512, 1024]
    },
C
cuicheng01 已提交
102
    "200": {
D
dongshuilong 已提交
103 104 105 106
        "block_type": "BottleneckBlock",
        "block_depth": [3, 12, 48, 3],
        "num_channels": [64, 256, 512, 1024]
    },
C
cuicheng01 已提交
107 108 109 110 111 112 113 114 115 116 117 118
}


class ConvBNLayer(TheseusLayer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
                 is_vd_mode=False,
                 act=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
119
                 lr_mult=1.0,
Z
zhiboniu 已提交
120
                 norm_decay=0.,
littletomatodonkey's avatar
littletomatodonkey 已提交
121
                 data_format="NCHW"):
C
cuicheng01 已提交
122
        super().__init__()
C
cuicheng01 已提交
123 124
        self.is_vd_mode = is_vd_mode
        self.act = act
C
cuicheng01 已提交
125
        self.avg_pool = AvgPool2D(
C
cuicheng01 已提交
126 127 128 129 130 131 132 133 134
            kernel_size=2, stride=2, padding=0, ceil_mode=True)
        self.conv = Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            weight_attr=ParamAttr(learning_rate=lr_mult),
littletomatodonkey's avatar
littletomatodonkey 已提交
135 136
            bias_attr=False,
            data_format=data_format)
Z
zhiboniu 已提交
137

Z
zhiboniu 已提交
138
        weight_attr = ParamAttr(
Z
zhiboniu 已提交
139 140 141 142 143 144 145 146 147
            learning_rate=lr_mult,
            regularizer=L2Decay(norm_decay),
            trainable=True)
        bias_attr = ParamAttr(
            learning_rate=lr_mult,
            regularizer=L2Decay(norm_decay),
            trainable=True)

        self.bn = BatchNorm2D(
Z
zhiboniu 已提交
148
            num_filters, weight_attr=weight_attr, bias_attr=bias_attr)
C
cuicheng01 已提交
149 150 151 152
        self.relu = nn.ReLU()

    def forward(self, x):
        if self.is_vd_mode:
C
cuicheng01 已提交
153
            x = self.avg_pool(x)
C
cuicheng01 已提交
154 155 156 157 158 159 160 161
        x = self.conv(x)
        x = self.bn(x)
        if self.act:
            x = self.relu(x)
        return x


class BottleneckBlock(TheseusLayer):
littletomatodonkey's avatar
littletomatodonkey 已提交
162 163 164 165 166 167 168
    def __init__(self,
                 num_channels,
                 num_filters,
                 stride,
                 shortcut=True,
                 if_first=False,
                 lr_mult=1.0,
Z
zhiboniu 已提交
169
                 norm_decay=0.,
littletomatodonkey's avatar
littletomatodonkey 已提交
170
                 data_format="NCHW"):
C
cuicheng01 已提交
171
        super().__init__()
C
cuicheng01 已提交
172 173 174 175 176

        self.conv0 = ConvBNLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=1,
C
cuicheng01 已提交
177
            act="relu",
littletomatodonkey's avatar
littletomatodonkey 已提交
178
            lr_mult=lr_mult,
Z
zhiboniu 已提交
179
            norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
180
            data_format=data_format)
C
cuicheng01 已提交
181 182 183 184 185
        self.conv1 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
C
cuicheng01 已提交
186
            act="relu",
littletomatodonkey's avatar
littletomatodonkey 已提交
187
            lr_mult=lr_mult,
Z
zhiboniu 已提交
188
            norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
189
            data_format=data_format)
C
cuicheng01 已提交
190 191 192 193 194
        self.conv2 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters * 4,
            filter_size=1,
            act=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
195
            lr_mult=lr_mult,
Z
zhiboniu 已提交
196
            norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
197
            data_format=data_format)
C
cuicheng01 已提交
198 199 200 201 202 203 204 205

        if not shortcut:
            self.short = ConvBNLayer(
                num_channels=num_channels,
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride if if_first else 1,
                is_vd_mode=False if if_first else True,
littletomatodonkey's avatar
littletomatodonkey 已提交
206
                lr_mult=lr_mult,
Z
zhiboniu 已提交
207
                norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
208
                data_format=data_format)
Z
zhiboniu 已提交
209

C
cuicheng01 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
        self.relu = nn.ReLU()
        self.shortcut = shortcut

    def forward(self, x):
        identity = x
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)

        if self.shortcut:
            short = identity
        else:
            short = self.short(identity)
        x = paddle.add(x=x, y=short)
        x = self.relu(x)
        return x


class BasicBlock(TheseusLayer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 stride,
                 shortcut=True,
                 if_first=False,
littletomatodonkey's avatar
littletomatodonkey 已提交
235
                 lr_mult=1.0,
Z
zhiboniu 已提交
236
                 norm_decay=0.,
littletomatodonkey's avatar
littletomatodonkey 已提交
237
                 data_format="NCHW"):
C
cuicheng01 已提交
238 239
        super().__init__()

C
cuicheng01 已提交
240 241 242 243 244 245
        self.stride = stride
        self.conv0 = ConvBNLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
C
cuicheng01 已提交
246
            act="relu",
littletomatodonkey's avatar
littletomatodonkey 已提交
247
            lr_mult=lr_mult,
Z
zhiboniu 已提交
248
            norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
249
            data_format=data_format)
C
cuicheng01 已提交
250 251 252 253 254
        self.conv1 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            act=None,
littletomatodonkey's avatar
littletomatodonkey 已提交
255
            lr_mult=lr_mult,
Z
zhiboniu 已提交
256
            norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
257
            data_format=data_format)
C
cuicheng01 已提交
258 259 260 261 262 263 264
        if not shortcut:
            self.short = ConvBNLayer(
                num_channels=num_channels,
                num_filters=num_filters,
                filter_size=1,
                stride=stride if if_first else 1,
                is_vd_mode=False if if_first else True,
littletomatodonkey's avatar
littletomatodonkey 已提交
265
                lr_mult=lr_mult,
Z
zhiboniu 已提交
266
                norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
267
                data_format=data_format)
C
cuicheng01 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
        self.shortcut = shortcut
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x
        x = self.conv0(x)
        x = self.conv1(x)
        if self.shortcut:
            short = identity
        else:
            short = self.short(identity)
        x = paddle.add(x=x, y=short)
        x = self.relu(x)
        return x


class ResNet(TheseusLayer):
C
cuicheng01 已提交
285 286 287 288 289 290 291 292 293
    """
    ResNet
    Args:
        config: dict. config of ResNet.
        version: str="vb". Different version of ResNet, version vd can perform better. 
        class_num: int=1000. The number of classes.
        lr_mult_list: list. Control the learning rate of different stages.
    Returns:
        model: nn.Layer. Specific ResNet model depends on args.
C
cuicheng01 已提交
294
    """
D
dongshuilong 已提交
295

C
cuicheng01 已提交
296 297
    def __init__(self,
                 config,
298
                 stages_pattern,
C
cuicheng01 已提交
299
                 version="vb",
H
HydrogenSulfate 已提交
300
                 stem_act="relu",
C
cuicheng01 已提交
301
                 class_num=1000,
littletomatodonkey's avatar
littletomatodonkey 已提交
302
                 lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
Z
zhiboniu 已提交
303
                 norm_decay=0.,
littletomatodonkey's avatar
littletomatodonkey 已提交
304
                 data_format="NCHW",
W
weishengyu 已提交
305
                 input_image_channel=3,
306 307
                 return_patterns=None,
                 return_stages=None):
C
cuicheng01 已提交
308
        super().__init__()
C
cuicheng01 已提交
309 310 311 312

        self.cfg = config
        self.lr_mult_list = lr_mult_list
        self.is_vd_mode = version == "vd"
C
cuicheng01 已提交
313 314 315 316 317 318
        self.class_num = class_num
        self.num_filters = [64, 128, 256, 512]
        self.block_depth = self.cfg["block_depth"]
        self.block_type = self.cfg["block_type"]
        self.num_channels = self.cfg["num_channels"]
        self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
D
dongshuilong 已提交
319

C
cuicheng01 已提交
320 321 322 323
        assert isinstance(self.lr_mult_list, (
            list, tuple
        )), "lr_mult_list should be in (list, tuple) but got {}".format(
            type(self.lr_mult_list))
D
dongshuilong 已提交
324 325 326
        assert len(self.lr_mult_list
                   ) == 5, "lr_mult_list length should be 5 but got {}".format(
                       len(self.lr_mult_list))
C
cuicheng01 已提交
327 328

        self.stem_cfg = {
C
cuicheng01 已提交
329
            #num_channels, num_filters, filter_size, stride
littletomatodonkey's avatar
littletomatodonkey 已提交
330 331 332
            "vb": [[input_image_channel, 64, 7, 2]],
            "vd":
            [[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
D
dongshuilong 已提交
333 334
        }

Z
zhiboniu 已提交
335
        self.stem = nn.Sequential(* [
C
cuicheng01 已提交
336
            ConvBNLayer(
D
dongshuilong 已提交
337 338 339 340
                num_channels=in_c,
                num_filters=out_c,
                filter_size=k,
                stride=s,
W
dbg  
weishengyu 已提交
341
                act=stem_act,
littletomatodonkey's avatar
littletomatodonkey 已提交
342
                lr_mult=self.lr_mult_list[0],
Z
zhiboniu 已提交
343
                norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
344
                data_format=data_format)
C
cuicheng01 已提交
345 346
            for in_c, out_c, k, s in self.stem_cfg[version]
        ])
D
dongshuilong 已提交
347

littletomatodonkey's avatar
littletomatodonkey 已提交
348 349
        self.max_pool = MaxPool2D(
            kernel_size=3, stride=2, padding=1, data_format=data_format)
C
cuicheng01 已提交
350 351
        block_list = []
        for block_idx in range(len(self.block_depth)):
C
cuicheng01 已提交
352
            shortcut = False
C
cuicheng01 已提交
353
            for i in range(self.block_depth[block_idx]):
D
dongshuilong 已提交
354 355 356
                block_list.append(globals()[self.block_type](
                    num_channels=self.num_channels[block_idx] if i == 0 else
                    self.num_filters[block_idx] * self.channels_mult,
C
cuicheng01 已提交
357 358
                    num_filters=self.num_filters[block_idx],
                    stride=2 if i == 0 and block_idx != 0 else 1,
C
cuicheng01 已提交
359
                    shortcut=shortcut,
C
cuicheng01 已提交
360
                    if_first=block_idx == i == 0 if version == "vd" else True,
littletomatodonkey's avatar
littletomatodonkey 已提交
361
                    lr_mult=self.lr_mult_list[block_idx + 1],
Z
zhiboniu 已提交
362
                    norm_decay=norm_decay,
littletomatodonkey's avatar
littletomatodonkey 已提交
363
                    data_format=data_format))
D
dongshuilong 已提交
364
                shortcut = True
C
cuicheng01 已提交
365
        self.blocks = nn.Sequential(*block_list)
C
cuicheng01 已提交
366

littletomatodonkey's avatar
littletomatodonkey 已提交
367
        self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
368
        self.flatten = nn.Flatten()
W
dbg  
weishengyu 已提交
369
        self.avg_pool_channels = self.num_channels[-1] * 2
C
cuicheng01 已提交
370
        stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
C
cuicheng01 已提交
371
        self.fc = Linear(
C
cuicheng01 已提交
372
            self.avg_pool_channels,
C
cuicheng01 已提交
373
            self.class_num,
D
dongshuilong 已提交
374
            weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
C
cuicheng01 已提交
375

littletomatodonkey's avatar
littletomatodonkey 已提交
376
        self.data_format = data_format
377 378 379 380 381

        super().init_res(
            stages_pattern,
            return_patterns=return_patterns,
            return_stages=return_stages)
littletomatodonkey's avatar
littletomatodonkey 已提交
382

C
cuicheng01 已提交
383
    def forward(self, x):
littletomatodonkey's avatar
littletomatodonkey 已提交
384 385 386 387 388 389 390 391 392 393
        with paddle.static.amp.fp16_guard():
            if self.data_format == "NHWC":
                x = paddle.transpose(x, [0, 2, 3, 1])
                x.stop_gradient = True
            x = self.stem(x)
            x = self.max_pool(x)
            x = self.blocks(x)
            x = self.avg_pool(x)
            x = self.flatten(x)
            x = self.fc(x)
C
cuicheng01 已提交
394 395 396
        return x


D
dongshuilong 已提交
397 398 399 400 401 402 403 404 405 406 407 408 409 410
def _load_pretrained(pretrained, model, model_url, use_ssld):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )


def ResNet18(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
411 412 413
    """
    ResNet18
    Args:
D
dongshuilong 已提交
414 415 416
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
417 418 419
    Returns:
        model: nn.Layer. Specific `ResNet18` model depends on args.
    """
420 421 422 423 424
    model = ResNet(
        config=NET_CONFIG["18"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet18"],
        version="vb",
        **kwargs)
D
dongshuilong 已提交
425
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
C
cuicheng01 已提交
426 427
    return model

C
cuicheng01 已提交
428

D
dongshuilong 已提交
429
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
430 431 432
    """
    ResNet18_vd
    Args:
D
dongshuilong 已提交
433 434 435
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
436 437 438
    Returns:
        model: nn.Layer. Specific `ResNet18_vd` model depends on args.
    """
439 440 441 442 443
    model = ResNet(
        config=NET_CONFIG["18"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet18"],
        version="vd",
        **kwargs)
D
dongshuilong 已提交
444
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
C
cuicheng01 已提交
445 446
    return model

C
cuicheng01 已提交
447

D
dongshuilong 已提交
448
def ResNet34(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
449 450 451
    """
    ResNet34
    Args:
D
dongshuilong 已提交
452 453 454
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
455
    Returns:
C
cuicheng01 已提交
456
        model: nn.Layer. Specific `ResNet34` model depends on args.
C
cuicheng01 已提交
457
    """
458 459 460 461 462
    model = ResNet(
        config=NET_CONFIG["34"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet34"],
        version="vb",
        **kwargs)
D
dongshuilong 已提交
463
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
C
cuicheng01 已提交
464 465 466
    return model


D
dongshuilong 已提交
467
def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
468 469 470
    """
    ResNet34_vd
    Args:
D
dongshuilong 已提交
471 472 473
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
474
    Returns:
C
cuicheng01 已提交
475
        model: nn.Layer. Specific `ResNet34_vd` model depends on args.
C
cuicheng01 已提交
476
    """
477 478 479 480 481
    model = ResNet(
        config=NET_CONFIG["34"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet34"],
        version="vd",
        **kwargs)
D
dongshuilong 已提交
482
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
C
cuicheng01 已提交
483 484 485
    return model


D
dongshuilong 已提交
486
def ResNet50(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
487 488 489
    """
    ResNet50
    Args:
D
dongshuilong 已提交
490 491 492
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
493 494 495
    Returns:
        model: nn.Layer. Specific `ResNet50` model depends on args.
    """
496 497 498 499 500
    model = ResNet(
        config=NET_CONFIG["50"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet50"],
        version="vb",
        **kwargs)
D
dongshuilong 已提交
501
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
C
cuicheng01 已提交
502 503
    return model

C
cuicheng01 已提交
504

D
dongshuilong 已提交
505
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
506 507 508
    """
    ResNet50_vd
    Args:
D
dongshuilong 已提交
509 510 511
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
512 513 514
    Returns:
        model: nn.Layer. Specific `ResNet50_vd` model depends on args.
    """
515 516 517 518 519
    model = ResNet(
        config=NET_CONFIG["50"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet50"],
        version="vd",
        **kwargs)
D
dongshuilong 已提交
520
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
C
cuicheng01 已提交
521 522
    return model

C
cuicheng01 已提交
523

D
dongshuilong 已提交
524
def ResNet101(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
525 526 527
    """
    ResNet101
    Args:
D
dongshuilong 已提交
528 529 530
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
531 532 533
    Returns:
        model: nn.Layer. Specific `ResNet101` model depends on args.
    """
534 535 536 537 538
    model = ResNet(
        config=NET_CONFIG["101"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet101"],
        version="vb",
        **kwargs)
D
dongshuilong 已提交
539
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
C
cuicheng01 已提交
540 541
    return model

C
cuicheng01 已提交
542

D
dongshuilong 已提交
543
def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
544 545 546
    """
    ResNet101_vd
    Args:
D
dongshuilong 已提交
547 548 549
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
550 551 552
    Returns:
        model: nn.Layer. Specific `ResNet101_vd` model depends on args.
    """
553 554 555 556 557
    model = ResNet(
        config=NET_CONFIG["101"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet101"],
        version="vd",
        **kwargs)
D
dongshuilong 已提交
558
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
C
cuicheng01 已提交
559 560
    return model

C
cuicheng01 已提交
561

D
dongshuilong 已提交
562
def ResNet152(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
563 564 565
    """
    ResNet152
    Args:
D
dongshuilong 已提交
566 567 568
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
569 570 571
    Returns:
        model: nn.Layer. Specific `ResNet152` model depends on args.
    """
572 573 574 575 576
    model = ResNet(
        config=NET_CONFIG["152"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet152"],
        version="vb",
        **kwargs)
D
dongshuilong 已提交
577
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
C
cuicheng01 已提交
578 579
    return model

C
cuicheng01 已提交
580

D
dongshuilong 已提交
581
def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
582 583 584
    """
    ResNet152_vd
    Args:
D
dongshuilong 已提交
585 586 587
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
588 589 590
    Returns:
        model: nn.Layer. Specific `ResNet152_vd` model depends on args.
    """
591 592 593 594 595
    model = ResNet(
        config=NET_CONFIG["152"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet152"],
        version="vd",
        **kwargs)
D
dongshuilong 已提交
596
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
C
cuicheng01 已提交
597 598 599
    return model


D
dongshuilong 已提交
600
def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
C
cuicheng01 已提交
601 602 603
    """
    ResNet200_vd
    Args:
D
dongshuilong 已提交
604 605 606
        pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
                    If str, means the path of the pretrained model.
        use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
C
cuicheng01 已提交
607 608 609
    Returns:
        model: nn.Layer. Specific `ResNet200_vd` model depends on args.
    """
610 611 612 613 614
    model = ResNet(
        config=NET_CONFIG["200"],
        stages_pattern=MODEL_STAGES_PATTERN["ResNet200"],
        version="vd",
        **kwargs)
D
dongshuilong 已提交
615
    _load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
C
cuicheng01 已提交
616
    return model