Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
1a74e9cb
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a74e9cb
编写于
5月 31, 2021
作者:
C
cuicheng01
提交者:
GitHub
5月 31, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #756 from RainFrost1/develop_reg
legendary models v0.1
上级
4d246c20
2fa80851
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
1549 addition
and
839 deletion
+1549
-839
ppcls/arch/backbone/legendary_models/__init__.py
ppcls/arch/backbone/legendary_models/__init__.py
+6
-0
ppcls/arch/backbone/legendary_models/hrnet.py
ppcls/arch/backbone/legendary_models/hrnet.py
+276
-131
ppcls/arch/backbone/legendary_models/inception_v3.py
ppcls/arch/backbone/legendary_models/inception_v3.py
+312
-270
ppcls/arch/backbone/legendary_models/mobilenet_v1.py
ppcls/arch/backbone/legendary_models/mobilenet_v1.py
+104
-115
ppcls/arch/backbone/legendary_models/mobilenet_v3.py
ppcls/arch/backbone/legendary_models/mobilenet_v3.py
+557
-0
ppcls/arch/backbone/legendary_models/resnet.py
ppcls/arch/backbone/legendary_models/resnet.py
+155
-211
ppcls/arch/backbone/legendary_models/vgg.py
ppcls/arch/backbone/legendary_models/vgg.py
+139
-112
未找到文件。
ppcls/arch/backbone/legendary_models/__init__.py
浏览文件 @
1a74e9cb
from
.resnet
import
ResNet18
,
ResNet34
,
ResNet50
,
ResNet101
,
ResNet152
,
ResNet18_vd
,
ResNet34_vd
,
ResNet50_vd
,
ResNet101_vd
,
ResNet152_vd
from
.hrnet
import
HRNet_W18_C
,
HRNet_W30_C
,
HRNet_W32_C
,
HRNet_W40_C
,
HRNet_W44_C
,
HRNet_W48_C
,
HRNet_W64_C
from
.mobilenet_v1
import
MobileNetV1_x0_25
,
MobileNetV1_x0_5
,
MobileNetV1_x0_75
,
MobileNetV1
from
.mobilenet_v3
import
MobileNetV3_small_x0_35
,
MobileNetV3_small_x0_5
,
MobileNetV3_small_x0_75
,
MobileNetV3_small_x1_0
,
MobileNetV3_small_x1_25
,
MobileNetV3_large_x0_35
,
MobileNetV3_large_x0_5
,
MobileNetV3_large_x0_75
,
MobileNetV3_large_x1_0
,
MobileNetV3_large_x1_25
from
.inception_v3
import
InceptionV3
from
.vgg
import
VGG11
,
VGG13
,
VGG16
,
VGG19
ppcls/arch/backbone/legendary_models/hrnet.py
浏览文件 @
1a74e9cb
...
@@ -24,29 +24,40 @@ from paddle.nn.functional import upsample
...
@@ -24,29 +24,40 @@ from paddle.nn.functional import upsample
from
paddle.nn.initializer
import
Uniform
from
paddle.nn.initializer
import
Uniform
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
,
Identity
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
,
Identity
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
MODEL_URLS
=
{
"HRNet_W18_C"
:
""
,
"HRNet_W18_C"
:
"HRNet_W30_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W18_C_pretrained.pdparams"
,
"HRNet_W32_C"
:
""
,
"HRNet_W30_C"
:
"HRNet_W40_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W30_C_pretrained.pdparams"
,
"HRNet_W44_C"
:
""
,
"HRNet_W32_C"
:
"HRNet_W48_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W32_C_pretrained.pdparams"
,
"HRNet_W60_C"
:
""
,
"HRNet_W40_C"
:
"HRNet_W64_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W40_C_pretrained.pdparams"
,
"SE_HRNet_W18_C"
:
""
,
"HRNet_W44_C"
:
"SE_HRNet_W30_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W44_C_pretrained.pdparams"
,
"SE_HRNet_W32_C"
:
""
,
"HRNet_W48_C"
:
"SE_HRNet_W40_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W48_C_pretrained.pdparams"
,
"SE_HRNet_W44_C"
:
""
,
"HRNet_W64_C"
:
"SE_HRNet_W48_C"
:
""
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W64_C_pretrained.pdparams"
"SE_HRNet_W60_C"
:
""
,
"SE_HRNet_W64_C"
:
""
,
}
}
__all__
=
list
(
MODEL_URLS
.
keys
())
__all__
=
list
(
MODEL_URLS
.
keys
())
def
_create_act
(
act
):
if
act
==
"hardswish"
:
return
nn
.
Hardswish
()
elif
act
==
"relu"
:
return
nn
.
ReLU
()
elif
act
is
None
:
return
Identity
()
else
:
raise
RuntimeError
(
"The activation function is not supported: {}"
.
format
(
act
))
class
ConvBNLayer
(
TheseusLayer
):
class
ConvBNLayer
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_channels
,
...
@@ -55,7 +66,7 @@ class ConvBNLayer(TheseusLayer):
...
@@ -55,7 +66,7 @@ class ConvBNLayer(TheseusLayer):
stride
=
1
,
stride
=
1
,
groups
=
1
,
groups
=
1
,
act
=
"relu"
):
act
=
"relu"
):
super
(
ConvBNLayer
,
self
).
__init__
()
super
().
__init__
()
self
.
conv
=
nn
.
Conv2D
(
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
num_channels
,
in_channels
=
num_channels
,
...
@@ -65,10 +76,8 @@ class ConvBNLayer(TheseusLayer):
...
@@ -65,10 +76,8 @@ class ConvBNLayer(TheseusLayer):
padding
=
(
filter_size
-
1
)
//
2
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
groups
=
groups
,
bias_attr
=
False
)
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm
(
self
.
bn
=
nn
.
BatchNorm
(
num_filters
,
act
=
None
)
num_filters
,
self
.
act
=
_create_act
(
act
)
act
=
None
)
self
.
act
=
create_act
(
act
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
...
@@ -77,18 +86,6 @@ class ConvBNLayer(TheseusLayer):
...
@@ -77,18 +86,6 @@ class ConvBNLayer(TheseusLayer):
return
x
return
x
def
create_act
(
act
):
if
act
==
'hardswish'
:
return
nn
.
Hardswish
()
elif
act
==
'relu'
:
return
nn
.
ReLU
()
elif
act
is
None
:
return
Identity
()
else
:
raise
RuntimeError
(
'The activation function is not supported: {}'
.
format
(
act
))
class
BottleneckBlock
(
TheseusLayer
):
class
BottleneckBlock
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_channels
,
...
@@ -96,7 +93,7 @@ class BottleneckBlock(TheseusLayer):
...
@@ -96,7 +93,7 @@ class BottleneckBlock(TheseusLayer):
has_se
,
has_se
,
stride
=
1
,
stride
=
1
,
downsample
=
False
):
downsample
=
False
):
super
(
BottleneckBlock
,
self
).
__init__
()
super
().
__init__
()
self
.
has_se
=
has_se
self
.
has_se
=
has_se
self
.
downsample
=
downsample
self
.
downsample
=
downsample
...
@@ -147,11 +144,8 @@ class BottleneckBlock(TheseusLayer):
...
@@ -147,11 +144,8 @@ class BottleneckBlock(TheseusLayer):
class
BasicBlock
(
nn
.
Layer
):
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_filters
,
has_se
=
False
):
num_channels
,
super
().
__init__
()
num_filters
,
has_se
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
has_se
=
has_se
self
.
has_se
=
has_se
...
@@ -190,9 +184,9 @@ class BasicBlock(nn.Layer):
...
@@ -190,9 +184,9 @@ class BasicBlock(nn.Layer):
class
SELayer
(
TheseusLayer
):
class
SELayer
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
reduction_ratio
):
def
__init__
(
self
,
num_channels
,
num_filters
,
reduction_ratio
):
super
(
SELayer
,
self
).
__init__
()
super
().
__init__
()
self
.
pool2d_gap
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
_num_channels
=
num_channels
self
.
_num_channels
=
num_channels
...
@@ -201,8 +195,7 @@ class SELayer(TheseusLayer):
...
@@ -201,8 +195,7 @@ class SELayer(TheseusLayer):
self
.
fc_squeeze
=
nn
.
Linear
(
self
.
fc_squeeze
=
nn
.
Linear
(
num_channels
,
num_channels
,
med_ch
,
med_ch
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
Uniform
(
-
stdv
,
stdv
)))
initializer
=
Uniform
(
-
stdv
,
stdv
)))
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
stdv
=
1.0
/
math
.
sqrt
(
med_ch
*
1.0
)
stdv
=
1.0
/
math
.
sqrt
(
med_ch
*
1.0
)
self
.
fc_excitation
=
nn
.
Linear
(
self
.
fc_excitation
=
nn
.
Linear
(
...
@@ -213,7 +206,7 @@ class SELayer(TheseusLayer):
...
@@ -213,7 +206,7 @@ class SELayer(TheseusLayer):
def
forward
(
self
,
x
,
res_dict
=
None
):
def
forward
(
self
,
x
,
res_dict
=
None
):
residual
=
x
residual
=
x
x
=
self
.
pool2d_gap
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
paddle
.
squeeze
(
x
,
axis
=
[
2
,
3
])
x
=
paddle
.
squeeze
(
x
,
axis
=
[
2
,
3
])
x
=
self
.
fc_squeeze
(
x
)
x
=
self
.
fc_squeeze
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
relu
(
x
)
...
@@ -225,11 +218,8 @@ class SELayer(TheseusLayer):
...
@@ -225,11 +218,8 @@ class SELayer(TheseusLayer):
class
Stage
(
TheseusLayer
):
class
Stage
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_modules
,
num_filters
,
has_se
=
False
):
num_modules
,
super
().
__init__
()
num_filters
,
has_se
=
False
):
super
(
Stage
,
self
).
__init__
()
self
.
_num_modules
=
num_modules
self
.
_num_modules
=
num_modules
...
@@ -237,8 +227,7 @@ class Stage(TheseusLayer):
...
@@ -237,8 +227,7 @@ class Stage(TheseusLayer):
for
i
in
range
(
num_modules
):
for
i
in
range
(
num_modules
):
self
.
stage_func_list
.
append
(
self
.
stage_func_list
.
append
(
HighResolutionModule
(
HighResolutionModule
(
num_filters
=
num_filters
,
num_filters
=
num_filters
,
has_se
=
has_se
))
has_se
=
has_se
))
def
forward
(
self
,
x
,
res_dict
=
None
):
def
forward
(
self
,
x
,
res_dict
=
None
):
x
=
x
x
=
x
...
@@ -248,10 +237,8 @@ class Stage(TheseusLayer):
...
@@ -248,10 +237,8 @@ class Stage(TheseusLayer):
class
HighResolutionModule
(
TheseusLayer
):
class
HighResolutionModule
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_filters
,
has_se
=
False
):
num_filters
,
super
().
__init__
()
has_se
=
False
):
super
(
HighResolutionModule
,
self
).
__init__
()
self
.
basic_block_list
=
nn
.
LayerList
()
self
.
basic_block_list
=
nn
.
LayerList
()
...
@@ -261,11 +248,11 @@ class HighResolutionModule(TheseusLayer):
...
@@ -261,11 +248,11 @@ class HighResolutionModule(TheseusLayer):
BasicBlock
(
BasicBlock
(
num_channels
=
num_filters
[
i
],
num_channels
=
num_filters
[
i
],
num_filters
=
num_filters
[
i
],
num_filters
=
num_filters
[
i
],
has_se
=
has_se
)
for
j
in
range
(
4
)]))
has_se
=
has_se
)
for
j
in
range
(
4
)
]))
self
.
fuse_func
=
FuseLayers
(
self
.
fuse_func
=
FuseLayers
(
in_channels
=
num_filters
,
in_channels
=
num_filters
,
out_channels
=
num_filters
)
out_channels
=
num_filters
)
def
forward
(
self
,
x
,
res_dict
=
None
):
def
forward
(
self
,
x
,
res_dict
=
None
):
out
=
[]
out
=
[]
...
@@ -279,10 +266,8 @@ class HighResolutionModule(TheseusLayer):
...
@@ -279,10 +266,8 @@ class HighResolutionModule(TheseusLayer):
class
FuseLayers
(
TheseusLayer
):
class
FuseLayers
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
out_channels
):
in_channels
,
super
().
__init__
()
out_channels
):
super
(
FuseLayers
,
self
).
__init__
()
self
.
_actual_ch
=
len
(
in_channels
)
self
.
_actual_ch
=
len
(
in_channels
)
self
.
_in_channels
=
in_channels
self
.
_in_channels
=
in_channels
...
@@ -352,7 +337,7 @@ class LastClsOut(TheseusLayer):
...
@@ -352,7 +337,7 @@ class LastClsOut(TheseusLayer):
num_channel_list
,
num_channel_list
,
has_se
,
has_se
,
num_filters_list
=
[
32
,
64
,
128
,
256
]):
num_filters_list
=
[
32
,
64
,
128
,
256
]):
super
(
LastClsOut
,
self
).
__init__
()
super
().
__init__
()
self
.
func_list
=
nn
.
LayerList
()
self
.
func_list
=
nn
.
LayerList
()
for
idx
in
range
(
len
(
num_channel_list
)):
for
idx
in
range
(
len
(
num_channel_list
)):
...
@@ -378,9 +363,12 @@ class HRNet(TheseusLayer):
...
@@ -378,9 +363,12 @@ class HRNet(TheseusLayer):
width: int=18. Base channel number of HRNet.
width: int=18. Base channel number of HRNet.
has_se: bool=False. If 'True', add se module to HRNet.
has_se: bool=False. If 'True', add se module to HRNet.
class_num: int=1000. Output num of last fc layer.
class_num: int=1000. Output num of last fc layer.
Returns:
model: nn.Layer. Specific HRNet model depends on args.
"""
"""
def
__init__
(
self
,
width
=
18
,
has_se
=
False
,
class_num
=
1000
):
def
__init__
(
self
,
width
=
18
,
has_se
=
False
,
class_num
=
1000
):
super
(
HRNet
,
self
).
__init__
()
super
().
__init__
()
self
.
width
=
width
self
.
width
=
width
self
.
has_se
=
has_se
self
.
has_se
=
has_se
...
@@ -388,21 +376,23 @@ class HRNet(TheseusLayer):
...
@@ -388,21 +376,23 @@ class HRNet(TheseusLayer):
channels_2
=
[
self
.
width
,
self
.
width
*
2
]
channels_2
=
[
self
.
width
,
self
.
width
*
2
]
channels_3
=
[
self
.
width
,
self
.
width
*
2
,
self
.
width
*
4
]
channels_3
=
[
self
.
width
,
self
.
width
*
2
,
self
.
width
*
4
]
channels_4
=
[
self
.
width
,
self
.
width
*
2
,
self
.
width
*
4
,
self
.
width
*
8
]
channels_4
=
[
self
.
width
,
self
.
width
*
2
,
self
.
width
*
4
,
self
.
width
*
8
]
self
.
conv_layer1_1
=
ConvBNLayer
(
self
.
conv_layer1_1
=
ConvBNLayer
(
num_channels
=
3
,
num_channels
=
3
,
num_filters
=
64
,
num_filters
=
64
,
filter_size
=
3
,
filter_size
=
3
,
stride
=
2
,
stride
=
2
,
act
=
'relu'
)
act
=
"relu"
)
self
.
conv_layer1_2
=
ConvBNLayer
(
self
.
conv_layer1_2
=
ConvBNLayer
(
num_channels
=
64
,
num_channels
=
64
,
num_filters
=
64
,
num_filters
=
64
,
filter_size
=
3
,
filter_size
=
3
,
stride
=
2
,
stride
=
2
,
act
=
'relu'
)
act
=
"relu"
)
self
.
layer1
=
nn
.
Sequential
(
*
[
self
.
layer1
=
nn
.
Sequential
(
*
[
BottleneckBlock
(
BottleneckBlock
(
...
@@ -410,48 +400,33 @@ class HRNet(TheseusLayer):
...
@@ -410,48 +400,33 @@ class HRNet(TheseusLayer):
num_filters
=
64
,
num_filters
=
64
,
has_se
=
has_se
,
has_se
=
has_se
,
stride
=
1
,
stride
=
1
,
downsample
=
True
if
i
==
0
else
False
)
downsample
=
True
if
i
==
0
else
False
)
for
i
in
range
(
4
)
for
i
in
range
(
4
)
])
])
self
.
conv_tr1_1
=
ConvBNLayer
(
self
.
conv_tr1_1
=
ConvBNLayer
(
num_channels
=
256
,
num_channels
=
256
,
num_filters
=
width
,
filter_size
=
3
)
num_filters
=
width
,
filter_size
=
3
)
self
.
conv_tr1_2
=
ConvBNLayer
(
self
.
conv_tr1_2
=
ConvBNLayer
(
num_channels
=
256
,
num_channels
=
256
,
num_filters
=
width
*
2
,
filter_size
=
3
,
stride
=
2
)
num_filters
=
width
*
2
,
filter_size
=
3
,
stride
=
2
)
self
.
st2
=
Stage
(
self
.
st2
=
Stage
(
num_modules
=
1
,
num_modules
=
1
,
num_filters
=
channels_2
,
has_se
=
self
.
has_se
)
num_filters
=
channels_2
,
has_se
=
self
.
has_se
)
self
.
conv_tr2
=
ConvBNLayer
(
self
.
conv_tr2
=
ConvBNLayer
(
num_channels
=
width
*
2
,
num_channels
=
width
*
2
,
num_filters
=
width
*
4
,
num_filters
=
width
*
4
,
filter_size
=
3
,
filter_size
=
3
,
stride
=
2
stride
=
2
)
)
self
.
st3
=
Stage
(
self
.
st3
=
Stage
(
num_modules
=
4
,
num_modules
=
4
,
num_filters
=
channels_3
,
has_se
=
self
.
has_se
)
num_filters
=
channels_3
,
has_se
=
self
.
has_se
)
self
.
conv_tr3
=
ConvBNLayer
(
self
.
conv_tr3
=
ConvBNLayer
(
num_channels
=
width
*
4
,
num_channels
=
width
*
4
,
num_filters
=
width
*
8
,
num_filters
=
width
*
8
,
filter_size
=
3
,
filter_size
=
3
,
stride
=
2
stride
=
2
)
)
self
.
st4
=
Stage
(
self
.
st4
=
Stage
(
num_modules
=
3
,
num_modules
=
3
,
num_filters
=
channels_4
,
has_se
=
self
.
has_se
)
num_filters
=
channels_4
,
has_se
=
self
.
has_se
)
# classification
# classification
num_filters_list
=
[
32
,
64
,
128
,
256
]
num_filters_list
=
[
32
,
64
,
128
,
256
]
...
@@ -464,17 +439,14 @@ class HRNet(TheseusLayer):
...
@@ -464,17 +439,14 @@ class HRNet(TheseusLayer):
self
.
cls_head_conv_list
=
nn
.
LayerList
()
self
.
cls_head_conv_list
=
nn
.
LayerList
()
for
idx
in
range
(
3
):
for
idx
in
range
(
3
):
self
.
cls_head_conv_list
.
append
(
self
.
cls_head_conv_list
.
append
(
ConvBNLayer
(
ConvBNLayer
(
num_channels
=
num_filters_list
[
idx
]
*
4
,
num_channels
=
num_filters_list
[
idx
]
*
4
,
num_filters
=
last_num_filters
[
idx
],
num_filters
=
last_num_filters
[
idx
],
filter_size
=
3
,
filter_size
=
3
,
stride
=
2
))
stride
=
2
))
self
.
conv_last
=
ConvBNLayer
(
self
.
conv_last
=
ConvBNLayer
(
num_channels
=
1024
,
num_channels
=
1024
,
num_filters
=
2048
,
filter_size
=
1
,
stride
=
1
)
num_filters
=
2048
,
filter_size
=
1
,
stride
=
1
)
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
...
@@ -516,81 +488,254 @@ class HRNet(TheseusLayer):
...
@@ -516,81 +488,254 @@ class HRNet(TheseusLayer):
return
y
return
y
def
HRNet_W18_C
(
**
args
):
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
model
=
HRNet
(
width
=
18
,
**
args
)
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
HRNet_W18_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
HRNet_W18_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W18_C` model depends on args.
"""
model
=
HRNet
(
width
=
18
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W18_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W30_C
(
**
args
):
def
HRNet_W30_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
30
,
**
args
)
"""
HRNet_W30_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W30_C` model depends on args.
"""
model
=
HRNet
(
width
=
30
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W30_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W32_C
(
**
args
):
def
HRNet_W32_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
32
,
**
args
)
"""
HRNet_W32_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W32_C` model depends on args.
"""
model
=
HRNet
(
width
=
32
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W32_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W40_C
(
**
args
):
def
HRNet_W40_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
40
,
**
args
)
"""
HRNet_W40_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W40_C` model depends on args.
"""
model
=
HRNet
(
width
=
40
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W40_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W44_C
(
**
args
):
def
HRNet_W44_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
44
,
**
args
)
"""
HRNet_W44_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W44_C` model depends on args.
"""
model
=
HRNet
(
width
=
44
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W44_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W48_C
(
**
args
):
def
HRNet_W48_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
48
,
**
args
)
"""
HRNet_W48_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W48_C` model depends on args.
"""
model
=
HRNet
(
width
=
48
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W48_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W60_C
(
**
args
):
def
HRNet_W60_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
60
,
**
args
)
"""
HRNet_W60_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W60_C` model depends on args.
"""
model
=
HRNet
(
width
=
60
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W60_C"
],
use_ssld
)
return
model
return
model
def
HRNet_W64_C
(
**
args
):
def
HRNet_W64_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
64
,
**
args
)
"""
HRNet_W64_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `HRNet_W64_C` model depends on args.
"""
model
=
HRNet
(
width
=
64
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"HRNet_W64_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W18_C
(
**
args
):
def
SE_HRNet_W18_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
18
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W18_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W18_C` model depends on args.
"""
model
=
HRNet
(
width
=
18
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W18_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W30_C
(
**
args
):
def
SE_HRNet_W30_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
30
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W30_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W30_C` model depends on args.
"""
model
=
HRNet
(
width
=
30
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W30_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W32_C
(
**
args
):
def
SE_HRNet_W32_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
32
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W32_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W32_C` model depends on args.
"""
model
=
HRNet
(
width
=
32
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W32_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W40_C
(
**
args
):
def
SE_HRNet_W40_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
40
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W40_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W40_C` model depends on args.
"""
model
=
HRNet
(
width
=
40
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W40_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W44_C
(
**
args
):
def
SE_HRNet_W44_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
44
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W44_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W44_C` model depends on args.
"""
model
=
HRNet
(
width
=
44
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W44_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W48_C
(
**
args
):
def
SE_HRNet_W48_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
48
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W48_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W48_C` model depends on args.
"""
model
=
HRNet
(
width
=
48
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W48_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W60_C
(
**
args
):
def
SE_HRNet_W60_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
60
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W60_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W60_C` model depends on args.
"""
model
=
HRNet
(
width
=
60
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W60_C"
],
use_ssld
)
return
model
return
model
def
SE_HRNet_W64_C
(
**
args
):
def
SE_HRNet_W64_C
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
HRNet
(
width
=
64
,
has_se
=
True
,
**
args
)
"""
SE_HRNet_W64_C
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `SE_HRNet_W64_C` model depends on args.
"""
model
=
HRNet
(
width
=
64
,
has_se
=
True
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"SE_HRNet_W64_C"
],
use_ssld
)
return
model
return
model
ppcls/arch/backbone/legendary_models/inception_v3.py
浏览文件 @
1a74e9cb
...
@@ -13,39 +13,37 @@
...
@@ -13,39 +13,37 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
math
import
paddle
import
paddle
from
paddle
import
ParamAttr
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2D
,
MaxPool2D
,
AvgPool2D
from
paddle.nn
import
AdaptiveAvgPool2D
,
MaxPool2D
,
AvgPool2D
from
paddle.nn.initializer
import
Uniform
from
paddle.nn.initializer
import
Uniform
import
math
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
MODEL_URLS
=
{
"InceptionV3"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/InceptionV3_pretrained.pdparams"
,
"InceptionV3"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/InceptionV3_pretrained.pdparams"
}
}
__all__
=
MODEL_URLS
.
keys
()
__all__
=
MODEL_URLS
.
keys
()
'''
'''
InceptionV3 config: dict.
InceptionV3 config: dict.
key: inception blocks of InceptionV3.
key: inception blocks of InceptionV3.
values: conv num in different blocks.
values: conv num in different blocks.
'''
'''
NET_CONFIG
=
{
NET_CONFIG
=
{
'inception_a'
:[[
192
,
256
,
288
],
[
32
,
64
,
64
]],
"inception_a"
:
[[
192
,
256
,
288
],
[
32
,
64
,
64
]],
'inception_b'
:[
288
],
"inception_b"
:
[
288
],
'inception_c'
:
[[
768
,
768
,
768
,
768
],
[
128
,
160
,
160
,
192
]],
"inception_c"
:
[[
768
,
768
,
768
,
768
],
[
128
,
160
,
160
,
192
]],
'inception_d'
:[
768
],
"inception_d"
:
[
768
],
'inception_e'
:[
1280
,
2048
]
"inception_e"
:
[
1280
,
2048
]
}
}
class
ConvBNLayer
(
TheseusLayer
):
class
ConvBNLayer
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_channels
,
...
@@ -55,7 +53,7 @@ class ConvBNLayer(TheseusLayer):
...
@@ -55,7 +53,7 @@ class ConvBNLayer(TheseusLayer):
padding
=
0
,
padding
=
0
,
groups
=
1
,
groups
=
1
,
act
=
"relu"
):
act
=
"relu"
):
super
(
ConvBNLayer
,
self
).
__init__
()
super
().
__init__
()
self
.
act
=
act
self
.
act
=
act
self
.
conv
=
Conv2D
(
self
.
conv
=
Conv2D
(
in_channels
=
num_channels
,
in_channels
=
num_channels
,
...
@@ -65,92 +63,100 @@ class ConvBNLayer(TheseusLayer):
...
@@ -65,92 +63,100 @@ class ConvBNLayer(TheseusLayer):
padding
=
padding
,
padding
=
padding
,
groups
=
groups
,
groups
=
groups
,
bias_attr
=
False
)
bias_attr
=
False
)
self
.
batch_norm
=
BatchNorm
(
self
.
bn
=
BatchNorm
(
num_filters
)
num_filters
)
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
b
atch_norm
(
x
)
x
=
self
.
b
n
(
x
)
if
self
.
act
:
if
self
.
act
:
x
=
self
.
relu
(
x
)
x
=
self
.
relu
(
x
)
return
x
return
x
class
InceptionStem
(
TheseusLayer
):
class
InceptionStem
(
TheseusLayer
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
InceptionStem
,
self
).
__init__
()
super
().
__init__
()
self
.
conv_1a_3x3
=
ConvBNLayer
(
num_channels
=
3
,
self
.
conv_1a_3x3
=
ConvBNLayer
(
num_filters
=
32
,
num_channels
=
3
,
filter_size
=
3
,
num_filters
=
32
,
stride
=
2
,
filter_size
=
3
,
act
=
"relu"
)
stride
=
2
,
self
.
conv_2a_3x3
=
ConvBNLayer
(
num_channels
=
32
,
act
=
"relu"
)
num_filters
=
32
,
self
.
conv_2a_3x3
=
ConvBNLayer
(
filter_size
=
3
,
num_channels
=
32
,
stride
=
1
,
num_filters
=
32
,
act
=
"relu"
)
filter_size
=
3
,
self
.
conv_2b_3x3
=
ConvBNLayer
(
num_channels
=
32
,
stride
=
1
,
num_filters
=
64
,
act
=
"relu"
)
filter_size
=
3
,
self
.
conv_2b_3x3
=
ConvBNLayer
(
padding
=
1
,
num_channels
=
32
,
act
=
"relu"
)
num_filters
=
64
,
filter_size
=
3
,
self
.
maxpool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
padding
=
1
,
self
.
conv_3b_1x1
=
ConvBNLayer
(
num_channels
=
64
,
act
=
"relu"
)
num_filters
=
80
,
filter_size
=
1
,
self
.
max_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
act
=
"relu"
)
self
.
conv_3b_1x1
=
ConvBNLayer
(
self
.
conv_4a_3x3
=
ConvBNLayer
(
num_channels
=
80
,
num_channels
=
64
,
num_filters
=
80
,
filter_size
=
1
,
act
=
"relu"
)
num_filters
=
192
,
self
.
conv_4a_3x3
=
ConvBNLayer
(
filter_size
=
3
,
num_channels
=
80
,
num_filters
=
192
,
filter_size
=
3
,
act
=
"relu"
)
act
=
"relu"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv_1a_3x3
(
x
)
x
=
self
.
conv_1a_3x3
(
x
)
x
=
self
.
conv_2a_3x3
(
x
)
x
=
self
.
conv_2a_3x3
(
x
)
x
=
self
.
conv_2b_3x3
(
x
)
x
=
self
.
conv_2b_3x3
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
max
_
pool
(
x
)
x
=
self
.
conv_3b_1x1
(
x
)
x
=
self
.
conv_3b_1x1
(
x
)
x
=
self
.
conv_4a_3x3
(
x
)
x
=
self
.
conv_4a_3x3
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
max
_
pool
(
x
)
return
x
return
x
class
InceptionA
(
TheseusLayer
):
class
InceptionA
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
,
pool_features
):
def
__init__
(
self
,
num_channels
,
pool_features
):
super
(
InceptionA
,
self
).
__init__
()
super
().
__init__
()
self
.
branch1x1
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch1x1
=
ConvBNLayer
(
num_filters
=
64
,
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
64
,
act
=
"relu"
)
filter_size
=
1
,
self
.
branch5x5_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
48
,
self
.
branch5x5_1
=
ConvBNLayer
(
filter_size
=
1
,
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
48
,
self
.
branch5x5_2
=
ConvBNLayer
(
num_channels
=
48
,
filter_size
=
1
,
num_filters
=
64
,
act
=
"relu"
)
filter_size
=
5
,
self
.
branch5x5_2
=
ConvBNLayer
(
padding
=
2
,
num_channels
=
48
,
act
=
"relu"
)
num_filters
=
64
,
filter_size
=
5
,
self
.
branch3x3dbl_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
padding
=
2
,
num_filters
=
64
,
act
=
"relu"
)
filter_size
=
1
,
act
=
"relu"
)
self
.
branch3x3dbl_1
=
ConvBNLayer
(
self
.
branch3x3dbl_2
=
ConvBNLayer
(
num_channels
=
64
,
num_channels
=
num_channels
,
num_filters
=
96
,
num_filters
=
64
,
filter_size
=
3
,
filter_size
=
1
,
padding
=
1
,
act
=
"relu"
)
act
=
"relu"
)
self
.
branch3x3dbl_2
=
ConvBNLayer
(
self
.
branch3x3dbl_3
=
ConvBNLayer
(
num_channels
=
96
,
num_channels
=
64
,
num_filters
=
96
,
num_filters
=
96
,
filter_size
=
3
,
filter_size
=
3
,
padding
=
1
,
padding
=
1
,
act
=
"relu"
)
act
=
"relu"
)
self
.
branch_pool
=
AvgPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
exclusive
=
False
)
self
.
branch3x3dbl_3
=
ConvBNLayer
(
self
.
branch_pool_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_channels
=
96
,
num_filters
=
pool_features
,
num_filters
=
96
,
filter_size
=
1
,
filter_size
=
3
,
act
=
"relu"
)
padding
=
1
,
act
=
"relu"
)
self
.
branch_pool
=
AvgPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
exclusive
=
False
)
self
.
branch_pool_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
pool_features
,
filter_size
=
1
,
act
=
"relu"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch1x1
=
self
.
branch1x1
(
x
)
...
@@ -163,34 +169,39 @@ class InceptionA(TheseusLayer):
...
@@ -163,34 +169,39 @@ class InceptionA(TheseusLayer):
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool_conv
(
branch_pool
)
branch_pool
=
self
.
branch_pool_conv
(
branch_pool
)
x
=
paddle
.
concat
([
branch1x1
,
branch5x5
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
x
=
paddle
.
concat
(
[
branch1x1
,
branch5x5
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
return
x
return
x
class
InceptionB
(
TheseusLayer
):
class
InceptionB
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
):
def
__init__
(
self
,
num_channels
):
super
(
InceptionB
,
self
).
__init__
()
super
().
__init__
()
self
.
branch3x3
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch3x3
=
ConvBNLayer
(
num_filters
=
384
,
num_channels
=
num_channels
,
filter_size
=
3
,
num_filters
=
384
,
stride
=
2
,
filter_size
=
3
,
act
=
"relu"
)
stride
=
2
,
self
.
branch3x3dbl_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
64
,
self
.
branch3x3dbl_1
=
ConvBNLayer
(
filter_size
=
1
,
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
64
,
self
.
branch3x3dbl_2
=
ConvBNLayer
(
num_channels
=
64
,
filter_size
=
1
,
num_filters
=
96
,
act
=
"relu"
)
filter_size
=
3
,
self
.
branch3x3dbl_2
=
ConvBNLayer
(
padding
=
1
,
num_channels
=
64
,
act
=
"relu"
)
num_filters
=
96
,
self
.
branch3x3dbl_3
=
ConvBNLayer
(
num_channels
=
96
,
filter_size
=
3
,
num_filters
=
96
,
padding
=
1
,
filter_size
=
3
,
act
=
"relu"
)
stride
=
2
,
self
.
branch3x3dbl_3
=
ConvBNLayer
(
act
=
"relu"
)
num_channels
=
96
,
num_filters
=
96
,
filter_size
=
3
,
stride
=
2
,
act
=
"relu"
)
self
.
branch_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
)
self
.
branch_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
branch3x3
=
self
.
branch3x3
(
x
)
branch3x3
=
self
.
branch3x3
(
x
)
...
@@ -204,64 +215,75 @@ class InceptionB(TheseusLayer):
...
@@ -204,64 +215,75 @@ class InceptionB(TheseusLayer):
return
x
return
x
class
InceptionC
(
TheseusLayer
):
class
InceptionC
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
,
channels_7x7
):
def
__init__
(
self
,
num_channels
,
channels_7x7
):
super
(
InceptionC
,
self
).
__init__
()
super
().
__init__
()
self
.
branch1x1
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch1x1
=
ConvBNLayer
(
num_filters
=
192
,
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
192
,
act
=
"relu"
)
filter_size
=
1
,
act
=
"relu"
)
self
.
branch7x7_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch7x7_1
=
ConvBNLayer
(
num_filters
=
channels_7x7
,
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
channels_7x7
,
stride
=
1
,
filter_size
=
1
,
act
=
"relu"
)
stride
=
1
,
self
.
branch7x7_2
=
ConvBNLayer
(
num_channels
=
channels_7x7
,
act
=
"relu"
)
num_filters
=
channels_7x7
,
self
.
branch7x7_2
=
ConvBNLayer
(
filter_size
=
(
1
,
7
),
num_channels
=
channels_7x7
,
stride
=
1
,
num_filters
=
channels_7x7
,
padding
=
(
0
,
3
),
filter_size
=
(
1
,
7
),
act
=
"relu"
)
stride
=
1
,
self
.
branch7x7_3
=
ConvBNLayer
(
num_channels
=
channels_7x7
,
padding
=
(
0
,
3
),
num_filters
=
192
,
act
=
"relu"
)
filter_size
=
(
7
,
1
),
self
.
branch7x7_3
=
ConvBNLayer
(
stride
=
1
,
num_channels
=
channels_7x7
,
padding
=
(
3
,
0
),
num_filters
=
192
,
act
=
"relu"
)
filter_size
=
(
7
,
1
),
stride
=
1
,
self
.
branch7x7dbl_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
padding
=
(
3
,
0
),
num_filters
=
channels_7x7
,
act
=
"relu"
)
filter_size
=
1
,
act
=
"relu"
)
self
.
branch7x7dbl_1
=
ConvBNLayer
(
self
.
branch7x7dbl_2
=
ConvBNLayer
(
num_channels
=
channels_7x7
,
num_channels
=
num_channels
,
num_filters
=
channels_7x7
,
num_filters
=
channels_7x7
,
filter_size
=
(
7
,
1
),
filter_size
=
1
,
padding
=
(
3
,
0
),
act
=
"relu"
)
act
=
"relu"
)
self
.
branch7x7dbl_2
=
ConvBNLayer
(
self
.
branch7x7dbl_3
=
ConvBNLayer
(
num_channels
=
channels_7x7
,
num_channels
=
channels_7x7
,
num_filters
=
channels_7x7
,
num_filters
=
channels_7x7
,
filter_size
=
(
1
,
7
),
filter_size
=
(
7
,
1
),
padding
=
(
0
,
3
),
padding
=
(
3
,
0
),
act
=
"relu"
)
act
=
"relu"
)
self
.
branch7x7dbl_4
=
ConvBNLayer
(
num_channels
=
channels_7x7
,
self
.
branch7x7dbl_3
=
ConvBNLayer
(
num_filters
=
channels_7x7
,
num_channels
=
channels_7x7
,
filter_size
=
(
7
,
1
),
num_filters
=
channels_7x7
,
padding
=
(
3
,
0
),
filter_size
=
(
1
,
7
),
act
=
"relu"
)
padding
=
(
0
,
3
),
self
.
branch7x7dbl_5
=
ConvBNLayer
(
num_channels
=
channels_7x7
,
act
=
"relu"
)
num_filters
=
192
,
self
.
branch7x7dbl_4
=
ConvBNLayer
(
filter_size
=
(
1
,
7
),
num_channels
=
channels_7x7
,
padding
=
(
0
,
3
),
num_filters
=
channels_7x7
,
act
=
"relu"
)
filter_size
=
(
7
,
1
),
padding
=
(
3
,
0
),
self
.
branch_pool
=
AvgPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
exclusive
=
False
)
act
=
"relu"
)
self
.
branch_pool_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch7x7dbl_5
=
ConvBNLayer
(
num_filters
=
192
,
num_channels
=
channels_7x7
,
filter_size
=
1
,
num_filters
=
192
,
act
=
"relu"
)
filter_size
=
(
1
,
7
),
padding
=
(
0
,
3
),
act
=
"relu"
)
self
.
branch_pool
=
AvgPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
exclusive
=
False
)
self
.
branch_pool_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
192
,
filter_size
=
1
,
act
=
"relu"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch1x1
=
self
.
branch1x1
(
x
)
...
@@ -278,41 +300,49 @@ class InceptionC(TheseusLayer):
...
@@ -278,41 +300,49 @@ class InceptionC(TheseusLayer):
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool_conv
(
branch_pool
)
branch_pool
=
self
.
branch_pool_conv
(
branch_pool
)
x
=
paddle
.
concat
([
branch1x1
,
branch7x7
,
branch7x7dbl
,
branch_pool
],
axis
=
1
)
x
=
paddle
.
concat
(
[
branch1x1
,
branch7x7
,
branch7x7dbl
,
branch_pool
],
axis
=
1
)
return
x
return
x
class
InceptionD
(
TheseusLayer
):
class
InceptionD
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
):
def
__init__
(
self
,
num_channels
):
super
(
InceptionD
,
self
).
__init__
()
super
().
__init__
()
self
.
branch3x3_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch3x3_1
=
ConvBNLayer
(
num_filters
=
192
,
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
192
,
act
=
"relu"
)
filter_size
=
1
,
self
.
branch3x3_2
=
ConvBNLayer
(
num_channels
=
192
,
act
=
"relu"
)
num_filters
=
320
,
self
.
branch3x3_2
=
ConvBNLayer
(
filter_size
=
3
,
num_channels
=
192
,
stride
=
2
,
num_filters
=
320
,
act
=
"relu"
)
filter_size
=
3
,
self
.
branch7x7x3_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
stride
=
2
,
num_filters
=
192
,
act
=
"relu"
)
filter_size
=
1
,
self
.
branch7x7x3_1
=
ConvBNLayer
(
act
=
"relu"
)
num_channels
=
num_channels
,
self
.
branch7x7x3_2
=
ConvBNLayer
(
num_channels
=
192
,
num_filters
=
192
,
num_filters
=
192
,
filter_size
=
1
,
filter_size
=
(
1
,
7
),
act
=
"relu"
)
padding
=
(
0
,
3
),
self
.
branch7x7x3_2
=
ConvBNLayer
(
act
=
"relu"
)
num_channels
=
192
,
self
.
branch7x7x3_3
=
ConvBNLayer
(
num_channels
=
192
,
num_filters
=
192
,
num_filters
=
192
,
filter_size
=
(
1
,
7
),
filter_size
=
(
7
,
1
),
padding
=
(
0
,
3
),
padding
=
(
3
,
0
),
act
=
"relu"
)
act
=
"relu"
)
self
.
branch7x7x3_3
=
ConvBNLayer
(
self
.
branch7x7x3_4
=
ConvBNLayer
(
num_channels
=
192
,
num_channels
=
192
,
num_filters
=
192
,
num_filters
=
192
,
filter_size
=
3
,
filter_size
=
(
7
,
1
),
stride
=
2
,
padding
=
(
3
,
0
),
act
=
"relu"
)
act
=
"relu"
)
self
.
branch7x7x3_4
=
ConvBNLayer
(
num_channels
=
192
,
num_filters
=
192
,
filter_size
=
3
,
stride
=
2
,
act
=
"relu"
)
self
.
branch_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
)
self
.
branch_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -325,56 +355,68 @@ class InceptionD(TheseusLayer):
...
@@ -325,56 +355,68 @@ class InceptionD(TheseusLayer):
branch7x7x3
=
self
.
branch7x7x3_4
(
branch7x7x3
)
branch7x7x3
=
self
.
branch7x7x3_4
(
branch7x7x3
)
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool
(
x
)
x
=
paddle
.
concat
([
branch3x3
,
branch7x7x3
,
branch_pool
],
axis
=
1
)
x
=
paddle
.
concat
([
branch3x3
,
branch7x7x3
,
branch_pool
],
axis
=
1
)
return
x
return
x
class
InceptionE
(
TheseusLayer
):
class
InceptionE
(
TheseusLayer
):
def
__init__
(
self
,
num_channels
):
def
__init__
(
self
,
num_channels
):
super
(
InceptionE
,
self
).
__init__
()
super
().
__init__
()
self
.
branch1x1
=
ConvBNLayer
(
num_channels
=
num_channels
,
self
.
branch1x1
=
ConvBNLayer
(
num_filters
=
320
,
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
320
,
act
=
"relu"
)
filter_size
=
1
,
self
.
branch3x3_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
384
,
self
.
branch3x3_1
=
ConvBNLayer
(
filter_size
=
1
,
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
384
,
self
.
branch3x3_2a
=
ConvBNLayer
(
num_channels
=
384
,
filter_size
=
1
,
num_filters
=
384
,
act
=
"relu"
)
filter_size
=
(
1
,
3
),
self
.
branch3x3_2a
=
ConvBNLayer
(
padding
=
(
0
,
1
),
num_channels
=
384
,
act
=
"relu"
)
num_filters
=
384
,
self
.
branch3x3_2b
=
ConvBNLayer
(
num_channels
=
384
,
filter_size
=
(
1
,
3
),
num_filters
=
384
,
padding
=
(
0
,
1
),
filter_size
=
(
3
,
1
),
act
=
"relu"
)
padding
=
(
1
,
0
),
self
.
branch3x3_2b
=
ConvBNLayer
(
act
=
"relu"
)
num_channels
=
384
,
num_filters
=
384
,
self
.
branch3x3dbl_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
filter_size
=
(
3
,
1
),
num_filters
=
448
,
padding
=
(
1
,
0
),
filter_size
=
1
,
act
=
"relu"
)
act
=
"relu"
)
self
.
branch3x3dbl_2
=
ConvBNLayer
(
num_channels
=
448
,
self
.
branch3x3dbl_1
=
ConvBNLayer
(
num_filters
=
384
,
num_channels
=
num_channels
,
filter_size
=
3
,
num_filters
=
448
,
padding
=
1
,
filter_size
=
1
,
act
=
"relu"
)
act
=
"relu"
)
self
.
branch3x3dbl_3a
=
ConvBNLayer
(
num_channels
=
384
,
self
.
branch3x3dbl_2
=
ConvBNLayer
(
num_filters
=
384
,
num_channels
=
448
,
filter_size
=
(
1
,
3
),
num_filters
=
384
,
padding
=
(
0
,
1
),
filter_size
=
3
,
act
=
"relu"
)
padding
=
1
,
self
.
branch3x3dbl_3b
=
ConvBNLayer
(
num_channels
=
384
,
act
=
"relu"
)
num_filters
=
384
,
self
.
branch3x3dbl_3a
=
ConvBNLayer
(
filter_size
=
(
3
,
1
),
num_channels
=
384
,
padding
=
(
1
,
0
),
num_filters
=
384
,
act
=
"relu"
)
filter_size
=
(
1
,
3
),
self
.
branch_pool
=
AvgPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
exclusive
=
False
)
padding
=
(
0
,
1
),
self
.
branch_pool_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
act
=
"relu"
)
num_filters
=
192
,
self
.
branch3x3dbl_3b
=
ConvBNLayer
(
filter_size
=
1
,
num_channels
=
384
,
act
=
"relu"
)
num_filters
=
384
,
filter_size
=
(
3
,
1
),
padding
=
(
1
,
0
),
act
=
"relu"
)
self
.
branch_pool
=
AvgPool2D
(
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
exclusive
=
False
)
self
.
branch_pool_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
192
,
filter_size
=
1
,
act
=
"relu"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
branch1x1
=
self
.
branch1x1
(
x
)
branch1x1
=
self
.
branch1x1
(
x
)
...
@@ -396,8 +438,9 @@ class InceptionE(TheseusLayer):
...
@@ -396,8 +438,9 @@ class InceptionE(TheseusLayer):
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool
(
x
)
branch_pool
=
self
.
branch_pool_conv
(
branch_pool
)
branch_pool
=
self
.
branch_pool_conv
(
branch_pool
)
x
=
paddle
.
concat
([
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
x
=
paddle
.
concat
(
return
x
[
branch1x1
,
branch3x3
,
branch3x3dbl
,
branch_pool
],
axis
=
1
)
return
x
class
Inception_V3
(
TheseusLayer
):
class
Inception_V3
(
TheseusLayer
):
...
@@ -410,25 +453,21 @@ class Inception_V3(TheseusLayer):
...
@@ -410,25 +453,21 @@ class Inception_V3(TheseusLayer):
Returns:
Returns:
model: nn.Layer. Specific Inception_V3 model depends on args.
model: nn.Layer. Specific Inception_V3 model depends on args.
"""
"""
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
class_num
=
1000
):
class_num
=
1000
,
super
().
__init__
()
pretrained
=
False
,
**
kwargs
):
self
.
inception_a_list
=
config
[
"inception_a"
]
super
(
Inception_V3
,
self
).
__init__
()
self
.
inception_c_list
=
config
[
"inception_c"
]
self
.
inception_b_list
=
config
[
"inception_b"
]
self
.
inception_a_list
=
config
[
'inception_a'
]
self
.
inception_d_list
=
config
[
"inception_d"
]
self
.
inception_c_list
=
config
[
'inception_c'
]
self
.
inception_e_list
=
config
[
"inception_e"
]
self
.
inception_b_list
=
config
[
'inception_b'
]
self
.
inception_d_list
=
config
[
'inception_d'
]
self
.
inception_e_list
=
config
[
'inception_e'
]
self
.
pretrained
=
pretrained
self
.
inception_stem
=
InceptionStem
()
self
.
inception_stem
=
InceptionStem
()
self
.
inception_block_list
=
nn
.
LayerList
()
self
.
inception_block_list
=
nn
.
LayerList
()
for
i
in
range
(
len
(
self
.
inception_a_list
[
0
])):
for
i
in
range
(
len
(
self
.
inception_a_list
[
0
])):
inception_a
=
InceptionA
(
self
.
inception_a_list
[
0
][
i
],
inception_a
=
InceptionA
(
self
.
inception_a_list
[
0
][
i
],
self
.
inception_a_list
[
1
][
i
])
self
.
inception_a_list
[
1
][
i
])
self
.
inception_block_list
.
append
(
inception_a
)
self
.
inception_block_list
.
append
(
inception_a
)
...
@@ -437,7 +476,7 @@ class Inception_V3(TheseusLayer):
...
@@ -437,7 +476,7 @@ class Inception_V3(TheseusLayer):
self
.
inception_block_list
.
append
(
inception_b
)
self
.
inception_block_list
.
append
(
inception_b
)
for
i
in
range
(
len
(
self
.
inception_c_list
[
0
])):
for
i
in
range
(
len
(
self
.
inception_c_list
[
0
])):
inception_c
=
InceptionC
(
self
.
inception_c_list
[
0
][
i
],
inception_c
=
InceptionC
(
self
.
inception_c_list
[
0
][
i
],
self
.
inception_c_list
[
1
][
i
])
self
.
inception_c_list
[
1
][
i
])
self
.
inception_block_list
.
append
(
inception_c
)
self
.
inception_block_list
.
append
(
inception_c
)
...
@@ -448,21 +487,20 @@ class Inception_V3(TheseusLayer):
...
@@ -448,21 +487,20 @@ class Inception_V3(TheseusLayer):
for
i
in
range
(
len
(
self
.
inception_e_list
)):
for
i
in
range
(
len
(
self
.
inception_e_list
)):
inception_e
=
InceptionE
(
self
.
inception_e_list
[
i
])
inception_e
=
InceptionE
(
self
.
inception_e_list
[
i
])
self
.
inception_block_list
.
append
(
inception_e
)
self
.
inception_block_list
.
append
(
inception_e
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
dropout
=
Dropout
(
p
=
0.2
,
mode
=
"downscale_in_infer"
)
self
.
dropout
=
Dropout
(
p
=
0.2
,
mode
=
"downscale_in_infer"
)
stdv
=
1.0
/
math
.
sqrt
(
2048
*
1.0
)
stdv
=
1.0
/
math
.
sqrt
(
2048
*
1.0
)
self
.
fc
=
Linear
(
self
.
fc
=
Linear
(
2048
,
2048
,
class_num
,
class_num
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
Uniform
(
-
stdv
,
stdv
)),
initializer
=
Uniform
(
-
stdv
,
stdv
)),
bias_attr
=
ParamAttr
())
bias_attr
=
ParamAttr
())
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
inception_stem
(
x
)
x
=
self
.
inception_stem
(
x
)
for
inception_block
in
self
.
inception_block_list
:
for
inception_block
in
self
.
inception_block_list
:
x
=
inception_block
(
x
)
x
=
inception_block
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
2048
])
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
2048
])
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
...
@@ -470,25 +508,29 @@ class Inception_V3(TheseusLayer):
...
@@ -470,25 +508,29 @@ class Inception_V3(TheseusLayer):
return
x
return
x
def
InceptionV3
(
**
kwargs
):
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
InceptionV3
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
"""
InceptionV3
InceptionV3
Args:
Args:
kwargs:
pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
class_num: int=1000. Output dim of last fc layer
.
if str, means the path of the pretrained model
.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model
.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True
.
Returns:
Returns:
model: nn.Layer. Specific `InceptionV3` model
model: nn.Layer. Specific `InceptionV3` model
"""
"""
model
=
Inception_V3
(
NET_CONFIG
,
**
kwargs
)
model
=
Inception_V3
(
NET_CONFIG
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"InceptionV3"
],
use_ssld
)
if
isinstance
(
model
.
pretrained
,
bool
):
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"InceptionV3"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
ppcls/arch/backbone/legendary_models/mobilenet_v1.py
浏览文件 @
1a74e9cb
...
@@ -14,8 +14,6 @@
...
@@ -14,8 +14,6 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
ReLU
,
Flatten
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
ReLU
,
Flatten
...
@@ -23,19 +21,22 @@ from paddle.nn import AdaptiveAvgPool2D
...
@@ -23,19 +21,22 @@ from paddle.nn import AdaptiveAvgPool2D
from
paddle.nn.initializer
import
KaimingNormal
from
paddle.nn.initializer
import
KaimingNormal
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain_from
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
MODEL_URLS
=
{
"MobileNetV1_x0_25"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_25_pretrained.pdparams"
,
"MobileNetV1_x0_25"
:
"MobileNetV1_x0_5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_5_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_25_pretrained.pdparams"
,
"MobileNetV1_x0_75"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_x0_75_pretrained.pdparams"
,
"MobileNetV1_x0_5"
:
"MobileNetV1"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV1_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_5_pretrained.pdparams"
,
"MobileNetV1_x0_75"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_x0_75_pretrained.pdparams"
,
"MobileNetV1"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_pretrained.pdparams"
}
}
__all__
=
MODEL_URLS
.
keys
()
__all__
=
MODEL_URLS
.
keys
()
class
ConvBNLayer
(
TheseusLayer
):
class
ConvBNLayer
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_channels
,
...
@@ -44,7 +45,7 @@ class ConvBNLayer(TheseusLayer):
...
@@ -44,7 +45,7 @@ class ConvBNLayer(TheseusLayer):
stride
,
stride
,
padding
,
padding
,
num_groups
=
1
):
num_groups
=
1
):
super
(
ConvBNLayer
,
self
).
__init__
()
super
().
__init__
()
self
.
conv
=
Conv2D
(
self
.
conv
=
Conv2D
(
in_channels
=
num_channels
,
in_channels
=
num_channels
,
...
@@ -55,9 +56,7 @@ class ConvBNLayer(TheseusLayer):
...
@@ -55,9 +56,7 @@ class ConvBNLayer(TheseusLayer):
groups
=
num_groups
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
bn
=
BatchNorm
(
num_filters
)
self
.
bn
=
BatchNorm
(
num_filters
)
self
.
relu
=
ReLU
()
self
.
relu
=
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -68,14 +67,9 @@ class ConvBNLayer(TheseusLayer):
...
@@ -68,14 +67,9 @@ class ConvBNLayer(TheseusLayer):
class
DepthwiseSeparable
(
TheseusLayer
):
class
DepthwiseSeparable
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_filters1
,
num_filters2
,
num_groups
,
num_channels
,
stride
,
scale
):
num_filters1
,
super
().
__init__
()
num_filters2
,
num_groups
,
stride
,
scale
):
super
(
DepthwiseSeparable
,
self
).
__init__
()
self
.
depthwise_conv
=
ConvBNLayer
(
self
.
depthwise_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_channels
=
num_channels
,
...
@@ -99,10 +93,18 @@ class DepthwiseSeparable(TheseusLayer):
...
@@ -99,10 +93,18 @@ class DepthwiseSeparable(TheseusLayer):
class
MobileNet
(
TheseusLayer
):
class
MobileNet
(
TheseusLayer
):
def
__init__
(
self
,
scale
=
1.0
,
class_num
=
1000
,
pretrained
=
False
):
"""
super
(
MobileNet
,
self
).
__init__
()
MobileNet
Args:
scale: float=1.0. The coefficient that controls the size of network parameters.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific MobileNet model depends on args.
"""
def
__init__
(
self
,
scale
=
1.0
,
class_num
=
1000
):
super
().
__init__
()
self
.
scale
=
scale
self
.
scale
=
scale
self
.
pretrained
=
pretrained
self
.
conv
=
ConvBNLayer
(
self
.
conv
=
ConvBNLayer
(
num_channels
=
3
,
num_channels
=
3
,
...
@@ -110,30 +112,31 @@ class MobileNet(TheseusLayer):
...
@@ -110,30 +112,31 @@ class MobileNet(TheseusLayer):
num_filters
=
int
(
32
*
scale
),
num_filters
=
int
(
32
*
scale
),
stride
=
2
,
stride
=
2
,
padding
=
1
)
padding
=
1
)
#num_channels, num_filters1, num_filters2, num_groups, stride
#num_channels, num_filters1, num_filters2, num_groups, stride
self
.
cfg
=
[[
int
(
32
*
scale
),
32
,
64
,
32
,
1
],
self
.
cfg
=
[[
int
(
32
*
scale
),
32
,
64
,
32
,
1
],
[
int
(
64
*
scale
),
64
,
128
,
64
,
2
],
[
int
(
64
*
scale
),
64
,
128
,
64
,
2
],
[
int
(
128
*
scale
),
128
,
128
,
128
,
1
],
[
int
(
128
*
scale
),
128
,
128
,
128
,
1
],
[
int
(
128
*
scale
),
128
,
256
,
128
,
2
],
[
int
(
128
*
scale
),
128
,
256
,
128
,
2
],
[
int
(
256
*
scale
),
256
,
256
,
256
,
1
],
[
int
(
256
*
scale
),
256
,
256
,
256
,
1
],
[
int
(
256
*
scale
),
256
,
512
,
256
,
2
],
[
int
(
256
*
scale
),
256
,
512
,
256
,
2
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
512
,
512
,
1
],
[
int
(
512
*
scale
),
512
,
1024
,
512
,
2
],
[
int
(
512
*
scale
),
512
,
1024
,
512
,
2
],
[
int
(
1024
*
scale
),
1024
,
1024
,
1024
,
1
]]
[
int
(
1024
*
scale
),
1024
,
1024
,
1024
,
1
]]
self
.
blocks
=
nn
.
Sequential
(
*
[
self
.
blocks
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
DepthwiseSeparable
(
num_channels
=
params
[
0
],
num_channels
=
params
[
0
],
num_filters1
=
params
[
1
],
num_filters1
=
params
[
1
],
num_filters2
=
params
[
2
],
num_filters2
=
params
[
2
],
num_groups
=
params
[
3
],
num_groups
=
params
[
3
],
stride
=
params
[
4
],
stride
=
params
[
4
],
scale
=
scale
)
for
params
in
self
.
cfg
])
scale
=
scale
)
for
params
in
self
.
cfg
])
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
flatten
=
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
flatten
=
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
...
@@ -142,7 +145,7 @@ class MobileNet(TheseusLayer):
...
@@ -142,7 +145,7 @@ class MobileNet(TheseusLayer):
int
(
1024
*
scale
),
int
(
1024
*
scale
),
class_num
,
class_num
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()))
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
blocks
(
x
)
x
=
self
.
blocks
(
x
)
...
@@ -152,91 +155,77 @@ class MobileNet(TheseusLayer):
...
@@ -152,91 +155,77 @@ class MobileNet(TheseusLayer):
return
x
return
x
def
MobileNetV1_x0_25
(
**
args
):
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
"""
if
pretrained
is
False
:
MobileNetV1_x0_25
pass
Args:
elif
pretrained
is
True
:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
kwargs:
elif
isinstance
(
pretrained
,
str
):
class_num: int=1000. Output dim of last fc layer.
load_dygraph_pretrain
(
model
,
pretrained
)
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
"""
model
=
MobileNet
(
scale
=
0.25
,
**
args
)
if
isinstance
(
model
.
pretrained
,
bool
):
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"MobileNetV1_x0_25"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type
"
)
"pretrained type is not available. Please use `string` or `boolean` type
."
return
model
)
def
MobileNetV1_x0_
5
(
**
args
):
def
MobileNetV1_x0_
25
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
MobileNetV1_x0_
5
MobileNetV1_x0_2
5
Args:
Args:
pretrained: bool=False
. If `True` load pretrained parameters, `False` otherwise.
pretrained: bool=False or str
. If `True` load pretrained parameters, `False` otherwise.
kwargs:
If str, means the path of the pretrained model.
class_num: int=1000. Output dim of last fc layer
.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True
.
Returns:
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_
5` model depends on args.
model: nn.Layer. Specific `MobileNetV1_x0_2
5` model depends on args.
"""
"""
model
=
MobileNet
(
scale
=
0.5
,
**
args
)
model
=
MobileNet
(
scale
=
0.25
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV1_x0_25"
],
if
model
.
pretrained
is
True
:
use_ssld
)
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"MobileNetV1_x0_5"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
MobileNetV1_x0_
75
(
**
args
):
def
MobileNetV1_x0_
5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
MobileNetV1_x0_7
5
MobileNetV1_x0_
5
Args:
Args:
pretrained: bool=False
. If `True` load pretrained parameters, `False` otherwise.
pretrained: bool=False or str
. If `True` load pretrained parameters, `False` otherwise.
kwargs:
If str, means the path of the pretrained model.
class_num: int=1000. Output dim of last fc layer
.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True
.
Returns:
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_7
5` model depends on args.
model: nn.Layer. Specific `MobileNetV1_x0_
5` model depends on args.
"""
"""
model
=
MobileNet
(
scale
=
0.75
,
**
args
)
model
=
MobileNet
(
scale
=
0.5
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV1_x0_5"
],
if
model
.
pretrained
is
True
:
use_ssld
)
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"MobileNetV1_x0_75"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
MobileNetV1
(
**
args
):
def
MobileNetV1
_x0_75
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
MobileNetV1
MobileNetV1_x0_75
Args:
Args:
pretrained: bool=False
. If `True` load pretrained parameters, `False` otherwise.
pretrained: bool=False or str
. If `True` load pretrained parameters, `False` otherwise.
kwargs:
If str, means the path of the pretrained model.
class_num: int=1000. Output dim of last fc layer
.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True
.
Returns:
Returns:
model: nn.Layer. Specific `MobileNetV1
` model depends on args.
model: nn.Layer. Specific `MobileNetV1_x0_75
` model depends on args.
"""
"""
model
=
MobileNet
(
scale
=
1.0
,
**
args
)
model
=
MobileNet
(
scale
=
0.75
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV1_x0_75"
],
if
model
.
pretrained
is
True
:
use_ssld
)
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"MobileNetV1"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
MobileNetV1
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV1
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV1` model depends on args.
"""
model
=
MobileNet
(
scale
=
1.0
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV1"
],
use_ssld
)
return
model
ppcls/arch/backbone/legendary_models/mobilenet_v3.py
0 → 100644
浏览文件 @
1a74e9cb
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
paddle
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
,
Conv2D
,
Dropout
,
Linear
from
paddle.regularizer
import
L2Decay
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"MobileNetV3_small_x0_35"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_35_pretrained.pdparams"
,
"MobileNetV3_small_x0_5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_5_pretrained.pdparams"
,
"MobileNetV3_small_x0_75"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_75_pretrained.pdparams"
,
"MobileNetV3_small_x1_0"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_pretrained.pdparams"
,
"MobileNetV3_small_x1_25"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_25_pretrained.pdparams"
,
"MobileNetV3_large_x0_35"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_35_pretrained.pdparams"
,
"MobileNetV3_large_x0_5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_5_pretrained.pdparams"
,
"MobileNetV3_large_x0_75"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x0_75_pretrained.pdparams"
,
"MobileNetV3_large_x1_0"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_0_pretrained.pdparams"
,
"MobileNetV3_large_x1_25"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_25_pretrained.pdparams"
,
}
__all__
=
MODEL_URLS
.
keys
()
# "large", "small" is just for MobinetV3_large, MobileNetV3_small respectively.
# The type of "large" or "small" config is a list. Each element(list) represents a depthwise block, which is composed of k, exp, se, act, s.
# k: kernel_size
# exp: middle channel number in depthwise block
# c: output channel number in depthwise block
# se: whether to use SE block
# act: which activation to use
# s: stride in depthwise block
NET_CONFIG
=
{
"large"
:
[
# k, exp, c, se, act, s
[
3
,
16
,
16
,
False
,
"relu"
,
1
],
[
3
,
64
,
24
,
False
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
1
],
[
5
,
72
,
40
,
True
,
"relu"
,
2
],
[
5
,
120
,
40
,
True
,
"relu"
,
1
],
[
5
,
120
,
40
,
True
,
"relu"
,
1
],
[
3
,
240
,
80
,
False
,
"hardswish"
,
2
],
[
3
,
200
,
80
,
False
,
"hardswish"
,
1
],
[
3
,
184
,
80
,
False
,
"hardswish"
,
1
],
[
3
,
184
,
80
,
False
,
"hardswish"
,
1
],
[
3
,
480
,
112
,
True
,
"hardswish"
,
1
],
[
3
,
672
,
112
,
True
,
"hardswish"
,
1
],
[
5
,
672
,
160
,
True
,
"hardswish"
,
2
],
[
5
,
960
,
160
,
True
,
"hardswish"
,
1
],
[
5
,
960
,
160
,
True
,
"hardswish"
,
1
],
],
"small"
:
[
# k, exp, c, se, act, s
[
3
,
16
,
16
,
True
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
2
],
[
3
,
88
,
24
,
False
,
"relu"
,
1
],
[
5
,
96
,
40
,
True
,
"hardswish"
,
2
],
[
5
,
240
,
40
,
True
,
"hardswish"
,
1
],
[
5
,
240
,
40
,
True
,
"hardswish"
,
1
],
[
5
,
120
,
48
,
True
,
"hardswish"
,
1
],
[
5
,
144
,
48
,
True
,
"hardswish"
,
1
],
[
5
,
288
,
96
,
True
,
"hardswish"
,
2
],
[
5
,
576
,
96
,
True
,
"hardswish"
,
1
],
[
5
,
576
,
96
,
True
,
"hardswish"
,
1
],
]
}
# first conv output channel number in MobileNetV3
STEM_CONV_NUMBER
=
16
# last second conv output channel for "small"
LAST_SECOND_CONV_SMALL
=
576
# last second conv output channel for "large"
LAST_SECOND_CONV_LARGE
=
960
# last conv output channel number for "large" and "small"
LAST_CONV
=
1280
def
_make_divisible
(
v
,
divisor
=
8
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
def
_create_act
(
act
):
if
act
==
"hardswish"
:
return
nn
.
Hardswish
()
elif
act
==
"relu"
:
return
nn
.
ReLU
()
elif
act
is
None
:
return
None
else
:
raise
RuntimeError
(
"The activation function is not supported: {}"
.
format
(
act
))
class
MobileNetV3
(
TheseusLayer
):
"""
MobileNetV3
Args:
config: list. MobileNetV3 depthwise blocks config.
scale: float=1.0. The coefficient that controls the size of network parameters.
class_num: int=1000. The number of classes.
inplanes: int=16. The output channel number of first convolution layer.
class_squeeze: int=960. The output channel number of penultimate convolution layer.
class_expand: int=1280. The output channel number of last convolution layer.
dropout_prob: float=0.2. Probability of setting units to zero.
Returns:
model: nn.Layer. Specific MobileNetV3 model depends on args.
"""
def
__init__
(
self
,
config
,
scale
=
1.0
,
class_num
=
1000
,
inplanes
=
STEM_CONV_NUMBER
,
class_squeeze
=
LAST_SECOND_CONV_LARGE
,
class_expand
=
LAST_CONV
,
dropout_prob
=
0.2
):
super
().
__init__
()
self
.
cfg
=
config
self
.
scale
=
scale
self
.
inplanes
=
inplanes
self
.
class_squeeze
=
class_squeeze
self
.
class_expand
=
class_expand
self
.
class_num
=
class_num
self
.
conv
=
ConvBNLayer
(
in_c
=
3
,
out_c
=
_make_divisible
(
self
.
inplanes
*
self
.
scale
),
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
num_groups
=
1
,
if_act
=
True
,
act
=
"hardswish"
)
self
.
blocks
=
nn
.
Sequential
(
*
[
ResidualUnit
(
in_c
=
_make_divisible
(
self
.
inplanes
*
self
.
scale
if
i
==
0
else
self
.
cfg
[
i
-
1
][
2
]
*
self
.
scale
),
mid_c
=
_make_divisible
(
self
.
scale
*
exp
),
out_c
=
_make_divisible
(
self
.
scale
*
c
),
filter_size
=
k
,
stride
=
s
,
use_se
=
se
,
act
=
act
)
for
i
,
(
k
,
exp
,
c
,
se
,
act
,
s
)
in
enumerate
(
self
.
cfg
)
])
self
.
last_second_conv
=
ConvBNLayer
(
in_c
=
_make_divisible
(
self
.
cfg
[
-
1
][
2
]
*
self
.
scale
),
out_c
=
_make_divisible
(
self
.
scale
*
self
.
class_squeeze
),
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
"hardswish"
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
last_conv
=
Conv2D
(
in_channels
=
_make_divisible
(
self
.
scale
*
self
.
class_squeeze
),
out_channels
=
self
.
class_expand
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
hardswish
=
nn
.
Hardswish
()
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
fc
=
Linear
(
self
.
class_expand
,
class_num
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
blocks
(
x
)
x
=
self
.
last_second_conv
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
self
.
last_conv
(
x
)
x
=
self
.
hardswish
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
return
x
class
ConvBNLayer
(
TheseusLayer
):
def
__init__
(
self
,
in_c
,
out_c
,
filter_size
,
stride
,
padding
,
num_groups
=
1
,
if_act
=
True
,
act
=
None
):
super
().
__init__
()
self
.
conv
=
Conv2D
(
in_channels
=
in_c
,
out_channels
=
out_c
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
bias_attr
=
False
)
self
.
bn
=
BatchNorm
(
num_channels
=
out_c
,
act
=
None
,
param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
if_act
=
if_act
self
.
act
=
_create_act
(
act
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
if_act
:
x
=
self
.
act
(
x
)
return
x
class
ResidualUnit
(
TheseusLayer
):
def
__init__
(
self
,
in_c
,
mid_c
,
out_c
,
filter_size
,
stride
,
use_se
,
act
=
None
):
super
().
__init__
()
self
.
if_shortcut
=
stride
==
1
and
in_c
==
out_c
self
.
if_se
=
use_se
self
.
expand_conv
=
ConvBNLayer
(
in_c
=
in_c
,
out_c
=
mid_c
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
if_act
=
True
,
act
=
act
)
self
.
bottleneck_conv
=
ConvBNLayer
(
in_c
=
mid_c
,
out_c
=
mid_c
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
int
((
filter_size
-
1
)
//
2
),
num_groups
=
mid_c
,
if_act
=
True
,
act
=
act
)
if
self
.
if_se
:
self
.
mid_se
=
SEModule
(
mid_c
)
self
.
linear_conv
=
ConvBNLayer
(
in_c
=
mid_c
,
out_c
=
out_c
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
if_act
=
False
,
act
=
None
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
expand_conv
(
x
)
x
=
self
.
bottleneck_conv
(
x
)
if
self
.
if_se
:
x
=
self
.
mid_se
(
x
)
x
=
self
.
linear_conv
(
x
)
if
self
.
if_shortcut
:
x
=
paddle
.
add
(
identity
,
x
)
return
x
# nn.Hardsigmoid can't transfer "slope" and "offset" in nn.functional.hardsigmoid
class
Hardsigmoid
(
TheseusLayer
):
def
__init__
(
self
,
slope
=
0.2
,
offset
=
0.5
):
super
().
__init__
()
self
.
slope
=
slope
self
.
offset
=
offset
def
forward
(
self
,
x
):
return
nn
.
functional
.
hardsigmoid
(
x
,
slope
=
self
.
slope
,
offset
=
self
.
offset
)
class
SEModule
(
TheseusLayer
):
def
__init__
(
self
,
channel
,
reduction
=
4
):
super
().
__init__
()
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
Conv2D
(
in_channels
=
channel
,
out_channels
=
channel
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
Conv2D
(
in_channels
=
channel
//
reduction
,
out_channels
=
channel
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
hardsigmoid
=
Hardsigmoid
(
slope
=
0.2
,
offset
=
0.5
)
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
hardsigmoid
(
x
)
return
paddle
.
multiply
(
x
=
identity
,
y
=
x
)
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
MobileNetV3_small_x0_35
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_small_x0_35
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"small"
],
scale
=
0.35
,
class_squeeze
=
LAST_SECOND_CONV_SMALL
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_small_x0_35"
],
use_ssld
)
return
model
def
MobileNetV3_small_x0_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_small_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"small"
],
scale
=
0.5
,
class_squeeze
=
LAST_SECOND_CONV_SMALL
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_small_x0_5"
],
use_ssld
)
return
model
def
MobileNetV3_small_x0_75
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_small_x0_75
Args:
pretrained: bool=false or str. if `true` load pretrained parameters, `false` otherwise.
if str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"small"
],
scale
=
0.75
,
class_squeeze
=
LAST_SECOND_CONV_SMALL
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_small_x0_75"
],
use_ssld
)
return
model
def
MobileNetV3_small_x1_0
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_small_x1_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"small"
],
scale
=
1.0
,
class_squeeze
=
LAST_SECOND_CONV_SMALL
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_small_x1_0"
],
use_ssld
)
return
model
def
MobileNetV3_small_x1_25
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_small_x1_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"small"
],
scale
=
1.25
,
class_squeeze
=
LAST_SECOND_CONV_SMALL
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_small_x1_25"
],
use_ssld
)
return
model
def
MobileNetV3_large_x0_35
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_large_x0_35
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"large"
],
scale
=
0.35
,
class_squeeze
=
LAST_SECOND_CONV_LARGE
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_large_x0_35"
],
use_ssld
)
return
model
def
MobileNetV3_large_x0_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_large_x0_5
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"large"
],
scale
=
0.5
,
class_squeeze
=
LAST_SECOND_CONV_LARGE
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_large_x0_5"
],
use_ssld
)
return
model
def
MobileNetV3_large_x0_75
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_large_x0_75
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"large"
],
scale
=
0.75
,
class_squeeze
=
LAST_SECOND_CONV_LARGE
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_large_x0_75"
],
use_ssld
)
return
model
def
MobileNetV3_large_x1_0
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_large_x1_0
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"large"
],
scale
=
1.0
,
class_squeeze
=
LAST_SECOND_CONV_LARGE
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_large_x1_0"
],
use_ssld
)
return
model
def
MobileNetV3_large_x1_25
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
MobileNetV3_large_x1_25
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
"""
model
=
MobileNetV3
(
config
=
NET_CONFIG
[
"large"
],
scale
=
1.25
,
class_squeeze
=
LAST_SECOND_CONV_LARGE
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"MobileNetV3_large_x1_25"
],
use_ssld
)
return
model
ppcls/arch/backbone/legendary_models/resnet.py
浏览文件 @
1a74e9cb
...
@@ -24,26 +24,34 @@ from paddle.nn.initializer import Uniform
...
@@ -24,26 +24,34 @@ from paddle.nn.initializer import Uniform
import
math
import
math
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain_from
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
MODEL_URLS
=
{
"ResNet18"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams"
,
"ResNet18"
:
"ResNet18_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams"
,
"ResNet34"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams"
,
"ResNet18_vd"
:
"ResNet34_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_vd_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams"
,
"ResNet50"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams"
,
"ResNet34"
:
"ResNet50_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams"
,
"ResNet101"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_pretrained.pdparams"
,
"ResNet34_vd"
:
"ResNet101_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet101_vd_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams"
,
"ResNet152"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_pretrained.pdparams"
,
"ResNet50"
:
"ResNet152_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet152_vd_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams"
,
"ResNet200_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet200_vd_pretrained.pdparams"
,
"ResNet50_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams"
,
"ResNet101"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams"
,
"ResNet101_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams"
,
"ResNet152"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams"
,
"ResNet152_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams"
,
"ResNet200_vd"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams"
,
}
}
__all__
=
MODEL_URLS
.
keys
()
__all__
=
MODEL_URLS
.
keys
()
'''
'''
ResNet config: dict.
ResNet config: dict.
key: depth of ResNet.
key: depth of ResNet.
...
@@ -55,17 +63,35 @@ ResNet config: dict.
...
@@ -55,17 +63,35 @@ ResNet config: dict.
'''
'''
NET_CONFIG
=
{
NET_CONFIG
=
{
"18"
:
{
"18"
:
{
"block_type"
:
"BasicBlock"
,
"block_depth"
:
[
2
,
2
,
2
,
2
],
"num_channels"
:
[
64
,
64
,
128
,
256
]},
"block_type"
:
"BasicBlock"
,
"block_depth"
:
[
2
,
2
,
2
,
2
],
"num_channels"
:
[
64
,
64
,
128
,
256
]
},
"34"
:
{
"34"
:
{
"block_type"
:
"BasicBlock"
,
"block_depth"
:
[
3
,
4
,
6
,
3
],
"num_channels"
:
[
64
,
64
,
128
,
256
]},
"block_type"
:
"BasicBlock"
,
"block_depth"
:
[
3
,
4
,
6
,
3
],
"num_channels"
:
[
64
,
64
,
128
,
256
]
},
"50"
:
{
"50"
:
{
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
4
,
6
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]},
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
4
,
6
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]
},
"101"
:
{
"101"
:
{
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
4
,
23
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]},
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
4
,
23
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]
},
"152"
:
{
"152"
:
{
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
8
,
36
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]},
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
8
,
36
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]
},
"200"
:
{
"200"
:
{
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
12
,
48
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]},
"block_type"
:
"BottleneckBlock"
,
"block_depth"
:
[
3
,
12
,
48
,
3
],
"num_channels"
:
[
64
,
256
,
512
,
1024
]
},
}
}
...
@@ -110,14 +136,14 @@ class ConvBNLayer(TheseusLayer):
...
@@ -110,14 +136,14 @@ class ConvBNLayer(TheseusLayer):
class
BottleneckBlock
(
TheseusLayer
):
class
BottleneckBlock
(
TheseusLayer
):
def
__init__
(
self
,
def
__init__
(
num_channels
,
self
,
num_filter
s
,
num_channel
s
,
stride
,
num_filters
,
shortcut
=
Tru
e
,
strid
e
,
if_first
=
Fals
e
,
shortcut
=
Tru
e
,
lr_mult
=
1.0
,
if_first
=
False
,
):
lr_mult
=
1.0
,
):
super
().
__init__
()
super
().
__init__
()
self
.
conv0
=
ConvBNLayer
(
self
.
conv0
=
ConvBNLayer
(
...
@@ -222,16 +248,15 @@ class ResNet(TheseusLayer):
...
@@ -222,16 +248,15 @@ class ResNet(TheseusLayer):
version: str="vb". Different version of ResNet, version vd can perform better.
version: str="vb". Different version of ResNet, version vd can perform better.
class_num: int=1000. The number of classes.
class_num: int=1000. The number of classes.
lr_mult_list: list. Control the learning rate of different stages.
lr_mult_list: list. Control the learning rate of different stages.
pretrained: (True or False) or path of pretrained_model. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific ResNet model depends on args.
model: nn.Layer. Specific ResNet model depends on args.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
config
,
config
,
version
=
"vb"
,
version
=
"vb"
,
class_num
=
1000
,
class_num
=
1000
,
lr_mult_list
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
],
lr_mult_list
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
]):
pretrained
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
cfg
=
config
self
.
cfg
=
config
...
@@ -243,51 +268,46 @@ class ResNet(TheseusLayer):
...
@@ -243,51 +268,46 @@ class ResNet(TheseusLayer):
self
.
block_type
=
self
.
cfg
[
"block_type"
]
self
.
block_type
=
self
.
cfg
[
"block_type"
]
self
.
num_channels
=
self
.
cfg
[
"num_channels"
]
self
.
num_channels
=
self
.
cfg
[
"num_channels"
]
self
.
channels_mult
=
1
if
self
.
num_channels
[
-
1
]
==
256
else
4
self
.
channels_mult
=
1
if
self
.
num_channels
[
-
1
]
==
256
else
4
self
.
pretrained
=
pretrained
assert
isinstance
(
self
.
lr_mult_list
,
(
assert
isinstance
(
self
.
lr_mult_list
,
(
list
,
tuple
list
,
tuple
)),
"lr_mult_list should be in (list, tuple) but got {}"
.
format
(
)),
"lr_mult_list should be in (list, tuple) but got {}"
.
format
(
type
(
self
.
lr_mult_list
))
type
(
self
.
lr_mult_list
))
assert
len
(
assert
len
(
self
.
lr_mult_list
self
.
lr_mult_list
)
==
5
,
"lr_mult_list length should be 5 but got {}"
.
format
(
)
==
5
,
"lr_mult_list length should be 5 but got {}"
.
format
(
len
(
self
.
lr_mult_list
))
len
(
self
.
lr_mult_list
))
self
.
stem_cfg
=
{
self
.
stem_cfg
=
{
#num_channels, num_filters, filter_size, stride
#num_channels, num_filters, filter_size, stride
"vb"
:
[[
3
,
64
,
7
,
2
]],
"vb"
:
[[
3
,
64
,
7
,
2
]],
"vd"
:
[[
3
,
32
,
3
,
2
],
"vd"
:
[[
3
,
32
,
3
,
2
],
[
32
,
32
,
3
,
1
],
[
32
,
64
,
3
,
1
]]
[
32
,
32
,
3
,
1
],
}
[
32
,
64
,
3
,
1
]]}
self
.
stem
=
nn
.
Sequential
(
*
[
self
.
stem
=
nn
.
Sequential
(
*
[
ConvBNLayer
(
ConvBNLayer
(
num_channels
=
in_c
,
num_channels
=
in_c
,
num_filters
=
out_c
,
num_filters
=
out_c
,
filter_size
=
k
,
filter_size
=
k
,
stride
=
s
,
stride
=
s
,
act
=
"relu"
,
act
=
"relu"
,
lr_mult
=
self
.
lr_mult_list
[
0
])
lr_mult
=
self
.
lr_mult_list
[
0
])
for
in_c
,
out_c
,
k
,
s
in
self
.
stem_cfg
[
version
]
for
in_c
,
out_c
,
k
,
s
in
self
.
stem_cfg
[
version
]
])
])
self
.
max_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
max_pool
=
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
block_list
=
[]
block_list
=
[]
for
block_idx
in
range
(
len
(
self
.
block_depth
)):
for
block_idx
in
range
(
len
(
self
.
block_depth
)):
shortcut
=
False
shortcut
=
False
for
i
in
range
(
self
.
block_depth
[
block_idx
]):
for
i
in
range
(
self
.
block_depth
[
block_idx
]):
block_list
.
append
(
block_list
.
append
(
globals
()[
self
.
block_type
](
globals
()[
self
.
block_type
](
num_channels
=
self
.
num_channels
[
block_idx
]
if
i
==
0
else
num_channels
=
self
.
num_channels
[
block_idx
]
self
.
num_filters
[
block_idx
]
*
self
.
channels_mult
,
if
i
==
0
else
self
.
num_filters
[
block_idx
]
*
self
.
channels_mult
,
num_filters
=
self
.
num_filters
[
block_idx
],
num_filters
=
self
.
num_filters
[
block_idx
],
stride
=
2
if
i
==
0
and
block_idx
!=
0
else
1
,
stride
=
2
if
i
==
0
and
block_idx
!=
0
else
1
,
shortcut
=
shortcut
,
shortcut
=
shortcut
,
if_first
=
block_idx
==
i
==
0
if
version
==
"vd"
else
True
,
if_first
=
block_idx
==
i
==
0
if
version
==
"vd"
else
True
,
lr_mult
=
self
.
lr_mult_list
[
block_idx
+
1
]))
lr_mult
=
self
.
lr_mult_list
[
block_idx
+
1
]))
shortcut
=
True
shortcut
=
True
self
.
blocks
=
nn
.
Sequential
(
*
block_list
)
self
.
blocks
=
nn
.
Sequential
(
*
block_list
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
...
@@ -297,8 +317,7 @@ class ResNet(TheseusLayer):
...
@@ -297,8 +317,7 @@ class ResNet(TheseusLayer):
self
.
fc
=
Linear
(
self
.
fc
=
Linear
(
self
.
avg_pool_channels
,
self
.
avg_pool_channels
,
self
.
class_num
,
self
.
class_num
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
Uniform
(
-
stdv
,
stdv
)))
initializer
=
Uniform
(
-
stdv
,
stdv
)))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
x
=
self
.
stem
(
x
)
...
@@ -310,254 +329,179 @@ class ResNet(TheseusLayer):
...
@@ -310,254 +329,179 @@ class ResNet(TheseusLayer):
return
x
return
x
def
ResNet18
(
**
args
):
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
ResNet18
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
"""
ResNet18
ResNet18
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet18` model depends on args.
model: nn.Layer. Specific `ResNet18` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"18"
],
version
=
"vb"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"18"
],
version
=
"vb"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet18"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet18"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet18_vd
(
**
args
):
def
ResNet18_vd
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet18_vd
ResNet18_vd
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"18"
],
version
=
"vd"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"18"
],
version
=
"vd"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet18_vd"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet18_vd"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet34
(
**
args
):
def
ResNet34
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet34
ResNet34
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet34` model depends on args.
model: nn.Layer. Specific `ResNet34` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"34"
],
version
=
"vb"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"34"
],
version
=
"vb"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet34"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet34"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet34_vd
(
**
args
):
def
ResNet34_vd
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet34_vd
ResNet34_vd
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"34"
],
version
=
"vd"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"34"
],
version
=
"vd"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet34_vd"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet34_vd"
],
use_ssld
=
True
)
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet50
(
**
args
):
def
ResNet50
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet50
ResNet50
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet50` model depends on args.
model: nn.Layer. Specific `ResNet50` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"50"
],
version
=
"vb"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"50"
],
version
=
"vb"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet50"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet50"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet50_vd
(
**
args
):
def
ResNet50_vd
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet50_vd
ResNet50_vd
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"50"
],
version
=
"vd"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"50"
],
version
=
"vd"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet50_vd"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet50_vd"
],
use_ssld
=
True
)
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet101
(
**
args
):
def
ResNet101
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet101
ResNet101
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet101` model depends on args.
model: nn.Layer. Specific `ResNet101` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"101"
],
version
=
"vb"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"101"
],
version
=
"vb"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet101"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet101"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet101_vd
(
**
args
):
def
ResNet101_vd
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet101_vd
ResNet101_vd
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"101"
],
version
=
"vd"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"101"
],
version
=
"vd"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet101_vd"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet101_vd"
],
use_ssld
=
True
)
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet152
(
**
args
):
def
ResNet152
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet152
ResNet152
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet152` model depends on args.
model: nn.Layer. Specific `ResNet152` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"152"
],
version
=
"vb"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"152"
],
version
=
"vb"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet152"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet152"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet152_vd
(
**
args
):
def
ResNet152_vd
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet152_vd
ResNet152_vd
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"152"
],
version
=
"vd"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"152"
],
version
=
"vd"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet152_vd"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet152_vd"
])
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
def
ResNet200_vd
(
**
args
):
def
ResNet200_vd
(
pretrained
=
False
,
use_ssld
=
False
,
**
kw
args
):
"""
"""
ResNet200_vd
ResNet200_vd
Args:
Args:
kwargs:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
class_num: int=1000. Output dim of last fc layer.
If str, means the path of the pretrained model.
lr_mult_list: list=[1.0, 1.0, 1.0, 1.0, 1.0]. Control the learning rate of different stages.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
pretrained: bool or str, default: bool=False. Whether to load the pretrained model.
Returns:
Returns:
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
"""
"""
model
=
ResNet
(
config
=
NET_CONFIG
[
"200"
],
version
=
"vd"
,
**
args
)
model
=
ResNet
(
config
=
NET_CONFIG
[
"200"
],
version
=
"vd"
,
**
kwargs
)
if
isinstance
(
model
.
pretrained
,
bool
):
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ResNet200_vd"
],
use_ssld
)
if
model
.
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
MODEL_URLS
[
"ResNet200_vd"
],
use_ssld
=
True
)
elif
isinstance
(
model
.
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
model
.
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type"
)
return
model
return
model
ppcls/arch/backbone/legendary_models/vgg.py
浏览文件 @
1a74e9cb
...
@@ -14,16 +14,24 @@
...
@@ -14,16 +14,24 @@
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
MaxPool2D
from
paddle.nn
import
MaxPool2D
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
__all__
=
[
"VGG11"
,
"VGG13"
,
"VGG16"
,
"VGG19"
]
MODEL_URLS
=
{
"VGG11"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG11_pretrained.pdparams"
,
"VGG13"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG13_pretrained.pdparams"
,
"VGG16"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG16_pretrained.pdparams"
,
"VGG19"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/VGG19_pretrained.pdparams"
,
}
__all__
=
MODEL_URLS
.
keys
()
# VGG config
# VGG config
# key: VGG network depth
# key: VGG network depth
...
@@ -36,68 +44,12 @@ NET_CONFIG = {
...
@@ -36,68 +44,12 @@ NET_CONFIG = {
}
}
def
VGG11
(
**
args
):
"""
VGG11
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
11
],
**
args
)
return
model
def
VGG13
(
**
args
):
"""
VGG13
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
13
],
**
args
)
return
model
def
VGG16
(
**
args
):
"""
VGG16
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
16
],
**
args
)
return
model
def
VGG19
(
**
args
):
"""
VGG19
Args:
kwargs:
class_num: int=1000. Output dim of last fc layer.
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
19
],
**
args
)
return
model
class
ConvBlock
(
TheseusLayer
):
class
ConvBlock
(
TheseusLayer
):
def
__init__
(
self
,
input_channels
,
output_channels
,
groups
):
def
__init__
(
self
,
input_channels
,
output_channels
,
groups
):
super
(
ConvBlock
,
self
).
__init__
()
super
().
__init__
()
self
.
groups
=
groups
self
.
groups
=
groups
self
.
_conv_
1
=
Conv2D
(
self
.
conv
1
=
Conv2D
(
in_channels
=
input_channels
,
in_channels
=
input_channels
,
out_channels
=
output_channels
,
out_channels
=
output_channels
,
kernel_size
=
3
,
kernel_size
=
3
,
...
@@ -105,7 +57,7 @@ class ConvBlock(TheseusLayer):
...
@@ -105,7 +57,7 @@ class ConvBlock(TheseusLayer):
padding
=
1
,
padding
=
1
,
bias_attr
=
False
)
bias_attr
=
False
)
if
groups
==
2
or
groups
==
3
or
groups
==
4
:
if
groups
==
2
or
groups
==
3
or
groups
==
4
:
self
.
_conv_
2
=
Conv2D
(
self
.
conv
2
=
Conv2D
(
in_channels
=
output_channels
,
in_channels
=
output_channels
,
out_channels
=
output_channels
,
out_channels
=
output_channels
,
kernel_size
=
3
,
kernel_size
=
3
,
...
@@ -113,7 +65,7 @@ class ConvBlock(TheseusLayer):
...
@@ -113,7 +65,7 @@ class ConvBlock(TheseusLayer):
padding
=
1
,
padding
=
1
,
bias_attr
=
False
)
bias_attr
=
False
)
if
groups
==
3
or
groups
==
4
:
if
groups
==
3
or
groups
==
4
:
self
.
_conv_
3
=
Conv2D
(
self
.
conv
3
=
Conv2D
(
in_channels
=
output_channels
,
in_channels
=
output_channels
,
out_channels
=
output_channels
,
out_channels
=
output_channels
,
kernel_size
=
3
,
kernel_size
=
3
,
...
@@ -121,7 +73,7 @@ class ConvBlock(TheseusLayer):
...
@@ -121,7 +73,7 @@ class ConvBlock(TheseusLayer):
padding
=
1
,
padding
=
1
,
bias_attr
=
False
)
bias_attr
=
False
)
if
groups
==
4
:
if
groups
==
4
:
self
.
_conv_
4
=
Conv2D
(
self
.
conv
4
=
Conv2D
(
in_channels
=
output_channels
,
in_channels
=
output_channels
,
out_channels
=
output_channels
,
out_channels
=
output_channels
,
kernel_size
=
3
,
kernel_size
=
3
,
...
@@ -129,73 +81,148 @@ class ConvBlock(TheseusLayer):
...
@@ -129,73 +81,148 @@ class ConvBlock(TheseusLayer):
padding
=
1
,
padding
=
1
,
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_pool
=
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
max
_pool
=
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
_
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
_conv_
1
(
inputs
)
x
=
self
.
conv
1
(
inputs
)
x
=
self
.
_
relu
(
x
)
x
=
self
.
relu
(
x
)
if
self
.
groups
==
2
or
self
.
groups
==
3
or
self
.
groups
==
4
:
if
self
.
groups
==
2
or
self
.
groups
==
3
or
self
.
groups
==
4
:
x
=
self
.
_conv_
2
(
x
)
x
=
self
.
conv
2
(
x
)
x
=
self
.
_
relu
(
x
)
x
=
self
.
relu
(
x
)
if
self
.
groups
==
3
or
self
.
groups
==
4
:
if
self
.
groups
==
3
or
self
.
groups
==
4
:
x
=
self
.
_conv_
3
(
x
)
x
=
self
.
conv
3
(
x
)
x
=
self
.
_
relu
(
x
)
x
=
self
.
relu
(
x
)
if
self
.
groups
==
4
:
if
self
.
groups
==
4
:
x
=
self
.
_conv_
4
(
x
)
x
=
self
.
conv
4
(
x
)
x
=
self
.
_
relu
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
_pool
(
x
)
x
=
self
.
max
_pool
(
x
)
return
x
return
x
class
VGGNet
(
TheseusLayer
):
class
VGGNet
(
TheseusLayer
):
def
__init__
(
self
,
"""
config
,
VGGNet
stop_grad_layers
=
0
,
Args:
class_num
=
1000
,
config: list. VGGNet config.
pretrained
=
False
,
stop_grad_layers: int=0. The parameters in blocks which index larger than `stop_grad_layers`, will be set `param.trainable=False`
**
args
):
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific VGG model depends on args.
"""
def
__init__
(
self
,
config
,
stop_grad_layers
=
0
,
class_num
=
1000
):
super
().
__init__
()
super
().
__init__
()
self
.
stop_grad_layers
=
stop_grad_layers
self
.
stop_grad_layers
=
stop_grad_layers
self
.
_
conv_block_1
=
ConvBlock
(
3
,
64
,
config
[
0
])
self
.
conv_block_1
=
ConvBlock
(
3
,
64
,
config
[
0
])
self
.
_
conv_block_2
=
ConvBlock
(
64
,
128
,
config
[
1
])
self
.
conv_block_2
=
ConvBlock
(
64
,
128
,
config
[
1
])
self
.
_
conv_block_3
=
ConvBlock
(
128
,
256
,
config
[
2
])
self
.
conv_block_3
=
ConvBlock
(
128
,
256
,
config
[
2
])
self
.
_
conv_block_4
=
ConvBlock
(
256
,
512
,
config
[
3
])
self
.
conv_block_4
=
ConvBlock
(
256
,
512
,
config
[
3
])
self
.
_
conv_block_5
=
ConvBlock
(
512
,
512
,
config
[
4
])
self
.
conv_block_5
=
ConvBlock
(
512
,
512
,
config
[
4
])
self
.
_
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
_
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
for
idx
,
block
in
enumerate
([
for
idx
,
block
in
enumerate
([
self
.
_conv_block_1
,
self
.
_conv_block_2
,
self
.
_
conv_block_3
,
self
.
conv_block_1
,
self
.
conv_block_2
,
self
.
conv_block_3
,
self
.
_conv_block_4
,
self
.
_
conv_block_5
self
.
conv_block_4
,
self
.
conv_block_5
]):
]):
if
self
.
stop_grad_layers
>=
idx
+
1
:
if
self
.
stop_grad_layers
>=
idx
+
1
:
for
param
in
block
.
parameters
():
for
param
in
block
.
parameters
():
param
.
trainable
=
False
param
.
trainable
=
False
self
.
_drop
=
Dropout
(
p
=
0.5
,
mode
=
"downscale_in_infer"
)
self
.
drop
=
Dropout
(
p
=
0.5
,
mode
=
"downscale_in_infer"
)
self
.
_fc1
=
Linear
(
7
*
7
*
512
,
4096
)
self
.
fc1
=
Linear
(
7
*
7
*
512
,
4096
)
self
.
_fc2
=
Linear
(
4096
,
4096
)
self
.
fc2
=
Linear
(
4096
,
4096
)
self
.
_out
=
Linear
(
4096
,
class_num
)
self
.
fc3
=
Linear
(
4096
,
class_num
)
if
pretrained
is
not
None
:
load_dygraph_pretrain
(
self
,
pretrained
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
_
conv_block_1
(
inputs
)
x
=
self
.
conv_block_1
(
inputs
)
x
=
self
.
_
conv_block_2
(
x
)
x
=
self
.
conv_block_2
(
x
)
x
=
self
.
_
conv_block_3
(
x
)
x
=
self
.
conv_block_3
(
x
)
x
=
self
.
_
conv_block_4
(
x
)
x
=
self
.
conv_block_4
(
x
)
x
=
self
.
_
conv_block_5
(
x
)
x
=
self
.
conv_block_5
(
x
)
x
=
self
.
_
flatten
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
_
fc1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
_
relu
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
_
drop
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
_
fc2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
_
relu
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
_
drop
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
_out
(
x
)
x
=
self
.
fc3
(
x
)
return
x
return
x
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
VGG11
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
VGG11
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
11
],
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"VGG11"
],
use_ssld
)
return
model
def
VGG13
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
VGG13
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG13` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
13
],
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"VGG13"
],
use_ssld
)
return
model
def
VGG16
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
VGG16
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG16` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
16
],
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"VGG16"
],
use_ssld
)
return
model
def
VGG19
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
VGG19
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `VGG19` model depends on args.
"""
model
=
VGGNet
(
config
=
NET_CONFIG
[
19
],
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"VGG19"
],
use_ssld
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录