Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
8064bab9
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8064bab9
编写于
9月 11, 2020
作者:
W
wubinghong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add drop-connected in efficientnet & refine the bifpn
上级
ae38108a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
276 addition
and
133 deletion
+276
-133
configs/efficientdet_d0.yml
configs/efficientdet_d0.yml
+3
-5
ppdet/modeling/architectures/efficientdet.py
ppdet/modeling/architectures/efficientdet.py
+1
-1
ppdet/modeling/backbones/bifpn.py
ppdet/modeling/backbones/bifpn.py
+149
-74
ppdet/modeling/backbones/efficientnet.py
ppdet/modeling/backbones/efficientnet.py
+123
-53
未找到文件。
configs/efficientdet_d0.yml
浏览文件 @
8064bab9
...
...
@@ -19,9 +19,7 @@ EfficientDet:
box_loss_weight
:
50.
EfficientNet
:
# norm_type: sync_bn
# TODO
norm_type
:
bn
norm_type
:
sync_bn
scale
:
b0
use_se
:
true
...
...
@@ -39,9 +37,9 @@ EfficientHead:
alpha
:
0.25
delta
:
0.1
output_decoder
:
score_thresh
:
0.0
5
# originally 0.
score_thresh
:
0.0
nms_thresh
:
0.5
pre_nms_top_n
:
1000
# originally
5000
pre_nms_top_n
:
5000
detections_per_im
:
100
nms_eta
:
1.0
...
...
ppdet/modeling/architectures/efficientdet.py
浏览文件 @
8064bab9
...
...
@@ -64,7 +64,7 @@ class EfficientDet(object):
mixed_precision_enabled
=
mixed_precision_global_state
()
is
not
None
if
mixed_precision_enabled
:
im
=
fluid
.
layers
.
cast
(
im
,
'float16'
)
body_feats
=
self
.
backbone
(
im
)
body_feats
=
self
.
backbone
(
im
,
mode
)
if
mixed_precision_enabled
:
body_feats
=
[
fluid
.
layers
.
cast
(
f
,
'float32'
)
for
f
in
body_feats
]
body_feats
=
self
.
fpn
(
body_feats
)
...
...
ppdet/modeling/backbones/bifpn.py
浏览文件 @
8064bab9
...
...
@@ -41,7 +41,8 @@ class FusionConv(object):
groups
=
self
.
num_chan
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
(),
name
=
name
+
'_dw_w'
),
bias_attr
=
False
)
bias_attr
=
False
,
use_cudnn
=
False
)
# pointwise
x
=
fluid
.
layers
.
conv2d
(
x
,
...
...
@@ -66,58 +67,87 @@ class FusionConv(object):
class
BiFPNCell
(
object
):
def
__init__
(
self
,
num_chan
,
levels
=
5
):
def
__init__
(
self
,
num_chan
,
levels
=
5
,
inputs_layer_num
=
3
):
"""
# Node id starts from the input features and monotonically increase whenever
# [Node NO.] Here is an example for level P3 - P7:
# {3: [0, 8],
# 4: [1, 7, 9],
# 5: [2, 6, 10],
# 6: [3, 5, 11],
# 7: [4, 12]}
# [Related Edge]
# {'feat_level': 6, 'inputs_offsets': [3, 4]}, # for P6'
# {'feat_level': 5, 'inputs_offsets': [2, 5]}, # for P5'
# {'feat_level': 4, 'inputs_offsets': [1, 6]}, # for P4'
# {'feat_level': 3, 'inputs_offsets': [0, 7]}, # for P3"
# {'feat_level': 4, 'inputs_offsets': [1, 7, 8]}, # for P4"
# {'feat_level': 5, 'inputs_offsets': [2, 6, 9]}, # for P5"
# {'feat_level': 6, 'inputs_offsets': [3, 5, 10]}, # for P6"
# {'feat_level': 7, 'inputs_offsets': [4, 11]}, # for P7"
P7 (4) --------------> P7" (12)
|----------| ↑
↓ |
P6 (3) --> P6' (5) --> P6" (11)
|----------|----------↑↑
↓ |
P5 (2) --> P5' (6) --> P5" (10)
|----------|----------↑↑
↓ |
P4 (1) --> P4' (7) --> P4" (9)
|----------|----------↑↑
|----------↓|
P3 (0) --------------> P3" (8)
"""
super
(
BiFPNCell
,
self
).
__init__
()
self
.
levels
=
levels
self
.
num_chan
=
num_chan
num_trigates
=
levels
-
2
num_bigates
=
levels
self
.
inputs_layer_num
=
inputs_layer_num
# Learnable weights of [P4", P5", P6"]
self
.
trigates
=
fluid
.
layers
.
create_parameter
(
shape
=
[
num_trigates
,
3
],
shape
=
[
levels
-
2
,
3
],
dtype
=
'float32'
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
# Learnable weights of [P6', P5', P4', P3", P7"]
self
.
bigates
=
fluid
.
layers
.
create_parameter
(
shape
=
[
num_bigate
s
,
2
],
shape
=
[
level
s
,
2
],
dtype
=
'float32'
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
self
.
eps
=
1e-4
def
__call__
(
self
,
inputs
,
cell_name
=
''
):
def
__call__
(
self
,
inputs
,
cell_name
=
''
,
is_first_time
=
False
,
p4_2_p5_2
=
[]
):
assert
len
(
inputs
)
==
self
.
levels
assert
((
is_first_time
)
and
(
len
(
p4_2_p5_2
)
!=
0
))
or
((
not
is_first_time
)
and
(
len
(
p4_2_p5_2
)
==
0
))
# upsample operator
def
upsample
(
feat
):
return
fluid
.
layers
.
resize_nearest
(
feat
,
scale
=
2.
)
# downsample operator
def
downsample
(
feat
):
return
fluid
.
layers
.
pool2d
(
feat
,
pool_type
=
'max'
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
'SAME'
)
return
fluid
.
layers
.
pool2d
(
feat
,
pool_type
=
'max'
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
'SAME'
)
# 3x3 fuse conv after OP combine
fuse_conv
=
FusionConv
(
self
.
num_chan
)
#
n
ormalize weight
#
N
ormalize weight
trigates
=
fluid
.
layers
.
relu
(
self
.
trigates
)
bigates
=
fluid
.
layers
.
relu
(
self
.
bigates
)
trigates
/=
fluid
.
layers
.
reduce_sum
(
trigates
,
dim
=
1
,
keep_dim
=
True
)
+
self
.
eps
bigates
/=
fluid
.
layers
.
reduce_sum
(
bigates
,
dim
=
1
,
keep_dim
=
True
)
+
self
.
eps
trigates
/=
fluid
.
layers
.
reduce_sum
(
trigates
,
dim
=
1
,
keep_dim
=
True
)
+
self
.
eps
bigates
/=
fluid
.
layers
.
reduce_sum
(
bigates
,
dim
=
1
,
keep_dim
=
True
)
+
self
.
eps
feature_maps
=
list
(
inputs
)
# make a copy
feature_maps
=
list
(
inputs
)
# make a copy
, 依次是 [P3, P4, P5, P6, P7]
# top down path
for
l
in
range
(
self
.
levels
-
1
):
p
=
self
.
levels
-
l
-
2
w1
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
l
,
0
],
ends
=
[
l
+
1
,
1
])
w2
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
l
,
1
],
ends
=
[
l
+
1
,
2
])
above
=
upsample
(
feature_maps
[
p
+
1
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
above
+
w2
*
inputs
[
p
],
name
=
'{}_tb_{}'
.
format
(
cell_name
,
l
))
w1
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
l
,
0
],
ends
=
[
l
+
1
,
1
])
w2
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
l
,
1
],
ends
=
[
l
+
1
,
2
])
above_layer
=
upsample
(
feature_maps
[
p
+
1
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
above_layer
+
w2
*
inputs
[
p
],
name
=
'{}_tb_{}'
.
format
(
cell_name
,
l
))
# bottom up path
for
l
in
range
(
1
,
self
.
levels
):
p
=
l
...
...
@@ -125,22 +155,26 @@ class BiFPNCell(object):
below
=
downsample
(
feature_maps
[
p
-
1
])
if
p
==
self
.
levels
-
1
:
# handle P7
w1
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
p
,
0
],
ends
=
[
p
+
1
,
1
])
w2
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
p
,
1
],
ends
=
[
p
+
1
,
2
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
below
+
w2
*
inputs
[
p
],
name
=
name
)
w1
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
p
,
0
],
ends
=
[
p
+
1
,
1
])
w2
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
p
,
1
],
ends
=
[
p
+
1
,
2
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
below
+
w2
*
inputs
[
p
],
name
=
name
)
else
:
w1
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
0
],
ends
=
[
p
,
1
])
w2
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
1
],
ends
=
[
p
,
2
])
w3
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
2
],
ends
=
[
p
,
3
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
feature_maps
[
p
]
+
w2
*
below
+
w3
*
inputs
[
p
],
name
=
name
)
if
is_first_time
:
if
p
<
self
.
inputs_layer_num
:
w1
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
0
],
ends
=
[
p
,
1
])
w2
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
1
],
ends
=
[
p
,
2
])
w3
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
2
],
ends
=
[
p
,
3
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
feature_maps
[
p
]
+
w2
*
below
+
w3
*
p4_2_p5_2
[
p
-
1
],
name
=
name
)
else
:
# For P6"
w1
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
0
],
ends
=
[
p
,
1
])
w2
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
1
],
ends
=
[
p
,
2
])
w3
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
2
],
ends
=
[
p
,
3
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
feature_maps
[
p
]
+
w2
*
below
+
w3
*
inputs
[
p
],
name
=
name
)
else
:
w1
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
0
],
ends
=
[
p
,
1
])
w2
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
1
],
ends
=
[
p
,
2
])
w3
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
2
],
ends
=
[
p
,
3
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
feature_maps
[
p
]
+
w2
*
below
+
w3
*
inputs
[
p
],
name
=
name
)
return
feature_maps
...
...
@@ -163,40 +197,81 @@ class BiFPN(object):
def
__call__
(
self
,
inputs
):
feats
=
[]
# NOTE add two extra levels
for
idx
in
range
(
self
.
levels
):
if
idx
<=
len
(
inputs
):
if
idx
==
len
(
inputs
):
feat
=
inputs
[
-
1
]
else
:
feat
=
inputs
[
idx
]
if
feat
.
shape
[
1
]
!=
self
.
num_chan
:
feat
=
fluid
.
layers
.
conv2d
(
feat
,
self
.
num_chan
,
filter_size
=
1
,
padding
=
'SAME'
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
()),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)))
feat
=
fluid
.
layers
.
batch_norm
(
feat
,
momentum
=
0.997
,
epsilon
=
1e-04
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
1.0
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)))
if
idx
>=
len
(
inputs
):
feat
=
fluid
.
layers
.
pool2d
(
# Squeeze the channel with 1x1 conv
for
idx
in
range
(
len
(
inputs
)):
if
inputs
[
idx
].
shape
[
1
]
!=
self
.
num_chan
:
feat
=
fluid
.
layers
.
conv2d
(
inputs
[
idx
],
self
.
num_chan
,
filter_size
=
1
,
padding
=
'SAME'
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
()),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)),
name
=
'resample_conv_{}'
.
format
(
idx
))
feat
=
fluid
.
layers
.
batch_norm
(
feat
,
pool_type
=
'max'
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
'SAME'
)
momentum
=
0.997
,
epsilon
=
1e-04
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
1.0
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)),
name
=
'resample_bn_{}'
.
format
(
idx
))
else
:
feat
=
inputs
[
idx
]
feats
.
append
(
feat
)
# Build additional input features that are not from backbone.
# P_7 layer we just use pool2d without conv layer & bn, for the same channel with P_6.
# https://github.com/google/automl/blob/master/efficientdet/keras/efficientdet_keras.py#L820
for
idx
in
range
(
len
(
inputs
),
self
.
levels
):
if
feats
[
-
1
].
shape
[
1
]
!=
self
.
num_chan
:
feat
=
fluid
.
layers
.
conv2d
(
feats
[
-
1
],
self
.
num_chan
,
filter_size
=
1
,
padding
=
'SAME'
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
()),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)),
name
=
'resample_conv_{}'
.
format
(
idx
))
feat
=
fluid
.
layers
.
batch_norm
(
feat
,
momentum
=
0.997
,
epsilon
=
1e-04
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
1.0
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)),
name
=
'resample_bn_{}'
.
format
(
idx
))
feat
=
fluid
.
layers
.
pool2d
(
feat
,
pool_type
=
'max'
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
'SAME'
,
name
=
'resample_downsample_{}'
.
format
(
idx
))
feats
.
append
(
feat
)
# Handle the p4_2 and p5_2 with another 1x1 conv & bn layer
p4_2_p5_2
=
[]
for
idx
in
range
(
1
,
len
(
inputs
)):
feat
=
fluid
.
layers
.
conv2d
(
inputs
[
idx
],
self
.
num_chan
,
filter_size
=
1
,
padding
=
'SAME'
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
()),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)),
name
=
'resample2_conv_{}'
.
format
(
idx
))
feat
=
fluid
.
layers
.
batch_norm
(
feat
,
momentum
=
0.997
,
epsilon
=
1e-04
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
1.0
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)),
name
=
'resample2_bn_{}'
.
format
(
idx
))
p4_2_p5_2
.
append
(
feat
)
biFPN
=
BiFPNCell
(
self
.
num_chan
,
self
.
levels
)
# BiFPN, repeated
biFPN
=
BiFPNCell
(
self
.
num_chan
,
self
.
levels
,
len
(
inputs
))
for
r
in
range
(
self
.
repeat
):
feats
=
biFPN
(
feats
,
'bifpn_{}'
.
format
(
r
))
if
r
==
0
:
feats
=
biFPN
(
feats
,
cell_name
=
'bifpn_{}'
.
format
(
r
),
is_first_time
=
True
,
p4_2_p5_2
=
p4_2_p5_2
)
else
:
feats
=
biFPN
(
feats
,
cell_name
=
'bifpn_{}'
.
format
(
r
))
return
feats
ppdet/modeling/backbones/efficientnet.py
浏览文件 @
8064bab9
...
...
@@ -28,12 +28,15 @@ __all__ = ['EfficientNet']
GlobalParams
=
collections
.
namedtuple
(
'GlobalParams'
,
[
'batch_norm_momentum'
,
'batch_norm_epsilon'
,
'width_coefficient'
,
'depth_coefficient'
,
'depth_divisor'
'depth_coefficient'
,
'depth_divisor'
,
'min_depth'
,
'drop_connect_rate'
,
'relu_fn'
,
'batch_norm'
,
'use_se'
,
'local_pooling'
,
'condconv_num_experts'
,
'clip_projection_output'
,
'blocks_args'
,
'fix_head_stem'
])
BlockArgs
=
collections
.
namedtuple
(
'BlockArgs'
,
[
'kernel_size'
,
'num_repeat'
,
'input_filters'
,
'output_filters'
,
'expand_ratio'
,
'stride'
,
'se_ratio'
'expand_ratio'
,
'id_skip'
,
'stride'
,
'se_ratio'
,
'conv_type'
,
'fused_conv'
,
'super_pixel'
,
'condconv'
])
GlobalParams
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
GlobalParams
.
_fields
)
...
...
@@ -51,8 +54,8 @@ def _decode_block_string(block_string):
key
,
value
=
splits
[:
2
]
options
[
key
]
=
value
assert
((
's'
in
options
and
len
(
options
[
's'
])
==
1
)
or
(
len
(
options
[
's'
])
==
2
and
options
[
's'
][
0
]
==
options
[
's'
][
1
])
)
if
's'
not
in
options
or
len
(
options
[
's'
])
!=
2
:
raise
ValueError
(
'Strides options should be a pair of integers.'
)
return
BlockArgs
(
kernel_size
=
int
(
options
[
'k'
]),
...
...
@@ -60,8 +63,13 @@ def _decode_block_string(block_string):
input_filters
=
int
(
options
[
'i'
]),
output_filters
=
int
(
options
[
'o'
]),
expand_ratio
=
int
(
options
[
'e'
]),
id_skip
=
(
'noskip'
not
in
block_string
),
se_ratio
=
float
(
options
[
'se'
])
if
'se'
in
options
else
None
,
stride
=
int
(
options
[
's'
][
0
]))
stride
=
int
(
options
[
's'
][
0
]),
conv_type
=
int
(
options
[
'c'
])
if
'c'
in
options
else
0
,
fused_conv
=
int
(
options
[
'f'
])
if
'f'
in
options
else
0
,
super_pixel
=
int
(
options
[
'p'
])
if
'p'
in
options
else
0
,
condconv
=
(
'cc'
in
block_string
))
def
get_model_params
(
scale
):
...
...
@@ -88,37 +96,47 @@ def get_model_params(scale):
'b5'
:
(
1.6
,
2.2
),
'b6'
:
(
1.8
,
2.6
),
'b7'
:
(
2.0
,
3.1
),
'l2'
:
(
4.3
,
5.3
),
}
w
,
d
=
params_dict
[
scale
]
global_params
=
GlobalParams
(
blocks_args
=
block_strings
,
batch_norm_momentum
=
0.99
,
batch_norm_epsilon
=
1e-3
,
drop_connect_rate
=
0
if
scale
==
'b0'
else
0.2
,
width_coefficient
=
w
,
depth_coefficient
=
d
,
depth_divisor
=
8
)
depth_divisor
=
8
,
min_depth
=
None
,
fix_head_stem
=
False
,
use_se
=
True
,
clip_projection_output
=
False
)
return
block_args
,
global_params
def
round_filters
(
filters
,
global_params
):
def
round_filters
(
filters
,
global_params
,
skip
=
False
):
"""Round number of filters based on depth multiplier."""
multiplier
=
global_params
.
width_coefficient
if
not
multiplier
:
return
filters
divisor
=
global_params
.
depth_divisor
min_depth
=
global_params
.
min_depth
if
skip
or
not
multiplier
:
return
filters
filters
*=
multiplier
min_depth
=
divisor
new_filters
=
max
(
min_depth
,
int
(
filters
+
divisor
/
2
)
//
divisor
*
divisor
)
min_depth
=
min_depth
or
divisor
new_filters
=
max
(
min_depth
,
int
(
filters
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_filters
<
0.9
*
filters
:
# prevent rounding by more than 10%
new_filters
+=
divisor
return
int
(
new_filters
)
def
round_repeats
(
repeats
,
global_params
):
def
round_repeats
(
repeats
,
global_params
,
skip
=
False
):
"""Round number of filters based on depth multiplier."""
multiplier
=
global_params
.
depth_coefficient
if
not
multiplier
:
if
skip
or
not
multiplier
:
return
repeats
return
int
(
math
.
ceil
(
multiplier
*
repeats
))
...
...
@@ -130,7 +148,8 @@ def conv2d(inputs,
padding
=
'SAME'
,
groups
=
1
,
use_bias
=
False
,
name
=
'conv2d'
):
name
=
'conv2d'
,
use_cudnn
=
True
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weights'
)
bias_attr
=
False
if
use_bias
:
...
...
@@ -145,7 +164,8 @@ def conv2d(inputs,
stride
=
stride
,
padding
=
padding
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
bias_attr
=
bias_attr
,
use_cudnn
=
use_cudnn
)
return
feats
...
...
@@ -163,6 +183,16 @@ def batch_norm(inputs, momentum, eps, name=None):
bias_attr
=
bias_attr
)
def
_drop_connect
(
inputs
,
prob
,
mode
):
if
mode
!=
'train'
:
return
inputs
keep_prob
=
1.0
-
prob
inputs_shape
=
fluid
.
layers
.
shape
(
inputs
)
random_tensor
=
keep_prob
+
fluid
.
layers
.
uniform_random
(
shape
=
[
inputs_shape
[
0
],
1
,
1
,
1
],
min
=
0.
,
max
=
1.
)
binary_tensor
=
fluid
.
layers
.
floor
(
random_tensor
)
output
=
inputs
/
keep_prob
*
binary_tensor
return
output
def
mb_conv_block
(
inputs
,
input_filters
,
output_filters
,
...
...
@@ -171,30 +201,37 @@ def mb_conv_block(inputs,
stride
,
momentum
,
eps
,
block_arg
,
drop_connect_rate
,
mode
,
se_ratio
=
None
,
name
=
None
):
feats
=
inputs
num_filters
=
input_filters
*
expand_ratio
# Expansion
if
expand_ratio
!=
1
:
feats
=
conv2d
(
feats
,
num_filters
,
1
,
name
=
name
+
'_expand_conv'
)
feats
=
batch_norm
(
feats
,
momentum
,
eps
,
name
=
name
+
'_bn0'
)
feats
=
fluid
.
layers
.
swish
(
feats
)
# Depthwise Convolution
feats
=
conv2d
(
feats
,
num_filters
,
kernel_size
,
stride
,
groups
=
num_filters
,
name
=
name
+
'_depthwise_conv'
)
name
=
name
+
'_depthwise_conv'
,
use_cudnn
=
False
)
feats
=
batch_norm
(
feats
,
momentum
,
eps
,
name
=
name
+
'_bn1'
)
feats
=
fluid
.
layers
.
swish
(
feats
)
# Squeeze and Excitation
if
se_ratio
is
not
None
:
filter_squeezed
=
max
(
1
,
int
(
input_filters
*
se_ratio
))
squeezed
=
fluid
.
layers
.
pool2d
(
feats
,
pool_type
=
'avg'
,
global_pooling
=
True
)
feats
,
pool_type
=
'avg'
,
global_pooling
=
True
,
use_cudnn
=
True
)
squeezed
=
conv2d
(
squeezed
,
filter_squeezed
,
...
...
@@ -206,10 +243,14 @@ def mb_conv_block(inputs,
squeezed
,
num_filters
,
1
,
use_bias
=
True
,
name
=
name
+
'_se_expand'
)
feats
=
feats
*
fluid
.
layers
.
sigmoid
(
squeezed
)
# Project_conv_norm
feats
=
conv2d
(
feats
,
output_filters
,
1
,
name
=
name
+
'_project_conv'
)
feats
=
batch_norm
(
feats
,
momentum
,
eps
,
name
=
name
+
'_bn2'
)
if
stride
==
1
and
input_filters
==
output_filters
:
# Skip connection and drop connect
if
block_arg
.
id_skip
and
block_arg
.
stride
==
1
and
input_filters
==
output_filters
:
if
drop_connect_rate
:
feats
=
_drop_connect
(
feats
,
drop_connect_rate
,
mode
)
feats
=
fluid
.
layers
.
elementwise_add
(
feats
,
inputs
)
return
feats
...
...
@@ -227,7 +268,10 @@ class EfficientNet(object):
"""
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
scale
=
'b0'
,
use_se
=
True
,
norm_type
=
'bn'
):
def
__init__
(
self
,
scale
=
'b0'
,
use_se
=
True
,
norm_type
=
'bn'
):
assert
scale
in
[
'b'
+
str
(
i
)
for
i
in
range
(
8
)],
\
"valid scales are b0 - b7"
assert
norm_type
in
[
'bn'
,
'sync_bn'
],
\
...
...
@@ -238,54 +282,80 @@ class EfficientNet(object):
self
.
scale
=
scale
self
.
use_se
=
use_se
def
__call__
(
self
,
inputs
):
def
__call__
(
self
,
inputs
,
mode
):
assert
mode
in
[
'train'
,
'test'
],
\
"only 'train' and 'test' mode are supported"
blocks_args
,
global_params
=
get_model_params
(
self
.
scale
)
momentum
=
global_params
.
batch_norm_momentum
eps
=
global_params
.
batch_norm_epsilon
num_filters
=
round_filters
(
32
,
global_params
)
feats
=
conv2d
(
inputs
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
2
,
name
=
'_conv_stem'
)
# Stem part.
num_filters
=
round_filters
(
blocks_args
[
0
].
input_filters
,
global_params
,
global_params
.
fix_head_stem
)
feats
=
conv2d
(
inputs
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
2
,
name
=
'_conv_stem'
)
feats
=
batch_norm
(
feats
,
momentum
=
momentum
,
eps
=
eps
,
name
=
'_bn0'
)
feats
=
fluid
.
layers
.
swish
(
feats
)
layer_count
=
0
# Builds blocks.
feature_maps
=
[]
for
b
,
block_arg
in
enumerate
(
blocks_args
):
for
r
in
range
(
block_arg
.
num_repeat
):
input_filters
=
round_filters
(
block_arg
.
input_filters
,
global_params
)
output_filters
=
round_filters
(
block_arg
.
output_filters
,
global_params
)
kernel_size
=
block_arg
.
kernel_size
stride
=
block_arg
.
stride
se_ratio
=
None
if
self
.
use_se
:
se_ratio
=
block_arg
.
se_ratio
if
r
>
0
:
input_filters
=
output_filters
stride
=
1
layer_count
=
0
num_blocks
=
sum
([
block_arg
.
num_repeat
for
block_arg
in
blocks_args
])
for
block_arg
in
blocks_args
:
# Update block input and output filters based on depth multiplier.
block_arg
=
block_arg
.
_replace
(
input_filters
=
round_filters
(
block_arg
.
input_filters
,
global_params
),
output_filters
=
round_filters
(
block_arg
.
output_filters
,
global_params
),
num_repeat
=
round_repeats
(
block_arg
.
num_repeat
,
global_params
))
# The first block needs to take care of stride,
# and filter size increase.
drop_connect_rate
=
global_params
.
drop_connect_rate
if
drop_connect_rate
:
drop_connect_rate
*=
float
(
layer_count
)
/
num_blocks
feats
=
mb_conv_block
(
feats
,
block_arg
.
input_filters
,
block_arg
.
output_filters
,
block_arg
.
expand_ratio
,
block_arg
.
kernel_size
,
block_arg
.
stride
,
momentum
,
eps
,
block_arg
,
drop_connect_rate
,
mode
,
se_ratio
=
block_arg
.
se_ratio
,
name
=
'_blocks.{}.'
.
format
(
layer_count
))
layer_count
+=
1
# Other block
if
block_arg
.
num_repeat
>
1
:
block_arg
=
block_arg
.
_replace
(
input_filters
=
block_arg
.
output_filters
,
stride
=
1
)
for
_
in
range
(
block_arg
.
num_repeat
-
1
):
drop_connect_rate
=
global_params
.
drop_connect_rate
if
drop_connect_rate
:
drop_connect_rate
*=
float
(
layer_count
)
/
num_blocks
feats
=
mb_conv_block
(
feats
,
input_filters
,
output_filters
,
block_arg
.
input_filters
,
block_arg
.
output_filters
,
block_arg
.
expand_ratio
,
kernel_size
,
stride
,
block_arg
.
kernel_size
,
block_arg
.
stride
,
momentum
,
eps
,
se_ratio
=
se_ratio
,
block_arg
,
drop_connect_rate
,
mode
,
se_ratio
=
block_arg
.
se_ratio
,
name
=
'_blocks.{}.'
.
format
(
layer_count
))
layer_count
+=
1
feature_maps
.
append
(
feats
)
return
list
(
feature_maps
[
i
]
for
i
in
[
2
,
4
,
6
])
return
list
(
feature_maps
[
i
]
for
i
in
[
2
,
4
,
6
])
# 1/8, 1/16, 1/32
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录