Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
204bcbdf
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
204bcbdf
编写于
10月 11, 2021
作者:
S
sucuicong
提交者:
GitHub
10月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
a version of ppyolo for EdgeBoard (#4243)
* a version of ppyolo for EdgeBoard * a version of ppyolo for EdgeBoard
上级
e83c3ecf
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
977 addition
and
4 deletion
+977
-4
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+7
-4
static/configs/ppyolo/ppyolo_eb.yml
static/configs/ppyolo/ppyolo_eb.yml
+74
-0
static/configs/ppyolo/ppyolo_eb_voc.yml
static/configs/ppyolo/ppyolo_eb_voc.yml
+103
-0
static/ppdet/modeling/anchor_heads/__init__.py
static/ppdet/modeling/anchor_heads/__init__.py
+2
-0
static/ppdet/modeling/anchor_heads/eb_head.py
static/ppdet/modeling/anchor_heads/eb_head.py
+349
-0
static/ppdet/modeling/backbones/__init__.py
static/ppdet/modeling/backbones/__init__.py
+2
-0
static/ppdet/modeling/backbones/resnet_eb.py
static/ppdet/modeling/backbones/resnet_eb.py
+440
-0
未找到文件。
ppdet/data/transform/operators.py
浏览文件 @
204bcbdf
...
...
@@ -167,7 +167,7 @@ class DecodeCache(BaseOperator):
'''
super
(
DecodeCache
,
self
).
__init__
()
self
.
use_cache
=
False
if
cache_root
is
None
else
True
self
.
use_cache
=
False
if
cache_root
is
None
else
True
self
.
cache_root
=
cache_root
if
cache_root
is
not
None
:
...
...
@@ -175,7 +175,8 @@ class DecodeCache(BaseOperator):
def
apply
(
self
,
sample
,
context
=
None
):
if
self
.
use_cache
and
os
.
path
.
exists
(
self
.
cache_path
(
self
.
cache_root
,
sample
[
'im_file'
])):
if
self
.
use_cache
and
os
.
path
.
exists
(
self
.
cache_path
(
self
.
cache_root
,
sample
[
'im_file'
])):
path
=
self
.
cache_path
(
self
.
cache_root
,
sample
[
'im_file'
])
im
=
self
.
load
(
path
)
...
...
@@ -191,7 +192,8 @@ class DecodeCache(BaseOperator):
sample
[
'ori_image'
]
=
im
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
if
self
.
use_cache
and
not
os
.
path
.
exists
(
self
.
cache_path
(
self
.
cache_root
,
sample
[
'im_file'
])):
if
self
.
use_cache
and
not
os
.
path
.
exists
(
self
.
cache_path
(
self
.
cache_root
,
sample
[
'im_file'
])):
path
=
self
.
cache_path
(
self
.
cache_root
,
sample
[
'im_file'
])
self
.
dump
(
im
,
path
)
...
...
@@ -212,7 +214,7 @@ class DecodeCache(BaseOperator):
def
load
(
path
):
with
open
(
path
,
'rb'
)
as
f
:
im
=
pickle
.
load
(
f
)
return
im
return
im
@
staticmethod
def
dump
(
obj
,
path
):
...
...
@@ -227,6 +229,7 @@ class DecodeCache(BaseOperator):
finally
:
MUTEX
.
release
()
@
register_op
class
Permute
(
BaseOperator
):
def
__init__
(
self
):
...
...
static/configs/ppyolo/ppyolo_eb.yml
0 → 100644
浏览文件 @
204bcbdf
architecture
:
YOLOv3
use_gpu
:
true
max_iters
:
500000
save_dir
:
output
snapshot_iter
:
10000
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_vd_pretrained.tar
weights
:
output/ppyolo_eb/best_model
num_classes
:
80
use_fine_grained_loss
:
true
log_iter
:
1000
use_ema
:
true
ema_decay
:
0.9998
YOLOv3
:
backbone
:
ResNet_EB
yolo_head
:
EBHead
ResNet_EB
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
34
variant
:
d
feature_maps
:
[
3
,
4
,
5
]
EBHead
:
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
norm_decay
:
0.
yolo_loss
:
YOLOv3Loss
nms
:
background_label
:
-1
keep_top_k
:
100
nms_threshold
:
0.45
nms_top_k
:
1000
normalized
:
false
score_threshold
:
0.01
YOLOv3Loss
:
ignore_thresh
:
0.7
label_smooth
:
false
use_fine_grained_loss
:
true
iou_loss
:
IouLoss
IouLoss
:
loss_weight
:
2.5
max_height
:
608
max_width
:
608
LearningRate
:
base_lr
:
0.001
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
320000
-
450000
-
!LinearWarmup
start_factor
:
0.
steps
:
4000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
ppyolo_reader.yml'
static/configs/ppyolo/ppyolo_eb_voc.yml
0 → 100644
浏览文件 @
204bcbdf
architecture
:
YOLOv3
use_gpu
:
true
max_iters
:
70000
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
3000
metric
:
VOC
map_type
:
integral
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet34_vd_pretrained.tar
weights
:
output/ppyolo_eb_voc/best_model
num_classes
:
20
use_fine_grained_loss
:
true
log_iter
:
1000
use_ema
:
true
ema_decay
:
0.9998
YOLOv3
:
backbone
:
ResNet_EB
yolo_head
:
EBHead
ResNet_EB
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
34
variant
:
d
feature_maps
:
[
3
,
4
,
5
]
EBHead
:
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
norm_decay
:
0.
yolo_loss
:
YOLOv3Loss
nms
:
background_label
:
-1
keep_top_k
:
100
nms_threshold
:
0.45
nms_top_k
:
1000
normalized
:
false
score_threshold
:
0.01
YOLOv3Loss
:
ignore_thresh
:
0.7
label_smooth
:
false
use_fine_grained_loss
:
true
iou_loss
:
IouLoss
IouLoss
:
loss_weight
:
2.5
max_height
:
608
max_width
:
608
LearningRate
:
base_lr
:
0.001
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
35000
-
60000
-
!LinearWarmup
start_factor
:
0.
steps
:
4000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
ppyolo_reader.yml'
TrainReader
:
dataset
:
!VOCDataSet
dataset_dir
:
dataset/voc
anno_path
:
trainval.txt
use_default_label
:
false
with_background
:
false
mixup_epoch
:
200
batch_size
:
8
EvalReader
:
inputs_def
:
image_shape
:
[
3
,
608
,
608
]
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
,
'
gt_bbox'
,
'
gt_class'
,
'
is_difficult'
]
num_max_boxes
:
50
dataset
:
!VOCDataSet
dataset_dir
:
dataset/voc
anno_path
:
test.txt
use_default_label
:
false
with_background
:
false
TestReader
:
dataset
:
!ImageFolder
use_default_label
:
false
with_background
:
false
static/ppdet/modeling/anchor_heads/__init__.py
浏览文件 @
204bcbdf
...
...
@@ -22,6 +22,7 @@ from . import corner_head
from
.
import
efficient_head
from
.
import
ttf_head
from
.
import
solov2_head
from
.
import
eb_head
from
.rpn_head
import
*
from
.yolo_head
import
*
...
...
@@ -31,3 +32,4 @@ from .corner_head import *
from
.efficient_head
import
*
from
.ttf_head
import
*
from
.solov2_head
import
*
from
.eb_head
import
*
static/ppdet/modeling/anchor_heads/eb_head.py
0 → 100644
浏览文件 @
204bcbdf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.modeling.ops
import
MultiClassNMS
from
ppdet.modeling.losses.yolo_loss
import
YOLOv3Loss
from
ppdet.core.workspace
import
register
__all__
=
[
'EBHead'
]
@
register
class
EBHead
(
object
):
"""
Head block for pp-yolo-eb, ppyolo for EdgeBoard : https://ai.baidu.com/ai-doc/HWCE/Yk3b86gvp
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
__inject__
=
[
'yolo_loss'
,
'nms'
]
__shared__
=
[
'num_classes'
,
'weight_prefix_name'
]
def
__init__
(
self
,
norm_decay
=
0.
,
num_classes
=
80
,
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]],
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]],
drop_block
=
False
,
block_size
=
3
,
keep_prob
=
0.9
,
yolo_loss
=
"YOLOv3Loss"
,
nms
=
MultiClassNMS
(
score_threshold
=
0.01
,
nms_top_k
=
1000
,
keep_top_k
=
100
,
nms_threshold
=
0.45
,
background_label
=-
1
).
__dict__
,
weight_prefix_name
=
''
):
self
.
norm_decay
=
norm_decay
self
.
num_classes
=
num_classes
self
.
anchor_masks
=
anchor_masks
self
.
_parse_anchors
(
anchors
)
self
.
yolo_loss
=
yolo_loss
self
.
nms
=
nms
self
.
prefix_name
=
weight_prefix_name
self
.
drop_block
=
drop_block
self
.
block_size
=
block_size
self
.
keep_prob
=
keep_prob
if
isinstance
(
nms
,
dict
):
self
.
nms
=
MultiClassNMS
(
**
nms
)
def
_conv_bn
(
self
,
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'leaky'
,
is_test
=
True
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
".conv.weights"
),
bias_attr
=
False
)
bn_name
=
name
+
".bn"
bn_param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
self
.
norm_decay
),
name
=
bn_name
+
'.scale'
)
bn_bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
self
.
norm_decay
),
name
=
bn_name
+
'.offset'
)
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
is_test
=
is_test
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
moving_variance_name
=
bn_name
+
'.var'
)
if
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
def
_detection_block
(
self
,
input
,
channel
,
is_test
=
True
,
name
=
None
):
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2 in detection block {}"
\
.
format
(
channel
,
name
)
conv
=
input
conv
=
self
.
_conv_bn
(
conv
,
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
,
name
=
'{}.0'
.
format
(
name
))
for
j
in
range
(
4
):
conv
=
self
.
_conv_bn
(
conv
,
channel
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
,
name
=
'{}.{}.1'
.
format
(
name
,
j
))
if
j
==
1
:
route
=
conv
return
route
,
conv
def
_upsample
(
self
,
input
,
scale
=
2
,
name
=
None
):
out
=
fluid
.
layers
.
resize_nearest
(
input
=
input
,
scale
=
float
(
scale
),
name
=
name
)
return
out
def
_pool_concat
(
self
,
input
):
pool1
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
pool2
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'avg'
)
out
=
fluid
.
layers
.
concat
(
input
=
[
pool1
,
pool2
],
axis
=
1
)
return
out
def
_parse_anchors
(
self
,
anchors
):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self
.
anchors
=
[]
self
.
mask_anchors
=
[]
assert
len
(
anchors
)
>
0
,
"ANCHORS not set."
assert
len
(
self
.
anchor_masks
)
>
0
,
"ANCHOR_MASKS not set."
for
anchor
in
anchors
:
assert
len
(
anchor
)
==
2
,
"anchor {} len should be 2"
.
format
(
anchor
)
self
.
anchors
.
extend
(
anchor
)
anchor_num
=
len
(
anchors
)
for
masks
in
self
.
anchor_masks
:
self
.
mask_anchors
.
append
([])
for
mask
in
masks
:
assert
mask
<
anchor_num
,
"anchor mask index overflow"
self
.
mask_anchors
[
-
1
].
extend
(
anchors
[
mask
])
def
_get_outputs
(
self
,
input
,
is_train
=
True
):
"""
Get ppyolo_eb head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs
=
[]
# get last out_layer_num blocks in reverse order
out_layer_num
=
len
(
self
.
anchor_masks
)
blocks
=
input
[
-
1
:
-
out_layer_num
-
1
:
-
1
]
filters_num1
=
blocks
[
1
].
shape
[
1
]
//
2
blk0
=
self
.
_pool_concat
(
blocks
[
2
])
blk0
=
self
.
_conv_bn
(
blk0
,
filters_num1
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
False
,
name
=
'channel_fusion_1'
)
blk1
=
fluid
.
layers
.
concat
(
input
=
[
blk0
,
blocks
[
1
]],
axis
=
1
)
filters_num2
=
blocks
[
0
].
shape
[
1
]
//
2
blk
=
self
.
_conv_bn
(
blk1
,
filters_num2
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
False
,
name
=
'channel_fusion_2'
)
blk2
=
self
.
_conv_bn
(
blk
,
filters_num2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
False
,
name
=
'feature_fusion'
)
blk2
=
self
.
_pool_concat
(
blk2
)
blk2
=
self
.
_conv_bn
(
blk2
,
filters_num2
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
False
,
name
=
'channel_fusion_3'
)
blk3
=
fluid
.
layers
.
concat
(
input
=
[
blk2
,
blocks
[
0
]],
axis
=
1
)
blocks
=
[
blk3
,
blk1
,
blocks
[
2
]]
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
if
i
>
0
:
# perform concat in first 2 detection_block
block
=
fluid
.
layers
.
concat
(
input
=
[
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
_detection_block
(
block
,
channel
=
512
//
(
2
**
i
),
is_test
=
(
not
is_train
),
name
=
self
.
prefix_name
+
"yolo_block.{}"
.
format
(
i
))
# out channel number = mask_num * (5 + class_num)
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
5
)
with
fluid
.
name_scope
(
'yolo_output'
):
block_out
=
fluid
.
layers
.
conv2d
(
input
=
tip
,
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.weights"
.
format
(
i
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
),
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.bias"
.
format
(
i
)))
outputs
.
append
(
block_out
)
if
i
<
len
(
blocks
)
-
1
:
# do not perform upsample in the last detection_block
route
=
self
.
_conv_bn
(
input
=
route
,
ch_out
=
256
//
(
2
**
i
),
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
(
not
is_train
),
name
=
self
.
prefix_name
+
"yolo_transition.{}"
.
format
(
i
))
# upsample
route
=
self
.
_upsample
(
route
)
return
outputs
def
get_loss
(
self
,
input
,
gt_box
,
gt_label
,
gt_score
,
targets
):
"""
Get final loss of network of ppyolo_eb.
Args:
input (list): List of Variables, output of backbone stages
gt_box (Variable): The ground-truth boudding boxes.
gt_label (Variable): The ground-truth class labels.
gt_score (Variable): The ground-truth boudding boxes mixup scores.
targets ([Variables]): List of Variables, the targets for yolo
loss calculatation.
Returns:
loss (Variable): The loss Variable of ppyolo_eb network.
"""
outputs
=
self
.
_get_outputs
(
input
,
is_train
=
True
)
return
self
.
yolo_loss
(
outputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
self
.
anchors
,
self
.
anchor_masks
,
self
.
mask_anchors
,
self
.
num_classes
,
self
.
prefix_name
)
def
get_prediction
(
self
,
input
,
im_size
,
exclude_nms
=
False
):
"""
Get prediction result of ppyolo_eb network
Args:
input (list): List of Variables, output of backbone stages
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
outputs
=
self
.
_get_outputs
(
input
,
is_train
=
False
)
boxes
=
[]
scores
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
box
,
score
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
im_size
,
anchors
=
self
.
mask_anchors
[
i
],
class_num
=
self
.
num_classes
,
conf_thresh
=
self
.
nms
.
score_threshold
,
downsample_ratio
=
downsample
,
name
=
self
.
prefix_name
+
"yolo_box"
+
str
(
i
))
boxes
.
append
(
box
)
scores
.
append
(
fluid
.
layers
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
downsample
//=
2
yolo_boxes
=
fluid
.
layers
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
scores
,
axis
=
2
)
# Only for benchmark, postprocess(NMS) is not needed
if
exclude_nms
:
return
{
'bbox'
:
yolo_boxes
,
'score'
:
yolo_scores
}
pred
=
self
.
nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
)
return
{
'bbox'
:
pred
}
static/ppdet/modeling/backbones/__init__.py
浏览文件 @
204bcbdf
...
...
@@ -35,6 +35,7 @@ from . import bifpn
from
.
import
cspdarknet
from
.
import
acfpn
from
.
import
ghostnet
from
.
import
resnet_eb
from
.resnet
import
*
from
.resnext
import
*
...
...
@@ -57,3 +58,4 @@ from .bifpn import *
from
.cspdarknet
import
*
from
.acfpn
import
*
from
.ghostnet
import
*
from
.resnet_eb
import
*
static/ppdet/modeling/backbones/resnet_eb.py
0 → 100644
浏览文件 @
204bcbdf
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.initializer
import
Constant
from
ppdet.core.workspace
import
register
,
serializable
from
numbers
import
Integral
from
.name_adapter
import
NameAdapter
__all__
=
[
'ResNet_EB'
]
@
register
@
serializable
class
ResNet_EB
(
object
):
"""
modified ResNet, especially for EdgeBoard: https://ai.baidu.com/ai-doc/HWCE/Yk3b86gvp
"""
__shared__
=
[
'norm_type'
,
'freeze_norm'
,
'weight_prefix_name'
]
def
__init__
(
self
,
depth
=
50
,
freeze_at
=
2
,
norm_type
=
'affine_channel'
,
freeze_norm
=
True
,
norm_decay
=
0.
,
variant
=
'b'
,
feature_maps
=
[
2
,
3
,
4
,
5
],
weight_prefix_name
=
''
,
lr_mult_list
=
[
1.
,
1.
,
1.
,
1.
]):
super
(
ResNet_EB
,
self
).
__init__
()
if
isinstance
(
feature_maps
,
Integral
):
feature_maps
=
[
feature_maps
]
assert
depth
in
[
18
,
34
,
50
,
101
,
152
,
200
],
\
"depth {} not in [18, 34, 50, 101, 152, 200]"
assert
variant
in
[
'a'
,
'b'
,
'c'
,
'd'
],
"invalid ResNet variant"
assert
0
<=
freeze_at
<=
4
,
"freeze_at should be 0, 1, 2, 3 or 4"
assert
len
(
feature_maps
)
>
0
,
"need one or more feature maps"
assert
norm_type
in
[
'bn'
,
'sync_bn'
,
'affine_channel'
]
assert
len
(
lr_mult_list
)
==
4
,
"lr_mult_list length must be 4 but got {}"
.
format
(
len
(
lr_mult_list
))
self
.
depth
=
depth
self
.
freeze_at
=
freeze_at
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
freeze_norm
=
freeze_norm
self
.
variant
=
variant
self
.
_model_type
=
'ResNet'
self
.
feature_maps
=
feature_maps
self
.
depth_cfg
=
{
18
:
([
2
,
2
,
2
,
2
],
self
.
basicblock
),
34
:
([
3
,
4
,
6
,
3
],
self
.
basicblock
),
50
:
([
3
,
4
,
6
,
3
],
self
.
bottleneck
),
101
:
([
3
,
4
,
23
,
3
],
self
.
bottleneck
),
152
:
([
3
,
8
,
36
,
3
],
self
.
bottleneck
),
200
:
([
3
,
12
,
48
,
3
],
self
.
bottleneck
),
}
self
.
stage_filters
=
[
64
,
128
,
256
,
512
]
self
.
_c1_out_chan_num
=
64
self
.
na
=
NameAdapter
(
self
)
self
.
prefix_name
=
weight_prefix_name
self
.
lr_mult_list
=
lr_mult_list
# var denoting curr stage
self
.
stage_num
=
-
1
def
_conv_offset
(
self
,
input
,
filter_size
,
stride
,
padding
,
act
=
None
,
name
=
None
):
out_channel
=
filter_size
*
filter_size
*
3
out
=
fluid
.
layers
.
conv2d
(
input
,
num_filters
=
out_channel
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
),
name
=
name
+
".w_0"
),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
),
name
=
name
+
".b_0"
),
act
=
act
,
name
=
name
)
return
out
def
_conv_norm
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
,
dcn_v2
=
False
):
_name
=
self
.
prefix_name
+
name
if
self
.
prefix_name
!=
''
else
name
# need fine lr for distilled model, default as 1.0
lr_mult
=
1.0
mult_idx
=
max
(
self
.
stage_num
-
2
,
0
)
mult_idx
=
min
(
self
.
stage_num
-
2
,
3
)
lr_mult
=
self
.
lr_mult_list
[
mult_idx
]
if
not
dcn_v2
:
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
_name
+
"_weights"
,
learning_rate
=
lr_mult
),
bias_attr
=
False
,
name
=
_name
+
'.conv2d.output.1'
)
else
:
# select deformable conv"
offset_mask
=
self
.
_conv_offset
(
input
=
input
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
act
=
None
,
name
=
_name
+
"_conv_offset"
)
offset_channel
=
filter_size
**
2
*
2
mask_channel
=
filter_size
**
2
offset
,
mask
=
fluid
.
layers
.
split
(
input
=
offset_mask
,
num_or_sections
=
[
offset_channel
,
mask_channel
],
dim
=
1
)
mask
=
fluid
.
layers
.
sigmoid
(
mask
)
conv
=
fluid
.
layers
.
deformable_conv
(
input
=
input
,
offset
=
offset
,
mask
=
mask
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
deformable_groups
=
1
,
im2col_step
=
1
,
param_attr
=
ParamAttr
(
name
=
_name
+
"_weights"
,
learning_rate
=
lr_mult
),
bias_attr
=
False
,
name
=
_name
+
".conv2d.output.1"
)
bn_name
=
self
.
na
.
fix_conv_norm_name
(
name
)
bn_name
=
self
.
prefix_name
+
bn_name
if
self
.
prefix_name
!=
''
else
bn_name
norm_lr
=
0.
if
self
.
freeze_norm
else
lr_mult
norm_decay
=
self
.
norm_decay
pattr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
,
learning_rate
=
norm_lr
,
regularizer
=
L2Decay
(
norm_decay
))
battr
=
ParamAttr
(
name
=
bn_name
+
'_offset'
,
learning_rate
=
norm_lr
,
regularizer
=
L2Decay
(
norm_decay
))
if
self
.
norm_type
in
[
'bn'
,
'sync_bn'
]:
global_stats
=
True
if
self
.
freeze_norm
else
False
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
bn_name
+
'.output.1'
,
param_attr
=
pattr
,
bias_attr
=
battr
,
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
use_global_stats
=
global_stats
)
scale
=
fluid
.
framework
.
_get_var
(
pattr
.
name
)
bias
=
fluid
.
framework
.
_get_var
(
battr
.
name
)
elif
self
.
norm_type
==
'affine_channel'
:
scale
=
fluid
.
layers
.
create_parameter
(
shape
=
[
conv
.
shape
[
1
]],
dtype
=
conv
.
dtype
,
attr
=
pattr
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
conv
.
shape
[
1
]],
dtype
=
conv
.
dtype
,
attr
=
battr
,
default_initializer
=
fluid
.
initializer
.
Constant
(
0.
))
out
=
fluid
.
layers
.
affine_channel
(
x
=
conv
,
scale
=
scale
,
bias
=
bias
,
act
=
act
)
if
self
.
freeze_norm
:
scale
.
stop_gradient
=
True
bias
.
stop_gradient
=
True
return
out
def
_shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
max_pooling_in_short_cut
=
self
.
variant
==
'd'
ch_in
=
input
.
shape
[
1
]
# the naming rule is same as pretrained weight
name
=
self
.
na
.
fix_shortcut_name
(
name
)
std_senet
=
getattr
(
self
,
'std_senet'
,
False
)
if
ch_in
!=
ch_out
or
stride
!=
1
or
(
self
.
depth
<
50
and
is_first
):
if
std_senet
:
if
is_first
:
return
self
.
_conv_norm
(
input
,
ch_out
,
1
,
stride
,
name
=
name
)
else
:
return
self
.
_conv_norm
(
input
,
ch_out
,
3
,
stride
,
name
=
name
)
if
max_pooling_in_short_cut
and
not
is_first
:
input1
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
2
,
pool_stride
=
2
,
pool_padding
=
0
,
ceil_mode
=
True
,
pool_type
=
'max'
)
input2
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
2
,
pool_stride
=
2
,
pool_padding
=
0
,
ceil_mode
=
True
,
pool_type
=
'avg'
)
input
=
fluid
.
layers
.
elementwise_add
(
x
=
input1
,
y
=
input2
,
name
=
name
+
".pool.add"
)
return
self
.
_conv_norm
(
input
,
ch_out
,
1
,
1
,
name
=
name
)
return
self
.
_conv_norm
(
input
,
ch_out
,
1
,
stride
,
name
=
name
)
else
:
return
input
def
bottleneck
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
,
dcn_v2
=
False
,
gcb
=
False
,
gcb_name
=
None
):
assert
dcn_v2
is
False
,
"Not implemented in EdgeBoard yet."
assert
gcb
is
False
,
"Not implemented in EdgeBoard yet."
if
self
.
variant
==
'a'
:
stride1
,
stride2
=
stride
,
1
else
:
stride1
,
stride2
=
1
,
stride
# ResNeXt
groups
=
getattr
(
self
,
'groups'
,
1
)
group_width
=
getattr
(
self
,
'group_width'
,
-
1
)
if
groups
==
1
:
expand
=
4
elif
(
groups
*
group_width
)
==
256
:
expand
=
1
else
:
# FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters
=
num_filters
//
2
expand
=
2
conv_name1
,
conv_name2
,
conv_name3
,
\
shortcut_name
=
self
.
na
.
fix_bottleneck_name
(
name
)
std_senet
=
getattr
(
self
,
'std_senet'
,
False
)
if
std_senet
:
conv_def
=
[
[
int
(
num_filters
/
2
),
1
,
stride1
,
'relu'
,
1
,
conv_name1
],
[
num_filters
,
3
,
stride2
,
'relu'
,
groups
,
conv_name2
],
[
num_filters
*
expand
,
1
,
1
,
None
,
1
,
conv_name3
]
]
else
:
conv_def
=
[[
num_filters
,
1
,
stride1
,
'relu'
,
1
,
conv_name1
],
[
num_filters
,
3
,
stride2
,
'relu'
,
groups
,
conv_name2
],
[
num_filters
*
expand
,
1
,
1
,
None
,
1
,
conv_name3
]]
residual
=
input
for
i
,
(
c
,
k
,
s
,
act
,
g
,
_name
)
in
enumerate
(
conv_def
):
residual
=
self
.
_conv_norm
(
input
=
residual
,
num_filters
=
c
,
filter_size
=
k
,
stride
=
s
,
act
=
act
,
groups
=
g
,
name
=
_name
,
dcn_v2
=
False
)
short
=
self
.
_shortcut
(
input
,
num_filters
*
expand
,
stride
,
is_first
=
is_first
,
name
=
shortcut_name
)
# Squeeze-and-Excitation
if
callable
(
getattr
(
self
,
'_squeeze_excitation'
,
None
)):
residual
=
self
.
_squeeze_excitation
(
input
=
residual
,
num_channels
=
num_filters
,
name
=
'fc'
+
name
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
residual
,
act
=
'relu'
,
name
=
name
+
".add.output.5"
)
def
basicblock
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
,
dcn_v2
=
False
,
gcb
=
False
,
gcb_name
=
None
):
assert
dcn_v2
is
False
,
"Not implemented in EdgeBoard yet."
assert
gcb
is
False
,
"Not implemented EdgeBoard yet."
conv0
=
self
.
_conv_norm
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
name
=
name
+
"_branch2a"
)
conv1
=
self
.
_conv_norm
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
short
=
self
.
_shortcut
(
input
,
num_filters
,
stride
,
is_first
,
name
=
name
+
"_branch1"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
def
layer_warp
(
self
,
input
,
stage_num
):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert
stage_num
in
[
2
,
3
,
4
,
5
]
self
.
stage_num
=
stage_num
stages
,
block_func
=
self
.
depth_cfg
[
self
.
depth
]
count
=
stages
[
stage_num
-
2
]
ch_out
=
self
.
stage_filters
[
stage_num
-
2
]
is_first
=
False
if
stage_num
!=
2
else
True
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv
=
input
for
i
in
range
(
count
):
conv_name
=
self
.
na
.
fix_layer_warp_name
(
stage_num
,
count
,
i
)
if
self
.
depth
<
50
:
is_first
=
True
if
i
==
0
and
stage_num
==
2
else
False
conv
=
block_func
(
input
=
conv
,
num_filters
=
ch_out
,
stride
=
2
if
i
==
0
and
stage_num
!=
2
else
1
,
is_first
=
is_first
,
name
=
conv_name
,
dcn_v2
=
False
,
gcb
=
False
,
gcb_name
=
None
)
return
conv
def
c1_stage
(
self
,
input
):
out_chan
=
self
.
_c1_out_chan_num
conv1_name
=
self
.
na
.
fix_c1_stage_name
()
if
self
.
variant
in
[
'c'
,
'd'
]:
conv_def
=
[
[
out_chan
//
2
,
3
,
2
,
"conv1_1"
],
[
out_chan
//
2
,
3
,
1
,
"conv1_2"
],
[
out_chan
,
3
,
1
,
"conv1_3"
],
]
else
:
conv_def
=
[[
out_chan
,
7
,
2
,
conv1_name
]]
for
(
c
,
k
,
s
,
_name
)
in
conv_def
:
input
=
self
.
_conv_norm
(
input
=
input
,
num_filters
=
c
,
filter_size
=
k
,
stride
=
s
,
act
=
'relu'
,
name
=
_name
)
output
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
return
output
def
__call__
(
self
,
input
):
assert
isinstance
(
input
,
Variable
)
assert
not
(
set
(
self
.
feature_maps
)
-
set
([
2
,
3
,
4
,
5
])),
\
"feature maps {} not in [2, 3, 4, 5]"
.
format
(
self
.
feature_maps
)
res_endpoints
=
[]
res
=
input
feature_maps
=
self
.
feature_maps
severed_head
=
getattr
(
self
,
'severed_head'
,
False
)
if
not
severed_head
:
res
=
self
.
c1_stage
(
res
)
feature_maps
=
range
(
2
,
max
(
self
.
feature_maps
)
+
1
)
for
i
in
feature_maps
:
res
=
self
.
layer_warp
(
res
,
i
)
if
i
in
self
.
feature_maps
:
res_endpoints
.
append
(
res
)
if
self
.
freeze_at
>=
i
:
res
.
stop_gradient
=
True
return
OrderedDict
([(
'res{}_sum'
.
format
(
self
.
feature_maps
[
idx
]),
feat
)
for
idx
,
feat
in
enumerate
(
res_endpoints
)])
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录