Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
fa6ce23f
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fa6ce23f
编写于
10月 22, 2020
作者:
L
littletomatodonkey
提交者:
GitHub
10月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #331 from weisy11/dygraph
add ghostnet and modify shufflenet
上级
4f885038
2ec1d73e
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
724 addition
and
193 deletion
+724
-193
configs/GhostNet/GhostNet_x0_5.yaml
configs/GhostNet/GhostNet_x0_5.yaml
+74
-0
configs/GhostNet/GhostNet_x1_0.yaml
configs/GhostNet/GhostNet_x1_0.yaml
+74
-0
configs/GhostNet/GhostNet_x1_3.yaml
configs/GhostNet/GhostNet_x1_3.yaml
+75
-0
configs/ShuffleNet/ShuffleNetV2.yaml
configs/ShuffleNet/ShuffleNetV2.yaml
+1
-1
configs/ShuffleNet/ShuffleNetV2_swish.yaml
configs/ShuffleNet/ShuffleNetV2_swish.yaml
+1
-1
configs/ShuffleNet/ShuffleNetV2_x0_25.yaml
configs/ShuffleNet/ShuffleNetV2_x0_25.yaml
+1
-1
configs/ShuffleNet/ShuffleNetV2_x0_33.yaml
configs/ShuffleNet/ShuffleNetV2_x0_33.yaml
+1
-1
configs/ShuffleNet/ShuffleNetV2_x0_5.yaml
configs/ShuffleNet/ShuffleNetV2_x0_5.yaml
+1
-1
configs/ShuffleNet/ShuffleNetV2_x1_5.yaml
configs/ShuffleNet/ShuffleNetV2_x1_5.yaml
+1
-1
configs/ShuffleNet/ShuffleNetV2_x2_0.yaml
configs/ShuffleNet/ShuffleNetV2_x2_0.yaml
+1
-1
ppcls/modeling/architectures/__init__.py
ppcls/modeling/architectures/__init__.py
+1
-0
ppcls/modeling/architectures/ghostnet.py
ppcls/modeling/architectures/ghostnet.py
+335
-0
ppcls/modeling/architectures/shufflenet_v2.py
ppcls/modeling/architectures/shufflenet_v2.py
+158
-186
未找到文件。
configs/GhostNet/GhostNet_x0_5.yaml
0 → 100644
浏览文件 @
fa6ce23f
mode
:
'
train'
ARCHITECTURE
:
name
:
'
GhostNet_x0_5'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
360
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
0.1
LEARNING_RATE
:
function
:
'
CosineWarmup'
params
:
lr
:
0.8
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.0000400
TRAIN
:
batch_size
:
2048
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/GhostNet/GhostNet_x1_0.yaml
0 → 100644
浏览文件 @
fa6ce23f
mode
:
'
train'
ARCHITECTURE
:
name
:
'
GhostNet_x1_0'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
360
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
0.1
LEARNING_RATE
:
function
:
'
CosineWarmup'
params
:
lr
:
0.4
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.0000400
TRAIN
:
batch_size
:
1024
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/GhostNet/GhostNet_x1_3.yaml
0 → 100644
浏览文件 @
fa6ce23f
mode
:
'
train'
ARCHITECTURE
:
name
:
'
GhostNet_x1_3'
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
360
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_mix
:
False
ls_epsilon
:
0.1
LEARNING_RATE
:
function
:
'
CosineWarmup'
params
:
lr
:
0.4
OPTIMIZER
:
function
:
'
Momentum'
params
:
momentum
:
0.9
regularizer
:
function
:
'
L2'
factor
:
0.0000400
TRAIN
:
batch_size
:
1024
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
AutoAugment
:
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
64
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/ShuffleNet/ShuffleNetV2.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.5
warmup_epoch
:
5
...
...
configs/ShuffleNet/ShuffleNetV2_swish.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.5
warmup_epoch
:
5
...
...
configs/ShuffleNet/ShuffleNetV2_x0_25.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.5
warmup_epoch
:
5
...
...
configs/ShuffleNet/ShuffleNetV2_x0_33.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.5
warmup_epoch
:
5
...
...
configs/ShuffleNet/ShuffleNetV2_x0_5.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.5
warmup_epoch
:
5
...
...
configs/ShuffleNet/ShuffleNetV2_x1_5.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.25
warmup_epoch
:
5
...
...
configs/ShuffleNet/ShuffleNetV2_x2_0.yaml
浏览文件 @
fa6ce23f
...
...
@@ -14,7 +14,7 @@ topk: 5
image_shape
:
[
3
,
224
,
224
]
LEARNING_RATE
:
function
:
'
Cosine'
function
:
'
Cosine
Warmup
'
params
:
lr
:
0.25
warmup_epoch
:
5
...
...
ppcls/modeling/architectures/__init__.py
浏览文件 @
fa6ce23f
...
...
@@ -28,6 +28,7 @@ from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44
from
.efficientnet
import
EfficientNetB0
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
from
.resnest
import
ResNeSt50_fast_1s1x64d
,
ResNeSt50
from
.googlenet
import
GoogLeNet
from
.ghostnet
import
GhostNet_x0_5
,
GhostNet_x1_0
,
GhostNet_x1_3
from
.mobilenet_v1
import
MobileNetV1_x0_25
,
MobileNetV1_x0_5
,
MobileNetV1_x0_75
,
MobileNetV1
from
.mobilenet_v2
import
MobileNetV2_x0_25
,
MobileNetV2_x0_5
,
MobileNetV2_x0_75
,
MobileNetV2
,
MobileNetV2_x1_5
,
MobileNetV2_x2_0
from
.mobilenet_v3
import
MobileNetV3_small_x0_35
,
MobileNetV3_small_x0_5
,
MobileNetV3_small_x0_75
,
MobileNetV3_small_x1_0
,
MobileNetV3_small_x1_25
,
MobileNetV3_large_x0_35
,
MobileNetV3_large_x0_5
,
MobileNetV3_large_x0_75
,
MobileNetV3_large_x1_0
,
MobileNetV3_large_x1_25
...
...
ppcls/modeling/architectures/ghostnet.py
0 → 100644
浏览文件 @
fa6ce23f
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2d
,
BatchNorm
,
AdaptiveAvgPool2d
,
Linear
from
paddle.fluid.regularizer
import
L2DecayRegularizer
from
paddle.nn.initializer
import
Uniform
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
"relu"
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
initializer
=
nn
.
initializer
.
MSRA
(),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
bn_name
=
name
+
"_bn"
# In the old version, moving_variance_name was name + "_variance"
self
.
_batch_norm
=
BatchNorm
(
num_channels
=
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
,
regularizer
=
L2DecayRegularizer
(
regularization_coeff
=
0.0
)),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
,
regularizer
=
L2DecayRegularizer
(
regularization_coeff
=
0.0
)),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
name
+
"_variance"
# wrong due to an old typo, will be fixed later.
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
SEBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
reduction_ratio
=
4
,
name
=
None
):
super
(
SEBlock
,
self
).
__init__
()
self
.
pool2d_gap
=
AdaptiveAvgPool2d
(
1
)
self
.
_num_channels
=
num_channels
stdv
=
1.0
/
math
.
sqrt
(
num_channels
*
1.0
)
med_ch
=
num_channels
//
reduction_ratio
self
.
squeeze
=
Linear
(
num_channels
,
med_ch
,
weight_attr
=
ParamAttr
(
initializer
=
Uniform
(
-
stdv
,
stdv
),
name
=
name
+
"_1_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_1_offset"
))
stdv
=
1.0
/
math
.
sqrt
(
med_ch
*
1.0
)
self
.
excitation
=
Linear
(
med_ch
,
num_channels
,
weight_attr
=
ParamAttr
(
initializer
=
Uniform
(
-
stdv
,
stdv
),
name
=
name
+
"_2_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_2_offset"
))
def
forward
(
self
,
inputs
):
pool
=
self
.
pool2d_gap
(
inputs
)
pool
=
paddle
.
reshape
(
pool
,
shape
=
[
-
1
,
self
.
_num_channels
])
squeeze
=
self
.
squeeze
(
pool
)
squeeze
=
F
.
relu
(
squeeze
)
excitation
=
self
.
excitation
(
squeeze
)
excitation
=
paddle
.
fluid
.
layers
.
clip
(
x
=
excitation
,
min
=
0
,
max
=
1
)
excitation
=
paddle
.
reshape
(
excitation
,
shape
=
[
-
1
,
self
.
_num_channels
,
1
,
1
])
out
=
inputs
*
excitation
return
out
class
GhostModule
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
output_channels
,
kernel_size
=
1
,
ratio
=
2
,
dw_size
=
3
,
stride
=
1
,
relu
=
True
,
name
=
None
):
super
(
GhostModule
,
self
).
__init__
()
init_channels
=
int
(
math
.
ceil
(
output_channels
/
ratio
))
new_channels
=
int
(
init_channels
*
(
ratio
-
1
))
self
.
primary_conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
init_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
1
,
act
=
"relu"
if
relu
else
None
,
name
=
name
+
"_primary_conv"
)
self
.
cheap_operation
=
ConvBNLayer
(
in_channels
=
init_channels
,
out_channels
=
new_channels
,
kernel_size
=
dw_size
,
stride
=
1
,
groups
=
init_channels
,
act
=
"relu"
if
relu
else
None
,
name
=
name
+
"_cheap_operation"
)
def
forward
(
self
,
inputs
):
x
=
self
.
primary_conv
(
inputs
)
y
=
self
.
cheap_operation
(
x
)
out
=
paddle
.
concat
([
x
,
y
],
axis
=
1
)
return
out
class
GhostBottleneck
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_dim
,
output_channels
,
kernel_size
,
stride
,
use_se
,
name
=
None
):
super
(
GhostBottleneck
,
self
).
__init__
()
self
.
_stride
=
stride
self
.
_use_se
=
use_se
self
.
_num_channels
=
in_channels
self
.
_output_channels
=
output_channels
self
.
ghost_module_1
=
GhostModule
(
in_channels
=
in_channels
,
output_channels
=
hidden_dim
,
kernel_size
=
1
,
stride
=
1
,
relu
=
True
,
name
=
name
+
"_ghost_module_1"
)
if
stride
==
2
:
self
.
depthwise_conv
=
ConvBNLayer
(
in_channels
=
hidden_dim
,
out_channels
=
hidden_dim
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
hidden_dim
,
act
=
None
,
name
=
name
+
"_depthwise_depthwise"
# looks strange due to an old typo, will be fixed later.
)
if
use_se
:
self
.
se_block
=
SEBlock
(
num_channels
=
hidden_dim
,
name
=
name
+
"_se"
)
self
.
ghost_module_2
=
GhostModule
(
in_channels
=
hidden_dim
,
output_channels
=
output_channels
,
kernel_size
=
1
,
relu
=
False
,
name
=
name
+
"_ghost_module_2"
)
if
stride
!=
1
or
in_channels
!=
output_channels
:
self
.
shortcut_depthwise
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
in_channels
,
act
=
None
,
name
=
name
+
"_shortcut_depthwise_depthwise"
# looks strange due to an old typo, will be fixed later.
)
self
.
shortcut_conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
output_channels
,
kernel_size
=
1
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
name
+
"_shortcut_conv"
)
def
forward
(
self
,
inputs
):
x
=
self
.
ghost_module_1
(
inputs
)
if
self
.
_stride
==
2
:
x
=
self
.
depthwise_conv
(
x
)
if
self
.
_use_se
:
x
=
self
.
se_block
(
x
)
x
=
self
.
ghost_module_2
(
x
)
if
self
.
_stride
==
1
and
self
.
_num_channels
==
self
.
_output_channels
:
shortcut
=
inputs
else
:
shortcut
=
self
.
shortcut_depthwise
(
inputs
)
shortcut
=
self
.
shortcut_conv
(
shortcut
)
return
paddle
.
elementwise_add
(
x
=
x
,
y
=
shortcut
,
axis
=-
1
)
class
GhostNet
(
nn
.
Layer
):
def
__init__
(
self
,
scale
,
class_dim
=
1000
):
super
(
GhostNet
,
self
).
__init__
()
self
.
cfgs
=
[
# k, t, c, SE, s
[
3
,
16
,
16
,
0
,
1
],
[
3
,
48
,
24
,
0
,
2
],
[
3
,
72
,
24
,
0
,
1
],
[
5
,
72
,
40
,
1
,
2
],
[
5
,
120
,
40
,
1
,
1
],
[
3
,
240
,
80
,
0
,
2
],
[
3
,
200
,
80
,
0
,
1
],
[
3
,
184
,
80
,
0
,
1
],
[
3
,
184
,
80
,
0
,
1
],
[
3
,
480
,
112
,
1
,
1
],
[
3
,
672
,
112
,
1
,
1
],
[
5
,
672
,
160
,
1
,
2
],
[
5
,
960
,
160
,
0
,
1
],
[
5
,
960
,
160
,
1
,
1
],
[
5
,
960
,
160
,
0
,
1
],
[
5
,
960
,
160
,
1
,
1
]
]
self
.
scale
=
scale
output_channels
=
int
(
self
.
_make_divisible
(
16
*
self
.
scale
,
4
))
self
.
conv1
=
ConvBNLayer
(
in_channels
=
3
,
out_channels
=
output_channels
,
kernel_size
=
3
,
stride
=
2
,
groups
=
1
,
act
=
"relu"
,
name
=
"conv1"
)
# build inverted residual blocks
idx
=
0
self
.
ghost_bottleneck_list
=
[]
for
k
,
exp_size
,
c
,
use_se
,
s
in
self
.
cfgs
:
in_channels
=
output_channels
output_channels
=
int
(
self
.
_make_divisible
(
c
*
self
.
scale
,
4
))
hidden_dim
=
int
(
self
.
_make_divisible
(
exp_size
*
self
.
scale
,
4
))
ghost_bottleneck
=
self
.
add_sublayer
(
name
=
"_ghostbottleneck_"
+
str
(
idx
),
sublayer
=
GhostBottleneck
(
in_channels
=
in_channels
,
hidden_dim
=
hidden_dim
,
output_channels
=
output_channels
,
kernel_size
=
k
,
stride
=
s
,
use_se
=
use_se
,
name
=
"_ghostbottleneck_"
+
str
(
idx
)))
self
.
ghost_bottleneck_list
.
append
(
ghost_bottleneck
)
idx
+=
1
# build last several layers
in_channels
=
output_channels
output_channels
=
int
(
self
.
_make_divisible
(
exp_size
*
self
.
scale
,
4
))
self
.
conv_last
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
output_channels
,
kernel_size
=
1
,
stride
=
1
,
groups
=
1
,
act
=
"relu"
,
name
=
"conv_last"
)
self
.
pool2d_gap
=
AdaptiveAvgPool2d
(
1
)
in_channels
=
output_channels
self
.
_fc0_output_channels
=
1280
self
.
fc_0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
self
.
_fc0_output_channels
,
kernel_size
=
1
,
stride
=
1
,
act
=
"relu"
,
name
=
"fc_0"
)
self
.
dropout
=
nn
.
Dropout
(
p
=
0.2
)
stdv
=
1.0
/
math
.
sqrt
(
self
.
_fc0_output_channels
*
1.0
)
self
.
fc_1
=
Linear
(
self
.
_fc0_output_channels
,
class_dim
,
weight_attr
=
ParamAttr
(
name
=
"fc_1_weights"
,
initializer
=
Uniform
(
-
stdv
,
stdv
)),
bias_attr
=
ParamAttr
(
name
=
"fc_1_offset"
))
def
forward
(
self
,
inputs
):
x
=
self
.
conv1
(
inputs
)
for
ghost_bottleneck
in
self
.
ghost_bottleneck_list
:
x
=
ghost_bottleneck
(
x
)
x
=
self
.
conv_last
(
x
)
x
=
self
.
pool2d_gap
(
x
)
x
=
self
.
fc_0
(
x
)
x
=
self
.
dropout
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
self
.
_fc0_output_channels
])
x
=
self
.
fc_1
(
x
)
return
x
def
_make_divisible
(
self
,
v
,
divisor
,
min_value
=
None
):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
# Make sure that round down does not go down by more than 10%.
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
def
GhostNet_x0_5
(
**
args
):
model
=
GhostNet
(
scale
=
0.5
)
return
model
def
GhostNet_x1_0
(
**
args
):
model
=
GhostNet
(
scale
=
1.0
)
return
model
def
GhostNet_x1_3
(
**
args
):
model
=
GhostNet
(
scale
=
1.3
)
return
model
ppcls/modeling/architectures/shufflenet_v2.py
浏览文件 @
fa6ce23f
...
...
@@ -16,15 +16,10 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2d
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2d
,
MaxPool2d
,
AvgPool2d
from
paddle
import
ParamAttr
,
reshape
,
transpose
,
concat
,
split
from
paddle.nn
import
Layer
,
Conv2d
,
MaxPool2d
,
AdaptiveAvgPool2d
,
BatchNorm
,
Linear
from
paddle.nn.initializer
import
MSRA
import
mat
h
from
paddle.nn.functional
import
swis
h
__all__
=
[
"ShuffleNetV2_x0_25"
,
"ShuffleNetV2_x0_33"
,
"ShuffleNetV2_x0_5"
,
...
...
@@ -34,188 +29,176 @@ __all__ = [
def
channel_shuffle
(
x
,
groups
):
batchsize
,
num_channels
,
height
,
width
=
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
],
x
.
shape
[
3
]
batch_size
,
num_channels
,
height
,
width
=
x
.
shape
[
0
:
4
]
channels_per_group
=
num_channels
//
groups
# reshape
x
=
paddle
.
reshape
(
x
=
x
,
shape
=
[
batchsize
,
groups
,
channels_per_group
,
height
,
width
])
x
=
reshape
(
x
=
x
,
shape
=
[
batch_size
,
groups
,
channels_per_group
,
height
,
width
])
# transpose
x
=
transpose
(
x
=
x
,
perm
=
[
0
,
2
,
1
,
3
,
4
])
x
=
paddle
.
transpose
(
x
=
x
,
perm
=
[
0
,
2
,
1
,
3
,
4
])
# flatten
x
=
paddle
.
reshape
(
x
=
x
,
shape
=
[
batch
size
,
num_channels
,
height
,
width
])
x
=
reshape
(
x
=
x
,
shape
=
[
batch_
size
,
num_channels
,
height
,
width
])
return
x
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
filter_size
,
num_filters
,
stride
,
padding
,
channels
=
None
,
num_groups
=
1
,
if_act
=
True
,
act
=
'relu'
,
name
=
None
):
class
ConvBNLayer
(
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
groups
=
1
,
act
=
None
,
name
=
None
,
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_if_act
=
if_act
assert
act
in
[
'relu'
,
'swish'
],
\
"supported act are {} but your act is {}"
.
format
(
[
'relu'
,
'swish'
],
act
)
self
.
_act
=
act
self
.
_conv
=
Conv2d
(
in_channels
=
num
_channels
,
out_channels
=
num_filter
s
,
kernel_size
=
filter
_size
,
in_channels
=
in
_channels
,
out_channels
=
out_channel
s
,
kernel_size
=
kernel
_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_
groups
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
self
.
_batch_norm
=
BatchNorm
(
num_filter
s
,
out_channel
s
,
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
),
act
=
act
,
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
def
forward
(
self
,
inputs
,
if_act
=
True
):
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
if
self
.
_if_act
:
y
=
F
.
relu
(
y
)
if
self
.
_act
==
'relu'
else
F
.
swish
(
y
)
return
y
class
InvertedResidual
Unit
(
nn
.
Layer
):
class
InvertedResidual
(
Layer
):
def
__init__
(
self
,
num
_channels
,
num_filter
s
,
in
_channels
,
out_channel
s
,
stride
,
benchmodel
,
act
=
'relu'
,
act
=
"relu"
,
name
=
None
):
super
(
InvertedResidualUnit
,
self
).
__init__
()
assert
stride
in
[
1
,
2
],
\
"supported stride are {} but your stride is {}"
.
format
([
1
,
2
],
stride
)
self
.
benchmodel
=
benchmodel
oup_inc
=
num_filters
//
2
inp
=
num_channels
if
benchmodel
==
1
:
self
.
_conv_pw
=
ConvBNLayer
(
num_channels
=
num_channels
//
2
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv1'
)
self
.
_conv_dw
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
oup_inc
,
if_act
=
False
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv2'
)
self
.
_conv_linear
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv3'
)
else
:
# branch1
self
.
_conv_dw_1
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
inp
,
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
inp
,
if_act
=
False
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv4'
)
self
.
_conv_linear_1
=
ConvBNLayer
(
num_channels
=
inp
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv5'
)
# branch2
self
.
_conv_pw_2
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv1'
)
self
.
_conv_dw_2
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
3
,
stride
=
stride
,
padding
=
1
,
num_groups
=
oup_inc
,
if_act
=
False
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv2'
)
self
.
_conv_linear_2
=
ConvBNLayer
(
num_channels
=
oup_inc
,
num_filters
=
oup_inc
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv3'
)
super
(
InvertedResidual
,
self
).
__init__
()
self
.
_conv_pw
=
ConvBNLayer
(
in_channels
=
in_channels
//
2
,
out_channels
=
out_channels
//
2
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv1'
)
self
.
_conv_dw
=
ConvBNLayer
(
in_channels
=
out_channels
//
2
,
out_channels
=
out_channels
//
2
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
groups
=
out_channels
//
2
,
act
=
None
,
name
=
'stage_'
+
name
+
'_conv2'
)
self
.
_conv_linear
=
ConvBNLayer
(
in_channels
=
out_channels
//
2
,
out_channels
=
out_channels
//
2
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv3'
)
def
forward
(
self
,
inputs
):
if
self
.
benchmodel
==
1
:
x1
,
x2
=
paddle
.
split
(
inputs
,
num_or_sections
=
[
inputs
.
shape
[
1
]
//
2
,
inputs
.
shape
[
1
]
//
2
],
axis
=
1
)
x2
=
self
.
_conv_pw
(
x2
)
x2
=
self
.
_conv_dw
(
x2
)
x2
=
self
.
_conv_linear
(
x2
)
out
=
paddle
.
concat
([
x1
,
x2
],
axis
=
1
)
else
:
x1
=
self
.
_conv_dw_1
(
inputs
)
x1
=
self
.
_conv_linear_1
(
x1
)
x1
,
x2
=
split
(
inputs
,
num_or_sections
=
[
inputs
.
shape
[
1
]
//
2
,
inputs
.
shape
[
1
]
//
2
],
axis
=
1
)
x2
=
self
.
_conv_pw
(
x2
)
x2
=
self
.
_conv_dw
(
x2
)
x2
=
self
.
_conv_linear
(
x2
)
out
=
concat
([
x1
,
x2
],
axis
=
1
)
return
channel_shuffle
(
out
,
2
)
x2
=
self
.
_conv_pw_2
(
inputs
)
x2
=
self
.
_conv_dw_2
(
x2
)
x2
=
self
.
_conv_linear_2
(
x2
)
out
=
paddle
.
concat
([
x1
,
x2
],
axis
=
1
)
class
InvertedResidualDS
(
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
act
=
"relu"
,
name
=
None
):
super
(
InvertedResidualDS
,
self
).
__init__
()
# branch1
self
.
_conv_dw_1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
groups
=
in_channels
,
act
=
None
,
name
=
'stage_'
+
name
+
'_conv4'
)
self
.
_conv_linear_1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
//
2
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv5'
)
# branch2
self
.
_conv_pw_2
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
//
2
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv1'
)
self
.
_conv_dw_2
=
ConvBNLayer
(
in_channels
=
out_channels
//
2
,
out_channels
=
out_channels
//
2
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
groups
=
out_channels
//
2
,
act
=
None
,
name
=
'stage_'
+
name
+
'_conv2'
)
self
.
_conv_linear_2
=
ConvBNLayer
(
in_channels
=
out_channels
//
2
,
out_channels
=
out_channels
//
2
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
1
,
act
=
act
,
name
=
'stage_'
+
name
+
'_conv3'
)
def
forward
(
self
,
inputs
):
x1
=
self
.
_conv_dw_1
(
inputs
)
x1
=
self
.
_conv_linear_1
(
x1
)
x2
=
self
.
_conv_pw_2
(
inputs
)
x2
=
self
.
_conv_dw_2
(
x2
)
x2
=
self
.
_conv_linear_2
(
x2
)
out
=
concat
([
x1
,
x2
],
axis
=
1
)
return
channel_shuffle
(
out
,
2
)
class
ShuffleNet
(
nn
.
Layer
):
def
__init__
(
self
,
class_dim
=
1000
,
scale
=
1.0
,
act
=
'relu'
):
class
ShuffleNet
(
Layer
):
def
__init__
(
self
,
class_dim
=
1000
,
scale
=
1.0
,
act
=
"relu"
):
super
(
ShuffleNet
,
self
).
__init__
()
self
.
scale
=
scale
self
.
class_dim
=
class_dim
...
...
@@ -238,58 +221,47 @@ class ShuffleNet(nn.Layer):
"] is not implemented!"
)
# 1. conv1
self
.
_conv1
=
ConvBNLayer
(
num
_channels
=
3
,
num_filter
s
=
stage_out_channels
[
1
],
filter
_size
=
3
,
in
_channels
=
3
,
out_channel
s
=
stage_out_channels
[
1
],
kernel
_size
=
3
,
stride
=
2
,
padding
=
1
,
if_act
=
True
,
act
=
act
,
name
=
'stage1_conv'
)
self
.
_max_pool
=
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
# 2. bottleneck sequences
self
.
_block_list
=
[]
i
=
1
in_c
=
int
(
32
*
scale
)
for
idxstage
in
range
(
len
(
stage_repeats
)):
numrepeat
=
stage_repeats
[
idxstage
]
output_channel
=
stage_out_channels
[
idxstage
+
2
]
for
i
in
range
(
numrepeat
):
for
stage_id
,
num_repeat
in
enumerate
(
stage_repeats
):
for
i
in
range
(
num_repeat
):
if
i
==
0
:
block
=
self
.
add_sublayer
(
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
),
InvertedResidualUnit
(
num_channels
=
stage_out_channels
[
idxstage
+
1
],
num_filters
=
output_channel
,
name
=
str
(
stage_id
+
2
)
+
'_'
+
str
(
i
+
1
),
sublayer
=
InvertedResidualDS
(
in_channels
=
stage_out_channels
[
stage_id
+
1
],
out_channels
=
stage_out_channels
[
stage_id
+
2
]
,
stride
=
2
,
benchmodel
=
2
,
act
=
act
,
name
=
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
)))
self
.
_block_list
.
append
(
block
)
name
=
str
(
stage_id
+
2
)
+
'_'
+
str
(
i
+
1
)))
else
:
block
=
self
.
add_sublayer
(
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
),
InvertedResidualUnit
(
num_channels
=
output_channel
,
num_filters
=
output_channel
,
name
=
str
(
stage_id
+
2
)
+
'_'
+
str
(
i
+
1
),
sublayer
=
InvertedResidual
(
in_channels
=
stage_out_channels
[
stage_id
+
2
]
,
out_channels
=
stage_out_channels
[
stage_id
+
2
]
,
stride
=
1
,
benchmodel
=
1
,
act
=
act
,
name
=
str
(
idxstage
+
2
)
+
'_'
+
str
(
i
+
1
)))
self
.
_block_list
.
append
(
block
)
name
=
str
(
stage_id
+
2
)
+
'_'
+
str
(
i
+
1
)))
self
.
_block_list
.
append
(
block
)
# 3. last_conv
self
.
_last_conv
=
ConvBNLayer
(
num
_channels
=
stage_out_channels
[
-
2
],
num_filter
s
=
stage_out_channels
[
-
1
],
filter
_size
=
1
,
in
_channels
=
stage_out_channels
[
-
2
],
out_channel
s
=
stage_out_channels
[
-
1
],
kernel
_size
=
1
,
stride
=
1
,
padding
=
0
,
if_act
=
True
,
act
=
act
,
name
=
'conv5'
)
# 4. pool
self
.
_pool2d_avg
=
AdaptiveAvgPool2d
(
1
)
self
.
_out_c
=
stage_out_channels
[
-
1
]
...
...
@@ -307,13 +279,13 @@ class ShuffleNet(nn.Layer):
y
=
inv
(
y
)
y
=
self
.
_last_conv
(
y
)
y
=
self
.
_pool2d_avg
(
y
)
y
=
paddle
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
_out_c
])
y
=
reshape
(
y
,
shape
=
[
-
1
,
self
.
_out_c
])
y
=
self
.
_fc
(
y
)
return
y
def
ShuffleNetV2_x0_25
(
**
args
):
model
=
ShuffleNet
V2
(
scale
=
0.25
,
**
args
)
model
=
ShuffleNet
(
scale
=
0.25
,
**
args
)
return
model
...
...
@@ -343,5 +315,5 @@ def ShuffleNetV2_x2_0(**args):
def
ShuffleNetV2_swish
(
**
args
):
model
=
ShuffleNet
(
scale
=
1.0
,
act
=
'swish'
,
**
args
)
model
=
ShuffleNet
(
scale
=
1.0
,
act
=
"swish"
,
**
args
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录