Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
92a162fb
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92a162fb
编写于
7月 01, 2020
作者:
S
shippingwang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add efficientnet-lite model
上级
9d3f36b7
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
733 addition
and
10 deletion
+733
-10
configs/EfficientNet/EfficientLite0.yaml
configs/EfficientNet/EfficientLite0.yaml
+91
-0
configs/EfficientNet/EfficientNetB0.yaml
configs/EfficientNet/EfficientNetB0.yaml
+1
-1
ppcls/modeling/architectures/__init__.py
ppcls/modeling/architectures/__init__.py
+3
-0
ppcls/modeling/architectures/efficientnetlite.py
ppcls/modeling/architectures/efficientnetlite.py
+627
-0
ppcls/modeling/architectures/layers.py
ppcls/modeling/architectures/layers.py
+11
-9
未找到文件。
configs/EfficientNet/EfficientLite0.yaml
0 → 100644
浏览文件 @
92a162fb
mode
:
'
train'
ARCHITECTURE
:
name
:
"
EfficientNetLite0"
params
:
is_test
:
False
padding_type
:
"
SAME"
override_params
:
drop_connect_rate
:
0.1
fix_head_stem
:
True
relu_fn
:
True
pretrained_model
:
"
"
model_save_dir
:
"
./output/"
classes_num
:
1000
total_images
:
1281167
save_interval
:
1
validate
:
True
valid_interval
:
1
epochs
:
360
topk
:
5
image_shape
:
[
3
,
224
,
224
]
use_ema
:
True
ema_decay
:
0.9999
use_aa
:
True
ls_epsilon
:
0.1
LEARNING_RATE
:
function
:
'
ExponentialWarmup'
params
:
lr
:
0.032
OPTIMIZER
:
function
:
'
RMSProp'
params
:
momentum
:
0.9
rho
:
0.9
epsilon
:
0.001
regularizer
:
function
:
'
L2'
factor
:
0.00001
TRAIN
:
batch_size
:
512
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/train_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
2
-
RandFlipImage
:
flip_code
:
1
-
AutoAugment
:
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
VALID
:
batch_size
:
128
num_workers
:
4
file_list
:
"
./dataset/ILSVRC2012/val_list.txt"
data_dir
:
"
./dataset/ILSVRC2012/"
shuffle_seed
:
0
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
False
channel_first
:
False
-
ResizeImage
:
interpolation
:
2
resize_short
:
412
-
CropImage
:
size
:
256
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
configs/EfficientNet/EfficientNetB0.yaml
浏览文件 @
92a162fb
...
...
@@ -46,7 +46,7 @@ TRAIN:
transforms
:
-
DecodeImage
:
to_rgb
:
True
to_np
:
Fals
to_np
:
Fals
e
channel_first
:
False
-
RandCropImage
:
size
:
224
...
...
ppcls/modeling/architectures/__init__.py
浏览文件 @
92a162fb
...
...
@@ -37,6 +37,9 @@ from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from
.darknet
import
DarkNet53
from
.resnext101_wsl
import
ResNeXt101_32x8d_wsl
,
ResNeXt101_32x16d_wsl
,
ResNeXt101_32x32d_wsl
,
ResNeXt101_32x48d_wsl
,
Fix_ResNeXt101_32x48d_wsl
from
.efficientnet
import
EfficientNet
,
EfficientNetB0
,
EfficientNetB0_small
,
EfficientNetB1
,
EfficientNetB2
,
EfficientNetB3
,
EfficientNetB4
,
EfficientNetB5
,
EfficientNetB6
,
EfficientNetB7
from
.efficientnetlite
import
EfficientNetLite
,
EfficientNetLite0
,
EfficientNetLite1
,
EfficientNetLite2
,
EfficientNetLite4
from
.res2net
import
Res2Net50_48w_2s
,
Res2Net50_26w_4s
,
Res2Net50_14w_8s
,
Res2Net50_26w_6s
,
Res2Net50_26w_8s
,
Res2Net101_26w_4s
,
Res2Net152_26w_4s
from
.res2net_vd
import
Res2Net50_vd_48w_2s
,
Res2Net50_vd_26w_4s
,
Res2Net50_vd_14w_8s
,
Res2Net50_vd_26w_6s
,
Res2Net50_vd_26w_8s
,
Res2Net101_vd_26w_4s
,
Res2Net152_vd_26w_4s
,
Res2Net200_vd_26w_4s
from
.hrnet
import
HRNet_W18_C
,
HRNet_W30_C
,
HRNet_W32_C
,
HRNet_W40_C
,
HRNet_W44_C
,
HRNet_W48_C
,
HRNet_W60_C
,
HRNet_W64_C
,
SE_HRNet_W18_C
,
SE_HRNet_W30_C
,
SE_HRNet_W32_C
,
SE_HRNet_W40_C
,
SE_HRNet_W44_C
,
SE_HRNet_W48_C
,
SE_HRNet_W60_C
,
SE_HRNet_W64_C
...
...
ppcls/modeling/architectures/efficientnetlite.py
0 → 100644
浏览文件 @
92a162fb
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
math
import
copy
import
paddle.fluid
as
fluid
from
.layers
import
conv2d
,
init_batch_norm_layer
,
init_fc_layer
__all__
=
[
'EfficientNetLite'
,
'EfficientNetLite0'
,
'EfficientNetLite1'
,
'EfficientNetLite2'
,
'EfficientNetLite3'
,
'EfficientNetLite4'
]
GlobalParams
=
collections
.
namedtuple
(
'GlobalParams'
,
[
'batch_norm_momentum'
,
'batch_norm_epsilon'
,
'dropout_rate'
,
'num_classes'
,
'width_coefficient'
,
'depth_coefficient'
,
'depth_divisor'
,
'min_depth'
,
'drop_connect_rate'
,
'fix_head_stem'
,
'relu_fn'
,
'local_pooling'
])
BlockArgs
=
collections
.
namedtuple
(
'BlockArgs'
,
[
'kernel_size'
,
'num_repeat'
,
'input_filters'
,
'output_filters'
,
'expand_ratio'
,
'id_skip'
,
'stride'
,
'se_ratio'
])
GlobalParams
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
GlobalParams
.
_fields
)
BlockArgs
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
BlockArgs
.
_fields
)
def
efficientnet_lite_params
(
model_name
):
""" Map EfficientNet model name to parameter coefficients. """
params_dict
=
{
# Coefficients: width,depth,resolution,dropout
'efficientnet-lite0'
:
(
1.0
,
1.0
,
224
,
0.2
),
'efficientnet-lite1'
:
(
1.0
,
1.1
,
240
,
0.2
),
'efficientnet-lite2'
:
(
1.1
,
1.2
,
260
,
0.3
),
'efficientnet-lite3'
:
(
1.2
,
1.4
,
280
,
0.3
),
'efficientnet-lite4'
:
(
1.4
,
1.8
,
300
,
0.3
),
}
return
params_dict
[
model_name
]
def
efficientnet_lite
(
width_coefficient
=
None
,
depth_coefficient
=
None
,
dropout_rate
=
0.2
,
drop_connect_rate
=
0.2
):
""" Get block arguments according to parameter and coefficients. """
blocks_args
=
[
'r1_k3_s11_e1_i32_o16_se0.25'
,
'r2_k3_s22_e6_i16_o24_se0.25'
,
'r2_k5_s22_e6_i24_o40_se0.25'
,
'r3_k3_s22_e6_i40_o80_se0.25'
,
'r3_k5_s11_e6_i80_o112_se0.25'
,
'r4_k5_s22_e6_i112_o192_se0.25'
,
'r1_k3_s11_e6_i192_o320_se0.25'
,
]
blocks_args
=
BlockDecoder
.
decode
(
blocks_args
)
global_params
=
GlobalParams
(
batch_norm_momentum
=
0.99
,
batch_norm_epsilon
=
1e-3
,
dropout_rate
=
dropout_rate
,
drop_connect_rate
=
drop_connect_rate
,
num_classes
=
1000
,
width_coefficient
=
width_coefficient
,
depth_coefficient
=
depth_coefficient
,
depth_divisor
=
8
,
min_depth
=
None
,
# FOR LITE, use relu6 for easier quantization
relu_fn
=
True
,
# FOR LITE, Don't scale in Lite model
fix_head_stem
=
True
,
# FOR LITE,
local_pooling
=
True
)
return
blocks_args
,
global_params
def
get_model_params
(
model_name
,
override_params
):
""" Get the block args and global params for a given model """
if
model_name
.
startswith
(
'efficientnet-lite'
):
w
,
d
,
_
,
p
=
efficientnet_lite_params
(
model_name
)
blocks_args
,
global_params
=
efficientnet_lite
(
width_coefficient
=
w
,
depth_coefficient
=
d
,
dropout_rate
=
p
)
else
:
raise
NotImplementedError
(
'model name is not pre-defined: %s'
%
model_name
)
if
override_params
:
global_params
=
global_params
.
_replace
(
**
override_params
)
return
blocks_args
,
global_params
def
round_filters
(
filters
,
global_params
,
skip
=
False
):
""" Calculate and round number of filters based on depth multiplier. """
multiplier
=
global_params
.
width_coefficient
if
skip
or
not
multiplier
:
return
filters
divisor
=
global_params
.
depth_divisor
min_depth
=
global_params
.
min_depth
filters
*=
multiplier
min_depth
=
min_depth
or
divisor
new_filters
=
max
(
min_depth
,
int
(
filters
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_filters
<
0.9
*
filters
:
# prevent rounding by more than 10%
new_filters
+=
divisor
return
int
(
new_filters
)
def
round_repeats
(
repeats
,
global_params
,
skip
=
False
):
""" Round number of filters based on depth multiplier. """
multiplier
=
global_params
.
depth_coefficient
if
skip
or
not
multiplier
:
return
repeats
return
int
(
math
.
ceil
(
multiplier
*
repeats
))
class
EfficientNetLite
():
def
__init__
(
self
,
name
=
'lite0'
,
padding_type
=
'SAME'
,
override_params
=
None
,
is_test
=
False
,
# For Lite, Don't use SE
use_se
=
False
):
valid_names
=
[
'lite'
+
str
(
i
)
for
i
in
range
(
5
)]
assert
name
in
valid_names
,
'efficientlite name should be in b0~b7'
model_name
=
'efficientnet-'
+
name
self
.
_blocks_args
,
self
.
_global_params
=
get_model_params
(
model_name
,
override_params
)
print
(
"global_params"
,
self
.
_global_params
)
self
.
_bn_mom
=
self
.
_global_params
.
batch_norm_momentum
self
.
_bn_eps
=
self
.
_global_params
.
batch_norm_epsilon
self
.
is_test
=
is_test
self
.
padding_type
=
padding_type
self
.
use_se
=
use_se
self
.
_relu_fn
=
self
.
_global_params
.
relu_fn
self
.
_fix_head_stem
=
self
.
_global_params
.
fix_head_stem
self
.
local_pooling
=
self
.
_global_params
.
local_pooling
# NCHW spatial: HW
self
.
_spatial_dims
=
[
2
,
3
]
def
net
(
self
,
input
,
class_dim
=
1000
,
is_test
=
False
):
conv
=
self
.
extract_features
(
input
,
is_test
=
is_test
)
out_channels
=
round_filters
(
1280
,
self
.
_global_params
,
self
.
_fix_head_stem
)
conv
=
self
.
conv_bn_layer
(
conv
,
num_filters
=
out_channels
,
filter_size
=
1
,
bn_act
=
'relu6'
if
self
.
_relu_fn
else
'swish'
,
# for lite
bn_mom
=
self
.
_bn_mom
,
bn_eps
=
self
.
_bn_eps
,
padding_type
=
self
.
padding_type
,
name
=
''
,
conv_name
=
'_conv_head'
,
bn_name
=
'_bn1'
)
pool
=
fluid
.
layers
.
pool2d
(
input
=
conv
,
pool_type
=
'avg'
,
global_pooling
=
True
,
use_cudnn
=
False
)
if
self
.
_global_params
.
dropout_rate
:
pool
=
fluid
.
layers
.
dropout
(
pool
,
self
.
_global_params
.
dropout_rate
,
dropout_implementation
=
'upscale_in_train'
)
param_attr
,
bias_attr
=
init_fc_layer
(
class_dim
,
'_fc'
)
out
=
fluid
.
layers
.
fc
(
pool
,
class_dim
,
name
=
'_fc'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
return
out
def
_drop_connect
(
self
,
inputs
,
prob
,
is_test
):
if
is_test
:
return
inputs
keep_prob
=
1.0
-
prob
inputs_shape
=
fluid
.
layers
.
shape
(
inputs
)
random_tensor
=
keep_prob
+
fluid
.
layers
.
uniform_random
(
shape
=
[
inputs_shape
[
0
],
1
,
1
,
1
],
min
=
0.
,
max
=
1.
)
binary_tensor
=
fluid
.
layers
.
floor
(
random_tensor
)
output
=
inputs
/
keep_prob
*
binary_tensor
return
output
def
_expand_conv_norm
(
self
,
inputs
,
block_args
,
is_test
,
name
=
None
):
# Expansion phase
oup
=
block_args
.
input_filters
*
\
block_args
.
expand_ratio
# number of output channels
if
block_args
.
expand_ratio
!=
1
:
conv
=
self
.
conv_bn_layer
(
inputs
,
num_filters
=
oup
,
filter_size
=
1
,
bn_act
=
None
,
bn_mom
=
self
.
_bn_mom
,
bn_eps
=
self
.
_bn_eps
,
padding_type
=
self
.
padding_type
,
name
=
name
,
conv_name
=
name
+
'_expand_conv'
,
bn_name
=
'_bn0'
)
return
conv
def
_depthwise_conv_norm
(
self
,
inputs
,
block_args
,
is_test
,
name
=
None
):
k
=
block_args
.
kernel_size
s
=
block_args
.
stride
if
isinstance
(
s
,
list
)
or
isinstance
(
s
,
tuple
):
s
=
s
[
0
]
oup
=
block_args
.
input_filters
*
\
block_args
.
expand_ratio
# number of output channels
conv
=
self
.
conv_bn_layer
(
inputs
,
num_filters
=
oup
,
filter_size
=
k
,
stride
=
s
,
num_groups
=
oup
,
bn_act
=
None
,
padding_type
=
self
.
padding_type
,
bn_mom
=
self
.
_bn_mom
,
bn_eps
=
self
.
_bn_eps
,
name
=
name
,
use_cudnn
=
False
,
conv_name
=
name
+
'_depthwise_conv'
,
bn_name
=
'_bn1'
)
return
conv
def
_project_conv_norm
(
self
,
inputs
,
block_args
,
is_test
,
name
=
None
):
final_oup
=
block_args
.
output_filters
conv
=
self
.
conv_bn_layer
(
inputs
,
num_filters
=
final_oup
,
filter_size
=
1
,
bn_act
=
None
,
padding_type
=
self
.
padding_type
,
bn_mom
=
self
.
_bn_mom
,
bn_eps
=
self
.
_bn_eps
,
name
=
name
,
conv_name
=
name
+
'_project_conv'
,
bn_name
=
'_bn2'
)
return
conv
def
conv_bn_layer
(
self
,
input
,
filter_size
,
num_filters
,
stride
=
1
,
num_groups
=
1
,
padding_type
=
"SAME"
,
conv_act
=
None
,
bn_act
=
'relu6'
,
# if self._relu_fn else 'swish',
use_cudnn
=
True
,
use_bn
=
True
,
bn_mom
=
0.9
,
bn_eps
=
1e-05
,
use_bias
=
False
,
name
=
None
,
conv_name
=
None
,
bn_name
=
None
):
conv
=
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
groups
=
num_groups
,
act
=
conv_act
,
padding_type
=
padding_type
,
use_cudnn
=
use_cudnn
,
name
=
conv_name
,
use_bias
=
use_bias
)
if
use_bn
is
False
:
return
conv
else
:
bn_name
=
name
+
bn_name
param_attr
,
bias_attr
=
init_batch_norm_layer
(
bn_name
)
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
bn_act
,
momentum
=
bn_mom
,
epsilon
=
bn_eps
,
name
=
bn_name
,
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
def
_conv_stem_norm
(
self
,
inputs
,
is_test
):
out_channels
=
round_filters
(
32
,
self
.
_global_params
,
self
.
_fix_head_stem
)
bn
=
self
.
conv_bn_layer
(
inputs
,
num_filters
=
out_channels
,
filter_size
=
3
,
stride
=
2
,
bn_act
=
None
,
bn_mom
=
self
.
_bn_mom
,
padding_type
=
self
.
padding_type
,
bn_eps
=
self
.
_bn_eps
,
name
=
''
,
conv_name
=
'_conv_stem'
,
bn_name
=
'_bn0'
)
return
bn
def
mb_conv_block
(
self
,
inputs
,
block_args
,
is_test
=
False
,
drop_connect_rate
=
None
,
name
=
None
):
# Expansion and Depthwise Convolution
oup
=
block_args
.
input_filters
*
\
block_args
.
expand_ratio
# number of output channels
has_se
=
self
.
use_se
and
(
block_args
.
se_ratio
is
not
None
)
and
(
0
<
block_args
.
se_ratio
<=
1
)
id_skip
=
block_args
.
id_skip
# skip connection and drop connect
conv
=
inputs
if
block_args
.
expand_ratio
!=
1
:
if
self
.
_relu_fn
:
conv
=
fluid
.
layers
.
relu6
(
self
.
_expand_conv_norm
(
conv
,
block_args
,
is_test
,
name
))
else
:
conv
=
fluid
.
layers
.
swish
(
self
.
_expand_conv_norm
(
conv
,
block_args
,
is_test
,
name
))
if
self
.
_relu_fn
:
conv
=
fluid
.
layers
.
relu6
(
self
.
_depthwise_conv_norm
(
conv
,
block_args
,
is_test
,
name
))
else
:
conv
=
fluid
.
layers
.
swish
(
self
.
_depthwise_conv_norm
(
conv
,
block_args
,
is_test
,
name
))
# Squeeze and Excitation
if
has_se
:
num_squeezed_channels
=
max
(
1
,
int
(
block_args
.
input_filters
*
block_args
.
se_ratio
))
conv
=
self
.
se_block
(
conv
,
num_squeezed_channels
,
oup
,
name
)
conv
=
self
.
_project_conv_norm
(
conv
,
block_args
,
is_test
,
name
)
# Skip connection and drop connect
input_filters
=
block_args
.
input_filters
output_filters
=
block_args
.
output_filters
if
id_skip
and
\
block_args
.
stride
==
1
and
\
input_filters
==
output_filters
:
if
drop_connect_rate
:
conv
=
self
.
_drop_connect
(
conv
,
drop_connect_rate
,
self
.
is_test
)
conv
=
fluid
.
layers
.
elementwise_add
(
conv
,
inputs
)
return
conv
def
se_block
(
self
,
inputs
,
num_squeezed_channels
,
oup
,
name
):
if
self
.
local_pooling
:
shape
=
inputs
.
shape
x_squeezed
=
fluid
.
layers
.
pool2d
(
input
=
inputs
,
pool_size
=
[
shape
[
self
.
_spatial_dims
[
0
]],
shape
[
self
.
_spatial_dims
[
1
]]
],
pool_stride
=
[
1
,
1
],
pool_padding
=
'VALID'
)
else
:
# same as tf: reduce_sum
x_squeezed
=
fluid
.
layers
.
pool2d
(
input
=
inputs
,
pool_type
=
'avg'
,
global_pooling
=
True
,
use_cudnn
=
False
)
x_squeezed
=
conv2d
(
x_squeezed
,
num_filters
=
num_squeezed_channels
,
filter_size
=
1
,
use_bias
=
True
,
padding_type
=
self
.
padding_type
,
act
=
'relu6'
if
self
.
_relu_fn
else
'swish'
,
name
=
name
+
'_se_reduce'
)
x_squeezed
=
conv2d
(
x_squeezed
,
num_filters
=
oup
,
filter_size
=
1
,
use_bias
=
True
,
padding_type
=
self
.
padding_type
,
name
=
name
+
'_se_expand'
)
#se_out = inputs * fluid.layers.sigmoid(x_squeezed)
se_out
=
fluid
.
layers
.
elementwise_mul
(
inputs
,
fluid
.
layers
.
sigmoid
(
x_squeezed
),
axis
=-
1
)
return
se_out
def
extract_features
(
self
,
inputs
,
is_test
):
""" Returns output of the final convolution layer """
if
self
.
_relu_fn
:
conv
=
fluid
.
layers
.
relu6
(
self
.
_conv_stem_norm
(
inputs
,
is_test
=
is_test
))
else
:
fluid
.
layers
.
swish
(
self
.
_conv_stem_norm
(
inputs
,
is_test
=
is_test
))
block_args_copy
=
copy
.
deepcopy
(
self
.
_blocks_args
)
idx
=
0
block_size
=
0
for
i
,
block_arg
in
enumerate
(
block_args_copy
):
block_arg
=
block_arg
.
_replace
(
input_filters
=
round_filters
(
block_arg
.
input_filters
,
self
.
_global_params
),
output_filters
=
round_filters
(
block_arg
.
output_filters
,
self
.
_global_params
),
# Lite
num_repeat
=
block_arg
.
num_repeat
if
self
.
_fix_head_stem
and
(
i
==
0
or
i
==
len
(
block_args_copy
)
-
1
)
else
round_repeats
(
block_arg
.
num_repeat
,
self
.
_global_params
))
block_size
+=
1
for
_
in
range
(
block_arg
.
num_repeat
-
1
):
block_size
+=
1
for
i
,
block_args
in
enumerate
(
self
.
_blocks_args
):
# Update block input and output filters based on depth multiplier.
block_args
=
block_args
.
_replace
(
input_filters
=
round_filters
(
block_args
.
input_filters
,
self
.
_global_params
),
output_filters
=
round_filters
(
block_args
.
output_filters
,
self
.
_global_params
),
# Lite
num_repeat
=
block_args
.
num_repeat
if
self
.
_fix_head_stem
and
(
i
==
0
or
i
==
len
(
self
.
_blocks_args
)
-
1
)
else
round_repeats
(
block_args
.
num_repeat
,
self
.
_global_params
))
# The first block needs to take care of stride,
# and filter size increase.
drop_connect_rate
=
self
.
_global_params
.
drop_connect_rate
if
drop_connect_rate
:
drop_connect_rate
*=
float
(
idx
)
/
block_size
conv
=
self
.
mb_conv_block
(
conv
,
block_args
,
is_test
,
drop_connect_rate
,
'_blocks.'
+
str
(
idx
)
+
'.'
)
idx
+=
1
if
block_args
.
num_repeat
>
1
:
block_args
=
block_args
.
_replace
(
input_filters
=
block_args
.
output_filters
,
stride
=
1
)
for
_
in
range
(
block_args
.
num_repeat
-
1
):
drop_connect_rate
=
self
.
_global_params
.
drop_connect_rate
if
drop_connect_rate
:
drop_connect_rate
*=
float
(
idx
)
/
block_size
conv
=
self
.
mb_conv_block
(
conv
,
block_args
,
is_test
,
drop_connect_rate
,
'_blocks.'
+
str
(
idx
)
+
'.'
)
idx
+=
1
return
conv
def
shortcut
(
self
,
input
,
data_residual
):
return
fluid
.
layers
.
elementwise_add
(
input
,
data_residual
)
class
BlockDecoder
(
object
):
"""
Block Decoder, straight from the official TensorFlow repository.
"""
@
staticmethod
def
_decode_block_string
(
block_string
):
""" Gets a block through a string notation of arguments. """
assert
isinstance
(
block_string
,
str
)
ops
=
block_string
.
split
(
'_'
)
options
=
{}
for
op
in
ops
:
splits
=
re
.
split
(
r
'(\d.*)'
,
op
)
if
len
(
splits
)
>=
2
:
key
,
value
=
splits
[:
2
]
options
[
key
]
=
value
# Check stride
cond_1
=
(
's'
in
options
and
len
(
options
[
's'
])
==
1
)
cond_2
=
((
len
(
options
[
's'
])
==
2
)
and
(
options
[
's'
][
0
]
==
options
[
's'
][
1
]))
assert
(
cond_1
or
cond_2
)
return
BlockArgs
(
kernel_size
=
int
(
options
[
'k'
]),
num_repeat
=
int
(
options
[
'r'
]),
input_filters
=
int
(
options
[
'i'
]),
output_filters
=
int
(
options
[
'o'
]),
expand_ratio
=
int
(
options
[
'e'
]),
id_skip
=
(
'noskip'
not
in
block_string
),
se_ratio
=
float
(
options
[
'se'
])
if
'se'
in
options
else
None
,
stride
=
[
int
(
options
[
's'
][
0
])])
@
staticmethod
def
_encode_block_string
(
block
):
"""Encodes a block to a string."""
args
=
[
'r%d'
%
block
.
num_repeat
,
'k%d'
%
block
.
kernel_size
,
's%d%d'
%
(
block
.
strides
[
0
],
block
.
strides
[
1
]),
'e%s'
%
block
.
expand_ratio
,
'i%d'
%
block
.
input_filters
,
'o%d'
%
block
.
output_filters
]
if
0
<
block
.
se_ratio
<=
1
:
args
.
append
(
'se%s'
%
block
.
se_ratio
)
if
block
.
id_skip
is
False
:
args
.
append
(
'noskip'
)
return
'_'
.
join
(
args
)
@
staticmethod
def
decode
(
string_list
):
"""
Decode a list of string notations to specify blocks in the network.
string_list: list of strings, each string is a notation of block
return
list of BlockArgs namedtuples of block args
"""
assert
isinstance
(
string_list
,
list
)
blocks_args
=
[]
for
block_string
in
string_list
:
blocks_args
.
append
(
BlockDecoder
.
_decode_block_string
(
block_string
))
return
blocks_args
@
staticmethod
def
encode
(
blocks_args
):
"""
Encodes a list of BlockArgs to a list of strings.
:param blocks_args: a list of BlockArgs namedtuples of block args
:return: a list of strings, each string is a notation of block
"""
block_strings
=
[]
for
block
in
blocks_args
:
block_strings
.
append
(
BlockDecoder
.
_encode_block_string
(
block
))
return
block_strings
def
EfficientNetLite0
(
is_test
=
False
,
padding_type
=
'SAME'
,
override_params
=
None
,
use_se
=
True
):
model
=
EfficientNetLite
(
name
=
'lite0'
,
is_test
=
is_test
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
)
return
model
def
EfficientNetLite1
(
is_test
=
False
,
padding_type
=
'SAME'
,
override_params
=
None
,
use_se
=
True
):
model
=
EfficientNetLite
(
name
=
'lite1'
,
is_test
=
is_test
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
)
return
model
def
EfficientNetLite2
(
is_test
=
False
,
padding_type
=
'SAME'
,
override_params
=
None
,
use_se
=
True
):
model
=
EfficientNetLite
(
name
=
'lite2'
,
is_test
=
is_test
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
)
return
model
def
EfficientNetLite3
(
is_test
=
False
,
padding_type
=
'SAME'
,
override_params
=
None
,
use_se
=
True
):
model
=
EfficientNetLite
(
name
=
'lite3'
,
is_test
=
is_test
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
)
return
model
def
EfficientNetLite4
(
is_test
=
False
,
padding_type
=
'SAME'
,
override_params
=
None
,
use_se
=
True
):
model
=
EfficientNetLite
(
name
=
'lite4'
,
is_test
=
is_test
,
padding_type
=
padding_type
,
override_params
=
override_params
,
use_se
=
use_se
)
return
model
ppcls/modeling/architectures/layers.py
浏览文件 @
92a162fb
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
Licensed under the Apache License, Version 2.0 (the "License");
#
you may not use this file except in compliance with the License.
#
You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
Unless required by applicable law or agreed to in writing, software
#
distributed under the License is distributed on an "AS IS" BASIS,
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
See the License for the specific language governing permissions and
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -242,6 +242,8 @@ def conv2d(input,
conv
=
fluid
.
layers
.
sigmoid
(
conv
,
name
=
name
+
'_sigmoid'
)
elif
act
==
'swish'
:
conv
=
fluid
.
layers
.
swish
(
conv
,
name
=
name
+
'_swish'
)
elif
act
==
'relu6'
:
conv
=
fluid
.
layers
.
relu6
(
conv
,
name
=
name
+
'_relu6'
)
elif
act
==
None
:
conv
=
conv
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录