“2d16d69b5f06cbd00fafe42f71d47328c9b8a7f4”上不存在“paddle/phi/kernels/cpu/searchsorted_kernel.cc”
未验证 提交 5671f9d9 编写于 作者: C cuicheng01 提交者: GitHub

Update resnet.py

上级 465a8e5f
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,9 +12,7 @@ ...@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
import paddle import paddle
...@@ -25,10 +23,36 @@ from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D ...@@ -25,10 +23,36 @@ from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform from paddle.nn.initializer import Uniform
import math import math
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain_from, load_dygraph_pretrain_from_url
MODEL_URLS = {
"ResNet18": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams",
"ResNet18_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams",
"ResNet34": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams",
"ResNet34_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_vd_pretrained.pdparams",
"ResNet50": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams",
"ResNet50_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams",
"ResNet101": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_pretrained.pdparams",
"ResNet101_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_vd_pretrained.pdparams",
"ResNet152": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_pretrained.pdparams",
"ResNet152_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_vd_pretrained.pdparams",
"ResNet200_vd": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet200_vd_pretrained.pdparams",
}
__all__ = MODEL_URLS.keys()
'''
ResNet config: dict.
key: depth of ResNet.
values: config's dict of specific model.
keys:
block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
block_depth: The number of blocks in different stages in ResNet.
num_channels: The number of channels to enter the next stage.
'''
NET_CONFIG = { NET_CONFIG = {
"18": { "18": {
"block_type": "BasicBlock", "block_depth": [2, 2, 2, 2], "num_channels": [64, 64, 128, 256]}, "block_type": "BasicBlock", "block_depth": [2, 2, 2, 2], "num_channels": [64, 64, 128, 256]},
...@@ -45,7 +69,6 @@ NET_CONFIG = { ...@@ -45,7 +69,6 @@ NET_CONFIG = {
} }
class ConvBNLayer(TheseusLayer): class ConvBNLayer(TheseusLayer):
def __init__(self, def __init__(self,
num_channels, num_channels,
...@@ -56,7 +79,7 @@ class ConvBNLayer(TheseusLayer): ...@@ -56,7 +79,7 @@ class ConvBNLayer(TheseusLayer):
is_vd_mode=False, is_vd_mode=False,
act=None, act=None,
lr_mult=1.0): lr_mult=1.0):
super(ConvBNLayer, self).__init__() super().__init__()
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
self.act = act self.act = act
self.avgpool = AvgPool2D( self.avgpool = AvgPool2D(
...@@ -72,7 +95,6 @@ class ConvBNLayer(TheseusLayer): ...@@ -72,7 +95,6 @@ class ConvBNLayer(TheseusLayer):
bias_attr=False) bias_attr=False)
self.bn = BatchNorm( self.bn = BatchNorm(
num_filters, num_filters,
act=act,
param_attr=ParamAttr(learning_rate=lr_mult), param_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult)) bias_attr=ParamAttr(learning_rate=lr_mult))
self.relu = nn.ReLU() self.relu = nn.ReLU()
...@@ -96,20 +118,20 @@ class BottleneckBlock(TheseusLayer): ...@@ -96,20 +118,20 @@ class BottleneckBlock(TheseusLayer):
if_first=False, if_first=False,
lr_mult=1.0, lr_mult=1.0,
): ):
super(BottleneckBlock, self).__init__() super().__init__()
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
act='relu', act="relu",
lr_mult=lr_mult) lr_mult=lr_mult)
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act='relu', act="relu",
lr_mult=lr_mult) lr_mult=lr_mult)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
...@@ -152,14 +174,15 @@ class BasicBlock(TheseusLayer): ...@@ -152,14 +174,15 @@ class BasicBlock(TheseusLayer):
shortcut=True, shortcut=True,
if_first=False, if_first=False,
lr_mult=1.0): lr_mult=1.0):
super(BasicBlock, self).__init__() super().__init__()
self.stride = stride self.stride = stride
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act='relu', act="relu",
lr_mult=lr_mult) lr_mult=lr_mult)
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
...@@ -167,7 +190,6 @@ class BasicBlock(TheseusLayer): ...@@ -167,7 +190,6 @@ class BasicBlock(TheseusLayer):
filter_size=3, filter_size=3,
act=None, act=None,
lr_mult=lr_mult) lr_mult=lr_mult)
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
...@@ -176,7 +198,6 @@ class BasicBlock(TheseusLayer): ...@@ -176,7 +198,6 @@ class BasicBlock(TheseusLayer):
stride=stride if if_first else 1, stride=stride if if_first else 1,
is_vd_mode=False if if_first else True, is_vd_mode=False if if_first else True,
lr_mult=lr_mult) lr_mult=lr_mult)
self.shortcut = shortcut self.shortcut = shortcut
self.relu = nn.ReLU() self.relu = nn.ReLU()
...@@ -184,43 +205,46 @@ class BasicBlock(TheseusLayer): ...@@ -184,43 +205,46 @@ class BasicBlock(TheseusLayer):
identity = x identity = x
x = self.conv0(x) x = self.conv0(x)
x = self.conv1(x) x = self.conv1(x)
if self.shortcut: if self.shortcut:
short = identity short = identity
else: else:
short = self.short(identity) short = self.short(identity)
x = paddle.add(x=x, y=short) x = paddle.add(x=x, y=short)
x = self.relu(x) x = self.relu(x)
return x return x
class ResNet(TheseusLayer): class ResNet(TheseusLayer):
"""ResNet model from """
`"Deep Residual Learning for Image Recognition" ResNet
<http://arxiv.org/abs/1512.03385>`_ paper. Args:
Parameters config: dict. config of ResNet.
---------- version: str="vb". Different version of ResNet, version vd can perform better.
config : dict of string and list class_num: int=1000. The number of classes.
Information of whole model. lr_mult_list: list. Control the learning rate of different stages.
version : str, "vb" and "vd" pretrained: (True or False) or path of pretrained_model. Whether to load the pretrained model.
Different version of ResNet, version vd can perform better. Returns:
class_dim : int, default 1000 model: nn.Layer. Specific ResNet model depends on args.
Number of classification classes.
lr_mult_list : list of float
Control the learning rate of different stages
""" """
def __init__(self, def __init__(self,
config, config,
version="vd", version="vb",
class_dim=1000, class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]): lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
super(ResNet, self).__init__() pretrained=False):
super().__init__()
self.cfg = config self.cfg = config
self.lr_mult_list = lr_mult_list self.lr_mult_list = lr_mult_list
self.is_vd_mode = version == "vd" self.is_vd_mode = version == "vd"
self.class_num = class_num
self.num_filters = [64, 128, 256, 512]
self.block_depth = self.cfg["block_depth"]
self.block_type = self.cfg["block_type"]
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
self.pretrained = pretrained
assert isinstance(self.lr_mult_list, ( assert isinstance(self.lr_mult_list, (
list, tuple list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format( )), "lr_mult_list should be in (list, tuple) but got {}".format(
...@@ -229,10 +253,10 @@ class ResNet(TheseusLayer): ...@@ -229,10 +253,10 @@ class ResNet(TheseusLayer):
self.lr_mult_list self.lr_mult_list
) == 5, "lr_mult_list length should be 5 but got {}".format( ) == 5, "lr_mult_list length should be 5 but got {}".format(
len(self.lr_mult_list)) len(self.lr_mult_list))
self.num_filters = [64, 128, 256, 512]
self.channels_mult = 1 if self.cfg["num_channels"][-1] == 256 else 4
self.stem_cfg = { self.stem_cfg = {
#num_channels, num_filters, filter_size, stride
"vb": [[3, 64, 7, 2]], "vb": [[3, 64, 7, 2]],
"vd": [[3, 32, 3, 2], "vd": [[3, 32, 3, 2],
[32, 32, 3, 1], [32, 32, 3, 1],
...@@ -244,39 +268,35 @@ class ResNet(TheseusLayer): ...@@ -244,39 +268,35 @@ class ResNet(TheseusLayer):
num_filters=out_c, num_filters=out_c,
filter_size=k, filter_size=k,
stride=s, stride=s,
act='relu', act="relu",
lr_mult=self.lr_mult_list[0]) lr_mult=self.lr_mult_list[0])
for in_c, out_c, k, s in self.stem_cfg[version] for in_c, out_c, k, s in self.stem_cfg[version]
]) ])
self.maxpool = MaxPool2D(kernel_size=3, stride=2, padding=1) self.maxpool = MaxPool2D(kernel_size=3, stride=2, padding=1)
block_list = []
self.block_list = [] for block_idx in range(len(self.block_depth)):
for block in range(len(self.cfg["block_depth"])):
shortcut = False shortcut = False
for i in range(self.cfg["block_depth"][block]): for i in range(self.block_depth[block_idx]):
self.block_list.append( block_list.append(
globals()[self.cfg["block_type"]]( globals()[self.block_type](
num_channels=self.cfg["num_channels"][block] num_channels=self.num_channels[block_idx]
if i == 0 else self.num_filters[block] * self.channels_mult, if i == 0 else self.num_filters[block_idx] * self.channels_mult,
num_filters=self.num_filters[block], num_filters=self.num_filters[block_idx],
stride=2 if i == 0 and block != 0 else 1, stride=2 if i == 0 and block_idx != 0 else 1,
shortcut=shortcut, shortcut=shortcut,
if_first=block == i == 0 if version == "vd" else True, if_first=block_idx == i == 0 if version == "vd" else True,
lr_mult=self.lr_mult_list[block + 1])) lr_mult=self.lr_mult_list[block_idx + 1]))
shortcut = True shortcut = True
self.blocks = nn.Sequential(*block_list)
self.blocks = nn.Sequential(*self.block_list)
self.avgpool = AdaptiveAvgPool2D(1) self.avgpool = AdaptiveAvgPool2D(1)
self.avgpool_channels = self.num_channels[-1] * 2
self.avgpool_channels = self.cfg["num_channels"][-1] * 2
stdv = 1.0 / math.sqrt(self.avgpool_channels * 1.0) stdv = 1.0 / math.sqrt(self.avgpool_channels * 1.0)
self.out = Linear( self.out = Linear(
self.avgpool_channels, self.avgpool_channels,
class_dim, self.class_num,
weight_attr=ParamAttr( weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv))) initializer=Uniform(-stdv, stdv)))
...@@ -291,42 +311,253 @@ class ResNet(TheseusLayer): ...@@ -291,42 +311,253 @@ class ResNet(TheseusLayer):
def ResNet18(**args): def ResNet18(**args):
"""
ResNet18
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet18` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vb", **args) model = ResNet(config=NET_CONFIG["18"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet18_vd(**args): def ResNet18_vd(**args):
"""
ResNet18_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vd", **args) model = ResNet(config=NET_CONFIG["18"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet18_vd"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet34(**args):
"""
ResNet34
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet18` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model
def ResNet34_vd(**args):
"""
ResNet34_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet34_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model
def ResNet50(**args): def ResNet50(**args):
"""
ResNet50
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet50` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vb", **args) model = ResNet(config=NET_CONFIG["50"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet50_vd(**args): def ResNet50_vd(**args):
"""
ResNet50_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vd", **args) model = ResNet(config=NET_CONFIG["50"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet50_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet101(**args): def ResNet101(**args):
"""
ResNet101
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet101` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vb", **args) model = ResNet(config=NET_CONFIG["101"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet101_vd(**args): def ResNet101_vd(**args):
"""
ResNet101_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vd", **args) model = ResNet(config=NET_CONFIG["101"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet101_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet152(**args): def ResNet152(**args):
"""
ResNet152
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet152` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vb", **args) model = ResNet(config=NET_CONFIG["152"], version="vb", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet152_vd(**args): def ResNet152_vd(**args):
"""
ResNet152_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vd", **args) model = ResNet(config=NET_CONFIG["152"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet152_vd"])
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
def ResNet200(**args):
model = ResNet(config=NET_CONFIG["200"], version="vb", **args)
return model
def ResNet200_vd(**args): def ResNet200_vd(**args):
"""
ResNet200_vd
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["200"], version="vd", **args) model = ResNet(config=NET_CONFIG["200"], version="vd", **args)
if isinstance(model.pretrained, bool):
if model.pretrained is True:
load_dygraph_pretrain_from_url(model, MODEL_URLS["ResNet200_vd"], use_ssld=True)
elif isinstance(model.pretrained, str):
load_dygraph_pretrain(model, model.pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type")
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册