Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
0e228b11
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0e228b11
编写于
4月 28, 2020
作者:
Y
Yang Zhang
提交者:
GitHub
4月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Initial implementation of `EfficientDet` (#492)
上级
521a4a6a
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
1368 addition
and
6 deletion
+1368
-6
configs/efficientdet_d0.yml
configs/efficientdet_d0.yml
+157
-0
docs/MODEL_ZOO.md
docs/MODEL_ZOO.md
+8
-0
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+200
-0
ppdet/modeling/anchor_heads/__init__.py
ppdet/modeling/anchor_heads/__init__.py
+2
-0
ppdet/modeling/anchor_heads/efficient_head.py
ppdet/modeling/anchor_heads/efficient_head.py
+189
-0
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/efficientdet.py
ppdet/modeling/architectures/efficientdet.py
+150
-0
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+4
-0
ppdet/modeling/backbones/bifpn.py
ppdet/modeling/backbones/bifpn.py
+202
-0
ppdet/modeling/backbones/efficientnet.py
ppdet/modeling/backbones/efficientnet.py
+291
-0
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+95
-5
ppdet/optimizer.py
ppdet/optimizer.py
+49
-0
tools/train.py
tools/train.py
+19
-1
未找到文件。
configs/efficientdet_d0.yml
0 → 100644
浏览文件 @
0e228b11
architecture
:
EfficientDet
max_iters
:
281250
use_gpu
:
true
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB0_pretrained.tar
weights
:
output/efficientdet_d0/model_final
log_smooth_window
:
20
snapshot_iter
:
10000
metric
:
COCO
save_dir
:
output
num_classes
:
81
use_ema
:
true
ema_decay
:
0.9998
EfficientDet
:
backbone
:
EfficientNet
fpn
:
BiFPN
efficient_head
:
EfficientHead
anchor_grid
:
AnchorGrid
box_loss_weight
:
50.
EfficientNet
:
# norm_type: sync_bn
# TODO
norm_type
:
bn
scale
:
b0
use_se
:
true
BiFPN
:
num_chan
:
64
repeat
:
3
levels
:
5
EfficientHead
:
repeat
:
3
num_chan
:
64
prior_prob
:
0.01
num_anchors
:
9
gamma
:
1.5
alpha
:
0.25
delta
:
0.1
output_decoder
:
score_thresh
:
0.05
# originally 0.
nms_thresh
:
0.5
pre_nms_top_n
:
1000
# originally 5000
detections_per_im
:
100
nms_eta
:
1.0
AnchorGrid
:
anchor_base_scale
:
4
num_scales
:
3
aspect_ratios
:
[[
1
,
1
],
[
1.4
,
0.7
],
[
0.7
,
1.4
]]
LearningRate
:
base_lr
:
0.16
schedulers
:
-
!CosineDecayWithSkip
total_steps
:
281250
skip_steps
:
938
-
!LinearWarmup
start_factor
:
0.05
steps
:
938
OptimizerBuilder
:
clip_grad_by_norm
:
10.
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.00004
type
:
L2
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_id'
,
'
fg_num'
,
'
gt_label'
,
'
gt_target'
]
dataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!RandomFlipImage
prob
:
0.5
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!RandomScaledCrop
target_dim
:
512
scale_range
:
[
.1
,
2.
]
interp
:
1
-
!Permute
to_bgr
:
false
channel_first
:
true
-
!TargetAssign
image_size
:
512
batch_size
:
16
shuffle
:
true
worker_num
:
32
bufsize
:
16
use_process
:
true
drop_empty
:
false
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
with_mixup
:
false
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!ResizeAndPad
target_dim
:
512
interp
:
1
-
!Permute
channel_first
:
true
to_bgr
:
false
drop_empty
:
false
batch_size
:
16
shuffle
:
false
worker_num
:
2
TestReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
image_shape
:
[
3
,
512
,
512
]
dataset
:
!ImageFolder
anno_path
:
annotations/instances_val2017.json
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
with_mixup
:
false
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
true
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
-
!ResizeAndPad
target_dim
:
512
interp
:
1
-
!Permute
channel_first
:
true
to_bgr
:
false
batch_size
:
16
shuffle
:
false
docs/MODEL_ZOO.md
浏览文件 @
0e228b11
...
@@ -181,6 +181,14 @@ results of image size 608/416/320 above. Deformable conv is added on stage 5 of
...
@@ -181,6 +181,14 @@ results of image size 608/416/320 above. Deformable conv is added on stage 5 of
**Notes:**
In RetinaNet, the base LR is changed to 0.01 for minibatch size 16.
**Notes:**
In RetinaNet, the base LR is changed to 0.01 for minibatch size 16.
### EfficientDet
| Scale | Image/gpu | Lr schd | Box AP | Download |
| :---------------: | :-----: | :-----: | :----: | :-------: |
| EfficientDet-D0 | 16 | 300 epochs | 33.8 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/efficientdet_d0.pdparams
)
|
**Notes:**
base LR is 0.16 for minibatch size 128 (8x16).
### SSDLite
### SSDLite
| Backbone | Size | Image/gpu | Lr schd | Inf time (fps) | Box AP | Download | Configs |
| Backbone | Size | Image/gpu | Lr schd | Inf time (fps) | Box AP | Download | Configs |
...
...
ppdet/data/transform/operators.py
浏览文件 @
0e228b11
...
@@ -37,6 +37,7 @@ import cv2
...
@@ -37,6 +37,7 @@ import cv2
from
PIL
import
Image
,
ImageEnhance
from
PIL
import
Image
,
ImageEnhance
from
ppdet.core.workspace
import
serializable
from
ppdet.core.workspace
import
serializable
from
ppdet.modeling.ops
import
AnchorGrid
from
.op_helper
import
(
satisfy_sample_constraint
,
filter_and_process
,
from
.op_helper
import
(
satisfy_sample_constraint
,
filter_and_process
,
generate_sample_bbox
,
clip_bbox
,
data_anchor_sampling
,
generate_sample_bbox
,
clip_bbox
,
data_anchor_sampling
,
...
@@ -1971,3 +1972,202 @@ class CornerRatio(BaseOperator):
...
@@ -1971,3 +1972,202 @@ class CornerRatio(BaseOperator):
sample
[
'ratios'
]
=
np
.
array
([
height_ratio
,
width_ratio
])
sample
[
'ratios'
]
=
np
.
array
([
height_ratio
,
width_ratio
])
return
sample
return
sample
@
register_op
class
RandomScaledCrop
(
BaseOperator
):
"""Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size.
Args:
target_dim (int): target size.
scale_range (list): random scale range.
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
"""
def
__init__
(
self
,
target_dim
=
512
,
scale_range
=
[.
1
,
2.
],
interp
=
cv2
.
INTER_LINEAR
):
super
(
RandomScaledCrop
,
self
).
__init__
()
self
.
target_dim
=
target_dim
self
.
scale_range
=
scale_range
self
.
interp
=
interp
def
__call__
(
self
,
sample
,
context
=
None
):
w
=
sample
[
'w'
]
h
=
sample
[
'h'
]
random_scale
=
np
.
random
.
uniform
(
*
self
.
scale_range
)
dim
=
self
.
target_dim
random_dim
=
int
(
dim
*
random_scale
)
dim_max
=
max
(
h
,
w
)
scale
=
random_dim
/
dim_max
resize_w
=
int
(
round
(
w
*
scale
))
resize_h
=
int
(
round
(
h
*
scale
))
offset_x
=
int
(
max
(
0
,
np
.
random
.
uniform
(
0.
,
resize_w
-
dim
)))
offset_y
=
int
(
max
(
0
,
np
.
random
.
uniform
(
0.
,
resize_h
-
dim
)))
if
'gt_bbox'
in
sample
and
len
(
sample
[
'gt_bbox'
])
>
0
:
scale_array
=
np
.
array
([
scale
,
scale
]
*
2
,
dtype
=
np
.
float32
)
shift_array
=
np
.
array
([
offset_x
,
offset_y
]
*
2
,
dtype
=
np
.
float32
)
boxes
=
sample
[
'gt_bbox'
]
*
scale_array
-
shift_array
boxes
=
np
.
clip
(
boxes
,
0
,
dim
-
1
)
# filter boxes with no area
area
=
np
.
prod
(
boxes
[...,
2
:]
-
boxes
[...,
:
2
],
axis
=
1
)
valid
=
(
area
>
1.
).
nonzero
()[
0
]
sample
[
'gt_bbox'
]
=
boxes
[
valid
]
sample
[
'gt_class'
]
=
sample
[
'gt_class'
][
valid
]
img
=
sample
[
'image'
]
img
=
cv2
.
resize
(
img
,
(
resize_w
,
resize_h
),
interpolation
=
self
.
interp
)
img
=
np
.
array
(
img
)
canvas
=
np
.
zeros
((
dim
,
dim
,
3
),
dtype
=
img
.
dtype
)
canvas
[:
min
(
dim
,
resize_h
),
:
min
(
dim
,
resize_w
),
:]
=
img
[
offset_y
:
offset_y
+
dim
,
offset_x
:
offset_x
+
dim
,
:]
sample
[
'h'
]
=
dim
sample
[
'w'
]
=
dim
sample
[
'image'
]
=
canvas
sample
[
'im_info'
]
=
[
resize_h
,
resize_w
,
scale
]
return
sample
@
register_op
class
ResizeAndPad
(
BaseOperator
):
"""Resize image and bbox, then pad image to target size.
Args:
target_dim (int): target size
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
"""
def
__init__
(
self
,
target_dim
=
512
,
interp
=
cv2
.
INTER_LINEAR
):
super
(
ResizeAndPad
,
self
).
__init__
()
self
.
target_dim
=
target_dim
self
.
interp
=
interp
def
__call__
(
self
,
sample
,
context
=
None
):
w
=
sample
[
'w'
]
h
=
sample
[
'h'
]
interp
=
self
.
interp
dim
=
self
.
target_dim
dim_max
=
max
(
h
,
w
)
scale
=
self
.
target_dim
/
dim_max
resize_w
=
int
(
round
(
w
*
scale
))
resize_h
=
int
(
round
(
h
*
scale
))
if
'gt_bbox'
in
sample
and
len
(
sample
[
'gt_bbox'
])
>
0
:
scale_array
=
np
.
array
([
scale
,
scale
]
*
2
,
dtype
=
np
.
float32
)
sample
[
'gt_bbox'
]
=
np
.
clip
(
sample
[
'gt_bbox'
]
*
scale_array
,
0
,
dim
-
1
)
img
=
sample
[
'image'
]
img
=
cv2
.
resize
(
img
,
(
resize_w
,
resize_h
),
interpolation
=
interp
)
img
=
np
.
array
(
img
)
canvas
=
np
.
zeros
((
dim
,
dim
,
3
),
dtype
=
img
.
dtype
)
canvas
[:
resize_h
,
:
resize_w
,
:]
=
img
sample
[
'h'
]
=
dim
sample
[
'w'
]
=
dim
sample
[
'image'
]
=
canvas
sample
[
'im_info'
]
=
[
resize_h
,
resize_w
,
scale
]
return
sample
@
register_op
class
TargetAssign
(
BaseOperator
):
"""Assign regression target and labels.
Args:
image_size (int or list): input image size, a single integer or list of
[h, w]. Default: 512
min_level (int): min level of the feature pyramid. Default: 3
max_level (int): max level of the feature pyramid. Default: 7
anchor_base_scale (int): base anchor scale. Default: 4
num_scales (int): number of anchor scales. Default: 3
aspect_ratios (list): aspect ratios.
Default: [(1, 1), (1.4, 0.7), (0.7, 1.4)]
match_threshold (float): threshold for foreground IoU. Default: 0.5
"""
def
__init__
(
self
,
image_size
=
512
,
min_level
=
3
,
max_level
=
7
,
anchor_base_scale
=
4
,
num_scales
=
3
,
aspect_ratios
=
[(
1
,
1
),
(
1.4
,
0.7
),
(
0.7
,
1.4
)],
match_threshold
=
0.5
):
super
(
TargetAssign
,
self
).
__init__
()
assert
image_size
%
2
**
max_level
==
0
,
\
"image size should be multiple of the max level stride"
self
.
image_size
=
image_size
self
.
min_level
=
min_level
self
.
max_level
=
max_level
self
.
anchor_base_scale
=
anchor_base_scale
self
.
num_scales
=
num_scales
self
.
aspect_ratios
=
aspect_ratios
self
.
match_threshold
=
match_threshold
@
property
def
anchors
(
self
):
if
not
hasattr
(
self
,
'_anchors'
):
anchor_grid
=
AnchorGrid
(
self
.
image_size
,
self
.
min_level
,
self
.
max_level
,
self
.
anchor_base_scale
,
self
.
num_scales
,
self
.
aspect_ratios
)
self
.
_anchors
=
np
.
concatenate
(
anchor_grid
.
generate
())
return
self
.
_anchors
def
iou_matrix
(
self
,
a
,
b
):
tl_i
=
np
.
maximum
(
a
[:,
np
.
newaxis
,
:
2
],
b
[:,
:
2
])
br_i
=
np
.
minimum
(
a
[:,
np
.
newaxis
,
2
:],
b
[:,
2
:])
area_i
=
np
.
prod
(
br_i
-
tl_i
,
axis
=
2
)
*
(
tl_i
<
br_i
).
all
(
axis
=
2
)
area_a
=
np
.
prod
(
a
[:,
2
:]
-
a
[:,
:
2
],
axis
=
1
)
area_b
=
np
.
prod
(
b
[:,
2
:]
-
b
[:,
:
2
],
axis
=
1
)
area_o
=
(
area_a
[:,
np
.
newaxis
]
+
area_b
-
area_i
)
# return area_i / (area_o + 1e-10)
return
np
.
where
(
area_i
==
0.
,
np
.
zeros_like
(
area_i
),
area_i
/
area_o
)
def
match
(
self
,
anchors
,
gt_boxes
):
# XXX put smaller matrix first would be a little bit faster
mat
=
self
.
iou_matrix
(
gt_boxes
,
anchors
)
max_anchor_for_each_gt
=
mat
.
argmax
(
axis
=
1
)
max_for_each_anchor
=
mat
.
max
(
axis
=
0
)
anchor_to_gt
=
mat
.
argmax
(
axis
=
0
)
anchor_to_gt
[
max_for_each_anchor
<
self
.
match_threshold
]
=
-
1
# XXX ensure each gt has at least one anchor assigned,
# see `force_match_for_each_row` in TF implementation
one_hot
=
np
.
zeros_like
(
mat
)
one_hot
[
np
.
arange
(
mat
.
shape
[
0
]),
max_anchor_for_each_gt
]
=
1.
max_anchor_indices
=
one_hot
.
sum
(
axis
=
0
).
nonzero
()[
0
]
max_gt_indices
=
one_hot
.
argmax
(
axis
=
0
)[
max_anchor_indices
]
anchor_to_gt
[
max_anchor_indices
]
=
max_gt_indices
return
anchor_to_gt
def
encode
(
self
,
anchors
,
boxes
):
wha
=
anchors
[...,
2
:]
-
anchors
[...,
:
2
]
+
1
ca
=
anchors
[...,
:
2
]
+
wha
*
.
5
whb
=
boxes
[...,
2
:]
-
boxes
[...,
:
2
]
+
1
cb
=
boxes
[...,
:
2
]
+
whb
*
.
5
offsets
=
np
.
empty_like
(
anchors
)
offsets
[...,
:
2
]
=
(
cb
-
ca
)
/
wha
offsets
[...,
2
:]
=
np
.
log
(
whb
/
wha
)
return
offsets
def
__call__
(
self
,
sample
,
context
=
None
):
gt_boxes
=
sample
[
'gt_bbox'
]
gt_labels
=
sample
[
'gt_class'
]
labels
=
np
.
full
((
self
.
anchors
.
shape
[
0
],
1
),
0
,
dtype
=
np
.
int32
)
targets
=
np
.
full
((
self
.
anchors
.
shape
[
0
],
4
),
0.
,
dtype
=
np
.
float32
)
sample
[
'gt_label'
]
=
labels
sample
[
'gt_target'
]
=
targets
if
len
(
gt_boxes
)
<
1
:
sample
[
'fg_num'
]
=
np
.
array
(
0
,
dtype
=
np
.
int32
)
return
sample
anchor_to_gt
=
self
.
match
(
self
.
anchors
,
gt_boxes
)
matched_indices
=
(
anchor_to_gt
>=
0
).
nonzero
()[
0
]
labels
[
matched_indices
]
=
gt_labels
[
anchor_to_gt
[
matched_indices
]]
matched_boxes
=
gt_boxes
[
anchor_to_gt
[
matched_indices
]]
matched_anchors
=
self
.
anchors
[
matched_indices
]
matched_targets
=
self
.
encode
(
matched_anchors
,
matched_boxes
)
targets
[
matched_indices
]
=
matched_targets
sample
[
'fg_num'
]
=
np
.
array
(
len
(
matched_targets
),
dtype
=
np
.
int32
)
return
sample
ppdet/modeling/anchor_heads/__init__.py
浏览文件 @
0e228b11
...
@@ -19,9 +19,11 @@ from . import yolo_head
...
@@ -19,9 +19,11 @@ from . import yolo_head
from
.
import
retina_head
from
.
import
retina_head
from
.
import
fcos_head
from
.
import
fcos_head
from
.
import
corner_head
from
.
import
corner_head
from
.
import
efficient_head
from
.rpn_head
import
*
from
.rpn_head
import
*
from
.yolo_head
import
*
from
.yolo_head
import
*
from
.retina_head
import
*
from
.retina_head
import
*
from
.fcos_head
import
*
from
.fcos_head
import
*
from
.corner_head
import
*
from
.corner_head
import
*
from
.efficient_head
import
*
ppdet/modeling/anchor_heads/efficient_head.py
0 → 100644
浏览文件 @
0e228b11
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
TruncatedNormal
,
Constant
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.modeling.ops
import
RetinaOutputDecoder
from
ppdet.core.workspace
import
register
__all__
=
[
'EfficientHead'
]
@
register
class
EfficientHead
(
object
):
"""
EfficientDet Head
Args:
output_decoder (object): `RetinaOutputDecoder` instance.
repeat (int): Number of convolution layers.
num_chan (int): Number of octave output channels.
prior_prob (float): Initial value of the class prediction layer bias.
num_anchors (int): Number of anchors per cell.
num_classes (int): Number of classes.
gamma (float): Gamma parameter for focal loss.
alpha (float): Alpha parameter for focal loss.
sigma (float): Sigma parameter for smooth l1 loss.
"""
__inject__
=
[
'output_decoder'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
output_decoder
=
RetinaOutputDecoder
().
__dict__
,
repeat
=
3
,
num_chan
=
64
,
prior_prob
=
0.01
,
num_anchors
=
9
,
num_classes
=
81
,
gamma
=
1.5
,
alpha
=
0.25
,
delta
=
0.1
):
super
(
EfficientHead
,
self
).
__init__
()
self
.
output_decoder
=
output_decoder
self
.
repeat
=
repeat
self
.
num_chan
=
num_chan
self
.
prior_prob
=
prior_prob
self
.
num_anchors
=
num_anchors
self
.
num_classes
=
num_classes
self
.
gamma
=
gamma
self
.
alpha
=
alpha
self
.
delta
=
delta
if
isinstance
(
output_decoder
,
dict
):
self
.
output_decoder
=
RetinaOutputDecoder
(
**
output_decoder
)
def
_get_output
(
self
,
body_feats
):
def
separable_conv
(
inputs
,
num_chan
,
bias_init
=
None
,
name
=
''
):
dw_conv_name
=
name
+
'_dw'
pw_conv_name
=
name
+
'_pw'
in_chan
=
inputs
.
shape
[
1
]
fan_in
=
np
.
sqrt
(
1.
/
(
in_chan
*
3
*
3
))
feat
=
fluid
.
layers
.
conv2d
(
input
=
inputs
,
num_filters
=
in_chan
,
groups
=
in_chan
,
filter_size
=
3
,
stride
=
1
,
padding
=
'SAME'
,
param_attr
=
ParamAttr
(
name
=
dw_conv_name
+
'_w'
,
initializer
=
TruncatedNormal
(
scale
=
fan_in
)),
bias_attr
=
False
)
fan_in
=
np
.
sqrt
(
1.
/
in_chan
)
feat
=
fluid
.
layers
.
conv2d
(
input
=
feat
,
num_filters
=
num_chan
,
filter_size
=
1
,
stride
=
1
,
param_attr
=
ParamAttr
(
name
=
pw_conv_name
+
'_w'
,
initializer
=
TruncatedNormal
(
scale
=
fan_in
)),
bias_attr
=
ParamAttr
(
name
=
pw_conv_name
+
'_b'
,
initializer
=
bias_init
,
regularizer
=
L2Decay
(
0.
)))
return
feat
def
subnet
(
inputs
,
prefix
,
level
):
feat
=
inputs
for
i
in
range
(
self
.
repeat
):
# NOTE share weight across FPN levels
conv_name
=
'{}_pred_conv_{}'
.
format
(
prefix
,
i
)
feat
=
separable_conv
(
feat
,
self
.
num_chan
,
name
=
conv_name
)
# NOTE batch norm params are not shared
bn_name
=
'{}_pred_bn_{}_{}'
.
format
(
prefix
,
level
,
i
)
feat
=
fluid
.
layers
.
batch_norm
(
input
=
feat
,
act
=
'swish'
,
momentum
=
0.997
,
epsilon
=
1e-4
,
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_w'
,
initializer
=
Constant
(
value
=
1.
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
'_b'
,
regularizer
=
L2Decay
(
0.
)))
return
feat
cls_preds
=
[]
box_preds
=
[]
for
l
,
feat
in
enumerate
(
body_feats
):
cls_out
=
subnet
(
feat
,
'cls'
,
l
)
box_out
=
subnet
(
feat
,
'box'
,
l
)
bias_init
=
float
(
-
np
.
log
((
1
-
self
.
prior_prob
)
/
self
.
prior_prob
))
bias_init
=
Constant
(
value
=
bias_init
)
cls_pred
=
separable_conv
(
cls_out
,
self
.
num_anchors
*
(
self
.
num_classes
-
1
),
bias_init
=
bias_init
,
name
=
'cls_pred'
)
cls_pred
=
fluid
.
layers
.
transpose
(
cls_pred
,
perm
=
[
0
,
2
,
3
,
1
])
cls_pred
=
fluid
.
layers
.
reshape
(
cls_pred
,
shape
=
(
0
,
-
1
,
self
.
num_classes
-
1
))
cls_preds
.
append
(
cls_pred
)
box_pred
=
separable_conv
(
box_out
,
self
.
num_anchors
*
4
,
name
=
'box_pred'
)
box_pred
=
fluid
.
layers
.
transpose
(
box_pred
,
perm
=
[
0
,
2
,
3
,
1
])
box_pred
=
fluid
.
layers
.
reshape
(
box_pred
,
shape
=
(
0
,
-
1
,
4
))
box_preds
.
append
(
box_pred
)
return
cls_preds
,
box_preds
def
get_prediction
(
self
,
body_feats
,
anchors
,
im_info
):
cls_preds
,
box_preds
=
self
.
_get_output
(
body_feats
)
cls_preds
=
[
fluid
.
layers
.
sigmoid
(
pred
)
for
pred
in
cls_preds
]
pred_result
=
self
.
output_decoder
(
bboxes
=
box_preds
,
scores
=
cls_preds
,
anchors
=
anchors
,
im_info
=
im_info
)
return
{
'bbox'
:
pred_result
}
def
get_loss
(
self
,
body_feats
,
gt_labels
,
gt_targets
,
fg_num
):
cls_preds
,
box_preds
=
self
.
_get_output
(
body_feats
)
fg_num
=
fluid
.
layers
.
reduce_sum
(
fg_num
,
name
=
'fg_num'
)
fg_num
.
stop_gradient
=
True
cls_pred
=
fluid
.
layers
.
concat
(
cls_preds
,
axis
=
1
)
box_pred
=
fluid
.
layers
.
concat
(
box_preds
,
axis
=
1
)
cls_pred_reshape
=
fluid
.
layers
.
reshape
(
cls_pred
,
shape
=
(
-
1
,
self
.
num_classes
-
1
))
gt_labels_reshape
=
fluid
.
layers
.
reshape
(
gt_labels
,
shape
=
(
-
1
,
1
))
loss_cls
=
fluid
.
layers
.
sigmoid_focal_loss
(
x
=
cls_pred_reshape
,
label
=
gt_labels_reshape
,
fg_num
=
fg_num
,
gamma
=
self
.
gamma
,
alpha
=
self
.
alpha
)
loss_cls
=
fluid
.
layers
.
reduce_sum
(
loss_cls
)
loss_bbox
=
fluid
.
layers
.
huber_loss
(
input
=
box_pred
,
label
=
gt_targets
,
delta
=
self
.
delta
)
mask
=
fluid
.
layers
.
expand
(
gt_labels
,
expand_times
=
[
1
,
1
,
4
])
>
0
loss_bbox
*=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss_bbox
=
fluid
.
layers
.
reduce_sum
(
loss_bbox
)
/
(
fg_num
*
4
)
return
{
'loss_cls'
:
loss_cls
,
'loss_bbox'
:
loss_bbox
}
ppdet/modeling/architectures/__init__.py
浏览文件 @
0e228b11
...
@@ -22,6 +22,7 @@ from . import cascade_rcnn_cls_aware
...
@@ -22,6 +22,7 @@ from . import cascade_rcnn_cls_aware
from
.
import
yolov3
from
.
import
yolov3
from
.
import
ssd
from
.
import
ssd
from
.
import
retinanet
from
.
import
retinanet
from
.
import
efficientdet
from
.
import
blazeface
from
.
import
blazeface
from
.
import
faceboxes
from
.
import
faceboxes
from
.
import
fcos
from
.
import
fcos
...
@@ -35,6 +36,7 @@ from .cascade_rcnn_cls_aware import *
...
@@ -35,6 +36,7 @@ from .cascade_rcnn_cls_aware import *
from
.yolov3
import
*
from
.yolov3
import
*
from
.ssd
import
*
from
.ssd
import
*
from
.retinanet
import
*
from
.retinanet
import
*
from
.efficientdet
import
*
from
.blazeface
import
*
from
.blazeface
import
*
from
.faceboxes
import
*
from
.faceboxes
import
*
from
.fcos
import
*
from
.fcos
import
*
...
...
ppdet/modeling/architectures/efficientdet.py
0 → 100644
浏览文件 @
0e228b11
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
ppdet.experimental
import
mixed_precision_global_state
from
ppdet.core.workspace
import
register
__all__
=
[
'EfficientDet'
]
@
register
class
EfficientDet
(
object
):
"""
EfficientDet architecture, see https://arxiv.org/abs/1911.09070
Args:
backbone (object): backbone instance
fpn (object): feature pyramid network instance
retina_head (object): `RetinaHead` instance
"""
__category__
=
'architecture'
__inject__
=
[
'backbone'
,
'fpn'
,
'efficient_head'
,
'anchor_grid'
]
def
__init__
(
self
,
backbone
,
fpn
,
efficient_head
,
anchor_grid
,
box_loss_weight
=
50.
):
super
(
EfficientDet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
fpn
=
fpn
self
.
efficient_head
=
efficient_head
self
.
anchor_grid
=
anchor_grid
self
.
box_loss_weight
=
box_loss_weight
def
build
(
self
,
feed_vars
,
mode
=
'train'
):
im
=
feed_vars
[
'image'
]
if
mode
==
'train'
:
gt_labels
=
feed_vars
[
'gt_label'
]
gt_targets
=
feed_vars
[
'gt_target'
]
fg_num
=
feed_vars
[
'fg_num'
]
else
:
im_info
=
feed_vars
[
'im_info'
]
mixed_precision_enabled
=
mixed_precision_global_state
()
is
not
None
if
mixed_precision_enabled
:
im
=
fluid
.
layers
.
cast
(
im
,
'float16'
)
body_feats
=
self
.
backbone
(
im
)
if
mixed_precision_enabled
:
body_feats
=
[
fluid
.
layers
.
cast
(
f
,
'float32'
)
for
f
in
body_feats
]
body_feats
=
self
.
fpn
(
body_feats
)
# XXX not used for training, but the parameters are needed when
# exporting inference model
anchors
=
self
.
anchor_grid
()
if
mode
==
'train'
:
loss
=
self
.
efficient_head
.
get_loss
(
body_feats
,
gt_labels
,
gt_targets
,
fg_num
)
loss_cls
=
loss
[
'loss_cls'
]
loss_bbox
=
loss
[
'loss_bbox'
]
total_loss
=
loss_cls
+
self
.
box_loss_weight
*
loss_bbox
loss
.
update
({
'loss'
:
total_loss
})
return
loss
else
:
pred
=
self
.
efficient_head
.
get_prediction
(
body_feats
,
anchors
,
im_info
)
return
pred
def
_inputs_def
(
self
,
image_shape
):
im_shape
=
[
None
]
+
image_shape
inputs_def
=
{
'image'
:
{
'shape'
:
im_shape
,
'dtype'
:
'float32'
},
'im_info'
:
{
'shape'
:
[
None
,
3
],
'dtype'
:
'float32'
},
'im_id'
:
{
'shape'
:
[
None
,
1
],
'dtype'
:
'int64'
},
'im_shape'
:
{
'shape'
:
[
None
,
3
],
'dtype'
:
'float32'
},
'fg_num'
:
{
'shape'
:
[
None
,
1
],
'dtype'
:
'int32'
},
'gt_label'
:
{
'shape'
:
[
None
,
None
,
1
],
'dtype'
:
'int32'
},
'gt_target'
:
{
'shape'
:
[
None
,
None
,
4
],
'dtype'
:
'float32'
},
}
return
inputs_def
def
build_inputs
(
self
,
image_shape
=
[
3
,
None
,
None
],
fields
=
[
'image'
,
'im_info'
,
'im_id'
,
'fg_num'
,
'gt_label'
,
'gt_target'
],
use_dataloader
=
True
,
iterable
=
False
):
inputs_def
=
self
.
_inputs_def
(
image_shape
)
feed_vars
=
OrderedDict
([(
key
,
fluid
.
data
(
name
=
key
,
shape
=
inputs_def
[
key
][
'shape'
],
dtype
=
inputs_def
[
key
][
'dtype'
]))
for
key
in
fields
])
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
list
(
feed_vars
.
values
()),
capacity
=
16
,
use_double_buffer
=
True
,
iterable
=
iterable
)
if
use_dataloader
else
None
return
feed_vars
,
loader
def
train
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
'train'
)
def
eval
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
'test'
)
def
test
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
'test'
)
ppdet/modeling/backbones/__init__.py
浏览文件 @
0e228b11
...
@@ -30,6 +30,8 @@ from . import hrnet
...
@@ -30,6 +30,8 @@ from . import hrnet
from
.
import
hrfpn
from
.
import
hrfpn
from
.
import
bfp
from
.
import
bfp
from
.
import
hourglass
from
.
import
hourglass
from
.
import
efficientnet
from
.
import
bifpn
from
.resnet
import
*
from
.resnet
import
*
from
.resnext
import
*
from
.resnext
import
*
...
@@ -47,3 +49,5 @@ from .hrnet import *
...
@@ -47,3 +49,5 @@ from .hrnet import *
from
.hrfpn
import
*
from
.hrfpn
import
*
from
.bfp
import
*
from
.bfp
import
*
from
.hourglass
import
*
from
.hourglass
import
*
from
.efficientnet
import
*
from
.bifpn
import
*
ppdet/modeling/backbones/bifpn.py
0 → 100644
浏览文件 @
0e228b11
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.initializer
import
Constant
,
Xavier
from
ppdet.core.workspace
import
register
__all__
=
[
'BiFPN'
]
class
FusionConv
(
object
):
def
__init__
(
self
,
num_chan
):
super
(
FusionConv
,
self
).
__init__
()
self
.
num_chan
=
num_chan
def
__call__
(
self
,
inputs
,
name
=
''
):
x
=
fluid
.
layers
.
swish
(
inputs
)
# depthwise
x
=
fluid
.
layers
.
conv2d
(
x
,
self
.
num_chan
,
filter_size
=
3
,
padding
=
'SAME'
,
groups
=
self
.
num_chan
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
(),
name
=
name
+
'_dw_w'
),
bias_attr
=
False
)
# pointwise
x
=
fluid
.
layers
.
conv2d
(
x
,
self
.
num_chan
,
filter_size
=
1
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
(),
name
=
name
+
'_pw_w'
),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
),
name
=
name
+
'_pw_b'
))
# bn + act
x
=
fluid
.
layers
.
batch_norm
(
x
,
momentum
=
0.997
,
epsilon
=
1e-04
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
1.0
),
regularizer
=
L2Decay
(
0.
),
name
=
name
+
'_bn_w'
),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
),
name
=
name
+
'_bn_b'
))
return
x
class
BiFPNCell
(
object
):
def
__init__
(
self
,
num_chan
,
levels
=
5
):
super
(
BiFPNCell
,
self
).
__init__
()
self
.
levels
=
levels
self
.
num_chan
=
num_chan
num_trigates
=
levels
-
2
num_bigates
=
levels
self
.
trigates
=
fluid
.
layers
.
create_parameter
(
shape
=
[
num_trigates
,
3
],
dtype
=
'float32'
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
self
.
bigates
=
fluid
.
layers
.
create_parameter
(
shape
=
[
num_bigates
,
2
],
dtype
=
'float32'
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
self
.
eps
=
1e-4
def
__call__
(
self
,
inputs
,
cell_name
=
''
):
assert
len
(
inputs
)
==
self
.
levels
def
upsample
(
feat
):
return
fluid
.
layers
.
resize_nearest
(
feat
,
scale
=
2.
)
def
downsample
(
feat
):
return
fluid
.
layers
.
pool2d
(
feat
,
pool_type
=
'max'
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
'SAME'
)
fuse_conv
=
FusionConv
(
self
.
num_chan
)
# normalize weight
trigates
=
fluid
.
layers
.
relu
(
self
.
trigates
)
bigates
=
fluid
.
layers
.
relu
(
self
.
bigates
)
trigates
/=
fluid
.
layers
.
reduce_sum
(
trigates
,
dim
=
1
,
keep_dim
=
True
)
+
self
.
eps
bigates
/=
fluid
.
layers
.
reduce_sum
(
bigates
,
dim
=
1
,
keep_dim
=
True
)
+
self
.
eps
feature_maps
=
list
(
inputs
)
# make a copy
# top down path
for
l
in
range
(
self
.
levels
-
1
):
p
=
self
.
levels
-
l
-
2
w1
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
l
,
0
],
ends
=
[
l
+
1
,
1
])
w2
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
l
,
1
],
ends
=
[
l
+
1
,
2
])
above
=
upsample
(
feature_maps
[
p
+
1
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
above
+
w2
*
inputs
[
p
],
name
=
'{}_tb_{}'
.
format
(
cell_name
,
l
))
# bottom up path
for
l
in
range
(
1
,
self
.
levels
):
p
=
l
name
=
'{}_bt_{}'
.
format
(
cell_name
,
l
)
below
=
downsample
(
feature_maps
[
p
-
1
])
if
p
==
self
.
levels
-
1
:
# handle P7
w1
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
p
,
0
],
ends
=
[
p
+
1
,
1
])
w2
=
fluid
.
layers
.
slice
(
bigates
,
axes
=
[
0
,
1
],
starts
=
[
p
,
1
],
ends
=
[
p
+
1
,
2
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
below
+
w2
*
inputs
[
p
],
name
=
name
)
else
:
w1
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
0
],
ends
=
[
p
,
1
])
w2
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
1
],
ends
=
[
p
,
2
])
w3
=
fluid
.
layers
.
slice
(
trigates
,
axes
=
[
0
,
1
],
starts
=
[
p
-
1
,
2
],
ends
=
[
p
,
3
])
feature_maps
[
p
]
=
fuse_conv
(
w1
*
feature_maps
[
p
]
+
w2
*
below
+
w3
*
inputs
[
p
],
name
=
name
)
return
feature_maps
@
register
class
BiFPN
(
object
):
"""
Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
Args:
num_chan (int): number of feature channels
repeat (int): number of repeats of the BiFPN module
level (int): number of FPN levels, default: 5
"""
def
__init__
(
self
,
num_chan
,
repeat
=
3
,
levels
=
5
):
super
(
BiFPN
,
self
).
__init__
()
self
.
num_chan
=
num_chan
self
.
repeat
=
repeat
self
.
levels
=
levels
def
__call__
(
self
,
inputs
):
feats
=
[]
# NOTE add two extra levels
for
idx
in
range
(
self
.
levels
):
if
idx
<=
len
(
inputs
):
if
idx
==
len
(
inputs
):
feat
=
inputs
[
-
1
]
else
:
feat
=
inputs
[
idx
]
if
feat
.
shape
[
1
]
!=
self
.
num_chan
:
feat
=
fluid
.
layers
.
conv2d
(
feat
,
self
.
num_chan
,
filter_size
=
1
,
padding
=
'SAME'
,
param_attr
=
ParamAttr
(
initializer
=
Xavier
()),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)))
feat
=
fluid
.
layers
.
batch_norm
(
feat
,
momentum
=
0.997
,
epsilon
=
1e-04
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
1.0
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
)))
if
idx
>=
len
(
inputs
):
feat
=
fluid
.
layers
.
pool2d
(
feat
,
pool_type
=
'max'
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
'SAME'
)
feats
.
append
(
feat
)
biFPN
=
BiFPNCell
(
self
.
num_chan
,
self
.
levels
)
for
r
in
range
(
self
.
repeat
):
feats
=
biFPN
(
feats
,
'bifpn_{}'
.
format
(
r
))
return
feats
ppdet/modeling/backbones/efficientnet.py
0 → 100644
浏览文件 @
0e228b11
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
import
collections
import
math
import
re
from
paddle
import
fluid
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
__all__
=
[
'EfficientNet'
]
GlobalParams
=
collections
.
namedtuple
(
'GlobalParams'
,
[
'batch_norm_momentum'
,
'batch_norm_epsilon'
,
'width_coefficient'
,
'depth_coefficient'
,
'depth_divisor'
])
BlockArgs
=
collections
.
namedtuple
(
'BlockArgs'
,
[
'kernel_size'
,
'num_repeat'
,
'input_filters'
,
'output_filters'
,
'expand_ratio'
,
'stride'
,
'se_ratio'
])
GlobalParams
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
GlobalParams
.
_fields
)
BlockArgs
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
BlockArgs
.
_fields
)
def
_decode_block_string
(
block_string
):
assert
isinstance
(
block_string
,
str
)
ops
=
block_string
.
split
(
'_'
)
options
=
{}
for
op
in
ops
:
splits
=
re
.
split
(
r
'(\d.*)'
,
op
)
if
len
(
splits
)
>=
2
:
key
,
value
=
splits
[:
2
]
options
[
key
]
=
value
assert
((
's'
in
options
and
len
(
options
[
's'
])
==
1
)
or
(
len
(
options
[
's'
])
==
2
and
options
[
's'
][
0
]
==
options
[
's'
][
1
]))
return
BlockArgs
(
kernel_size
=
int
(
options
[
'k'
]),
num_repeat
=
int
(
options
[
'r'
]),
input_filters
=
int
(
options
[
'i'
]),
output_filters
=
int
(
options
[
'o'
]),
expand_ratio
=
int
(
options
[
'e'
]),
se_ratio
=
float
(
options
[
'se'
])
if
'se'
in
options
else
None
,
stride
=
int
(
options
[
's'
][
0
]))
def
get_model_params
(
scale
):
block_strings
=
[
'r1_k3_s11_e1_i32_o16_se0.25'
,
'r2_k3_s22_e6_i16_o24_se0.25'
,
'r2_k5_s22_e6_i24_o40_se0.25'
,
'r3_k3_s22_e6_i40_o80_se0.25'
,
'r3_k5_s11_e6_i80_o112_se0.25'
,
'r4_k5_s22_e6_i112_o192_se0.25'
,
'r1_k3_s11_e6_i192_o320_se0.25'
,
]
block_args
=
[]
for
block_string
in
block_strings
:
block_args
.
append
(
_decode_block_string
(
block_string
))
params_dict
=
{
# width, depth
'b0'
:
(
1.0
,
1.0
),
'b1'
:
(
1.0
,
1.1
),
'b2'
:
(
1.1
,
1.2
),
'b3'
:
(
1.2
,
1.4
),
'b4'
:
(
1.4
,
1.8
),
'b5'
:
(
1.6
,
2.2
),
'b6'
:
(
1.8
,
2.6
),
'b7'
:
(
2.0
,
3.1
),
}
w
,
d
=
params_dict
[
scale
]
global_params
=
GlobalParams
(
batch_norm_momentum
=
0.99
,
batch_norm_epsilon
=
1e-3
,
width_coefficient
=
w
,
depth_coefficient
=
d
,
depth_divisor
=
8
)
return
block_args
,
global_params
def
round_filters
(
filters
,
global_params
):
multiplier
=
global_params
.
width_coefficient
if
not
multiplier
:
return
filters
divisor
=
global_params
.
depth_divisor
filters
*=
multiplier
min_depth
=
divisor
new_filters
=
max
(
min_depth
,
int
(
filters
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_filters
<
0.9
*
filters
:
# prevent rounding by more than 10%
new_filters
+=
divisor
return
int
(
new_filters
)
def
round_repeats
(
repeats
,
global_params
):
multiplier
=
global_params
.
depth_coefficient
if
not
multiplier
:
return
repeats
return
int
(
math
.
ceil
(
multiplier
*
repeats
))
def
conv2d
(
inputs
,
num_filters
,
filter_size
,
stride
=
1
,
padding
=
'SAME'
,
groups
=
1
,
use_bias
=
False
,
name
=
'conv2d'
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weights'
)
bias_attr
=
False
if
use_bias
:
bias_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_offset'
,
regularizer
=
L2Decay
(
0.
))
feats
=
fluid
.
layers
.
conv2d
(
inputs
,
num_filters
,
filter_size
,
groups
=
groups
,
name
=
name
,
stride
=
stride
,
padding
=
padding
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
return
feats
def
batch_norm
(
inputs
,
momentum
,
eps
,
name
=
None
):
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_scale'
,
regularizer
=
L2Decay
(
0.
))
bias_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_offset'
,
regularizer
=
L2Decay
(
0.
))
return
fluid
.
layers
.
batch_norm
(
input
=
inputs
,
momentum
=
momentum
,
epsilon
=
eps
,
name
=
name
,
moving_mean_name
=
name
+
'_mean'
,
moving_variance_name
=
name
+
'_variance'
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
)
def
mb_conv_block
(
inputs
,
input_filters
,
output_filters
,
expand_ratio
,
kernel_size
,
stride
,
momentum
,
eps
,
se_ratio
=
None
,
name
=
None
):
feats
=
inputs
num_filters
=
input_filters
*
expand_ratio
if
expand_ratio
!=
1
:
feats
=
conv2d
(
feats
,
num_filters
,
1
,
name
=
name
+
'_expand_conv'
)
feats
=
batch_norm
(
feats
,
momentum
,
eps
,
name
=
name
+
'_bn0'
)
feats
=
fluid
.
layers
.
swish
(
feats
)
feats
=
conv2d
(
feats
,
num_filters
,
kernel_size
,
stride
,
groups
=
num_filters
,
name
=
name
+
'_depthwise_conv'
)
feats
=
batch_norm
(
feats
,
momentum
,
eps
,
name
=
name
+
'_bn1'
)
feats
=
fluid
.
layers
.
swish
(
feats
)
if
se_ratio
is
not
None
:
filter_squeezed
=
max
(
1
,
int
(
input_filters
*
se_ratio
))
squeezed
=
fluid
.
layers
.
pool2d
(
feats
,
pool_type
=
'avg'
,
global_pooling
=
True
)
squeezed
=
conv2d
(
squeezed
,
filter_squeezed
,
1
,
use_bias
=
True
,
name
=
name
+
'_se_reduce'
)
squeezed
=
fluid
.
layers
.
swish
(
squeezed
)
squeezed
=
conv2d
(
squeezed
,
num_filters
,
1
,
use_bias
=
True
,
name
=
name
+
'_se_expand'
)
feats
=
feats
*
fluid
.
layers
.
sigmoid
(
squeezed
)
feats
=
conv2d
(
feats
,
output_filters
,
1
,
name
=
name
+
'_project_conv'
)
feats
=
batch_norm
(
feats
,
momentum
,
eps
,
name
=
name
+
'_bn2'
)
if
stride
==
1
and
input_filters
==
output_filters
:
feats
=
fluid
.
layers
.
elementwise_add
(
feats
,
inputs
)
return
feats
@
register
class
EfficientNet
(
object
):
"""
EfficientNet, see https://arxiv.org/abs/1905.11946
Args:
scale (str): compounding scale factor, 'b0' - 'b7'.
use_se (bool): use squeeze and excite module.
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
"""
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
scale
=
'b0'
,
use_se
=
True
,
norm_type
=
'bn'
):
assert
scale
in
[
'b'
+
str
(
i
)
for
i
in
range
(
8
)],
\
"valid scales are b0 - b7"
assert
norm_type
in
[
'bn'
,
'sync_bn'
],
\
"only 'bn' and 'sync_bn' are supported"
super
(
EfficientNet
,
self
).
__init__
()
self
.
norm_type
=
norm_type
self
.
scale
=
scale
self
.
use_se
=
use_se
def
__call__
(
self
,
inputs
):
blocks_args
,
global_params
=
get_model_params
(
self
.
scale
)
momentum
=
global_params
.
batch_norm_momentum
eps
=
global_params
.
batch_norm_epsilon
num_filters
=
round_filters
(
32
,
global_params
)
feats
=
conv2d
(
inputs
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
2
,
name
=
'_conv_stem'
)
feats
=
batch_norm
(
feats
,
momentum
=
momentum
,
eps
=
eps
,
name
=
'_bn0'
)
feats
=
fluid
.
layers
.
swish
(
feats
)
layer_count
=
0
feature_maps
=
[]
for
b
,
block_arg
in
enumerate
(
blocks_args
):
for
r
in
range
(
block_arg
.
num_repeat
):
input_filters
=
round_filters
(
block_arg
.
input_filters
,
global_params
)
output_filters
=
round_filters
(
block_arg
.
output_filters
,
global_params
)
kernel_size
=
block_arg
.
kernel_size
stride
=
block_arg
.
stride
se_ratio
=
None
if
self
.
use_se
:
se_ratio
=
block_arg
.
se_ratio
if
r
>
0
:
input_filters
=
output_filters
stride
=
1
feats
=
mb_conv_block
(
feats
,
input_filters
,
output_filters
,
block_arg
.
expand_ratio
,
kernel_size
,
stride
,
momentum
,
eps
,
se_ratio
=
se_ratio
,
name
=
'_blocks.{}.'
.
format
(
layer_count
))
layer_count
+=
1
feature_maps
.
append
(
feats
)
return
list
(
feature_maps
[
i
]
for
i
in
[
2
,
4
,
6
])
ppdet/modeling/ops.py
浏览文件 @
0e228b11
...
@@ -18,17 +18,19 @@ import math
...
@@ -18,17 +18,19 @@ import math
import
six
import
six
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.utils.bbox_utils
import
bbox_overlaps
,
box_to_delta
from
ppdet.utils.bbox_utils
import
bbox_overlaps
,
box_to_delta
__all__
=
[
__all__
=
[
'AnchorGenerator'
,
'
DropBlock'
,
'RPNTargetAssign'
,
'GenerateProposals
'
,
'AnchorGenerator'
,
'
AnchorGrid'
,
'DropBlock'
,
'RPNTargetAssign
'
,
'
MultiClassNMS'
,
'BBoxAssigner'
,
'MaskAssigner'
,
'RoIAlign'
,
'RoIPool
'
,
'
GenerateProposals'
,
'MultiClassNMS'
,
'BBoxAssigner'
,
'MaskAssigner
'
,
'
MultiBoxHead'
,
'SSDLiteMultiBoxHead'
,
'SSDOutputDecoder
'
,
'
RoIAlign'
,
'RoIPool'
,
'MultiBoxHead'
,
'SSDLiteMultiBoxHead
'
,
'
RetinaTargetAssign'
,
'RetinaOutputDecoder'
,
'ConvNorm'
,
'DeformConvNorm
'
,
'
SSDOutputDecoder'
,
'RetinaTargetAssign'
,
'RetinaOutputDecoder
'
,
'MultiClassSoftNMS'
,
'LibraBBoxAssigner'
'
ConvNorm'
,
'DeformConvNorm'
,
'
MultiClassSoftNMS'
,
'LibraBBoxAssigner'
]
]
...
@@ -324,6 +326,94 @@ class AnchorGenerator(object):
...
@@ -324,6 +326,94 @@ class AnchorGenerator(object):
self
.
stride
=
stride
self
.
stride
=
stride
@
register
@
serializable
class
AnchorGrid
(
object
):
"""Generate anchor grid
Args:
image_size (int or list): input image size, may be a single integer or
list of [h, w]. Default: 512
min_level (int): min level of the feature pyramid. Default: 3
max_level (int): max level of the feature pyramid. Default: 7
anchor_base_scale: base anchor scale. Default: 4
num_scales: number of anchor scales. Default: 3
aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
"""
def
__init__
(
self
,
image_size
=
512
,
min_level
=
3
,
max_level
=
7
,
anchor_base_scale
=
4
,
num_scales
=
3
,
aspect_ratios
=
[[
1
,
1
],
[
1.4
,
0.7
],
[
0.7
,
1.4
]]):
super
(
AnchorGrid
,
self
).
__init__
()
if
isinstance
(
image_size
,
Integral
):
self
.
image_size
=
[
image_size
,
image_size
]
else
:
self
.
image_size
=
image_size
for
dim
in
self
.
image_size
:
assert
dim
%
2
**
max_level
==
0
,
\
"image size should be multiple of the max level stride"
self
.
min_level
=
min_level
self
.
max_level
=
max_level
self
.
anchor_base_scale
=
anchor_base_scale
self
.
num_scales
=
num_scales
self
.
aspect_ratios
=
aspect_ratios
@
property
def
base_cell
(
self
):
if
not
hasattr
(
self
,
'_base_cell'
):
self
.
_base_cell
=
self
.
make_cell
()
return
self
.
_base_cell
def
make_cell
(
self
):
scales
=
[
2
**
(
i
/
self
.
num_scales
)
for
i
in
range
(
self
.
num_scales
)]
scales
=
np
.
array
(
scales
)
ratios
=
np
.
array
(
self
.
aspect_ratios
)
ws
=
np
.
outer
(
scales
,
ratios
[:,
0
]).
reshape
(
-
1
,
1
)
hs
=
np
.
outer
(
scales
,
ratios
[:,
1
]).
reshape
(
-
1
,
1
)
anchors
=
np
.
hstack
((
-
0.5
*
ws
,
-
0.5
*
hs
,
0.5
*
ws
,
0.5
*
hs
))
return
anchors
def
make_grid
(
self
,
stride
):
cell
=
self
.
base_cell
*
stride
*
self
.
anchor_base_scale
x_steps
=
np
.
arange
(
stride
//
2
,
self
.
image_size
[
1
],
stride
)
y_steps
=
np
.
arange
(
stride
//
2
,
self
.
image_size
[
0
],
stride
)
offset_x
,
offset_y
=
np
.
meshgrid
(
x_steps
,
y_steps
)
offset_x
=
offset_x
.
flatten
()
offset_y
=
offset_y
.
flatten
()
offsets
=
np
.
stack
((
offset_x
,
offset_y
,
offset_x
,
offset_y
),
axis
=-
1
)
offsets
=
offsets
[:,
np
.
newaxis
,
:]
return
(
cell
+
offsets
).
reshape
(
-
1
,
4
)
def
generate
(
self
):
return
[
self
.
make_grid
(
2
**
l
)
for
l
in
range
(
self
.
min_level
,
self
.
max_level
+
1
)
]
def
__call__
(
self
):
if
not
hasattr
(
self
,
'_anchor_vars'
):
anchor_vars
=
[]
helper
=
LayerHelper
(
'anchor_grid'
)
for
idx
,
l
in
enumerate
(
range
(
self
.
min_level
,
self
.
max_level
+
1
)):
stride
=
2
**
l
anchors
=
self
.
make_grid
(
stride
)
var
=
helper
.
create_parameter
(
attr
=
ParamAttr
(
name
=
'anchors_{}'
.
format
(
idx
)),
shape
=
anchors
.
shape
,
dtype
=
'float32'
,
stop_gradient
=
True
,
default_initializer
=
NumpyArrayInitializer
(
anchors
))
anchor_vars
.
append
(
var
)
var
.
persistable
=
True
self
.
_anchor_vars
=
anchor_vars
return
self
.
_anchor_vars
@
register
@
register
@
serializable
@
serializable
class
RPNTargetAssign
(
object
):
class
RPNTargetAssign
(
object
):
...
...
ppdet/optimizer.py
浏览文件 @
0e228b11
...
@@ -16,12 +16,15 @@ from __future__ import absolute_import
...
@@ -16,12 +16,15 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
math
import
logging
import
logging
from
paddle
import
fluid
from
paddle
import
fluid
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.regularizer
as
regularizer
import
paddle.fluid.regularizer
as
regularizer
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
from
paddle.fluid.layers.ops
import
cos
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.core.workspace
import
register
,
serializable
...
@@ -75,6 +78,50 @@ class CosineDecay(object):
...
@@ -75,6 +78,50 @@ class CosineDecay(object):
def
__call__
(
self
,
base_lr
=
None
,
learning_rate
=
None
):
def
__call__
(
self
,
base_lr
=
None
,
learning_rate
=
None
):
assert
base_lr
is
not
None
,
"either base LR or values should be provided"
assert
base_lr
is
not
None
,
"either base LR or values should be provided"
lr
=
fluid
.
layers
.
cosine_decay
(
base_lr
,
1
,
self
.
max_iters
)
lr
=
fluid
.
layers
.
cosine_decay
(
base_lr
,
1
,
self
.
max_iters
)
@
serializable
class
CosineDecayWithSkip
(
object
):
"""
Cosine decay, with explicit support for warm up
Args:
total_steps (int): total steps over which to apply the decay
skip_steps (int): skip some steps at the beginning, e.g., warm up
"""
def
__init__
(
self
,
total_steps
,
skip_steps
=
None
):
super
(
CosineDecayWithSkip
,
self
).
__init__
()
assert
(
not
skip_steps
or
skip_steps
>
0
),
\
"skip steps must be greater than zero"
assert
total_steps
>
0
,
"total step must be greater than zero"
assert
(
not
skip_steps
or
skip_steps
<
total_steps
),
\
"skip steps must be smaller than total steps"
self
.
total_steps
=
total_steps
self
.
skip_steps
=
skip_steps
def
__call__
(
self
,
base_lr
=
None
,
learning_rate
=
None
):
steps
=
_decay_step_counter
()
total
=
self
.
total_steps
if
self
.
skip_steps
is
not
None
:
total
-=
self
.
skip_steps
lr
=
fluid
.
layers
.
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
base_lr
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
"learning_rate"
)
def
decay
():
cos_lr
=
base_lr
*
.
5
*
(
cos
(
steps
*
(
math
.
pi
/
total
))
+
1
)
fluid
.
layers
.
tensor
.
assign
(
input
=
cos_lr
,
output
=
lr
)
if
self
.
skip_steps
is
None
:
decay
()
else
:
skipped
=
steps
>=
self
.
skip_steps
fluid
.
layers
.
cond
(
skipped
,
decay
)
return
lr
return
lr
...
@@ -140,10 +187,12 @@ class OptimizerBuilder():
...
@@ -140,10 +187,12 @@ class OptimizerBuilder():
__category__
=
'optim'
__category__
=
'optim'
def
__init__
(
self
,
def
__init__
(
self
,
clip_grad_by_norm
=
None
,
regularizer
=
{
'type'
:
'L2'
,
regularizer
=
{
'type'
:
'L2'
,
'factor'
:
.
0001
},
'factor'
:
.
0001
},
optimizer
=
{
'type'
:
'Momentum'
,
optimizer
=
{
'type'
:
'Momentum'
,
'momentum'
:
.
9
}):
'momentum'
:
.
9
}):
self
.
clip_grad_by_norm
=
clip_grad_by_norm
self
.
regularizer
=
regularizer
self
.
regularizer
=
regularizer
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
...
...
tools/train.py
浏览文件 @
0e228b11
...
@@ -38,6 +38,8 @@ set_paddle_flags(
...
@@ -38,6 +38,8 @@ set_paddle_flags(
)
)
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
from
paddle.fluid.optimizer
import
ExponentialMovingAverage
from
ppdet.experimental
import
mixed_precision_context
from
ppdet.experimental
import
mixed_precision_context
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
...
@@ -124,10 +126,21 @@ def main():
...
@@ -124,10 +126,21 @@ def main():
loss
*=
ctx
.
get_loss_scale_var
()
loss
*=
ctx
.
get_loss_scale_var
()
lr
=
lr_builder
()
lr
=
lr_builder
()
optimizer
=
optim_builder
(
lr
)
optimizer
=
optim_builder
(
lr
)
optimizer
.
minimize
(
loss
)
clip
=
None
if
optim_builder
.
clip_grad_by_norm
is
not
None
:
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
optim_builder
.
clip_grad_by_norm
)
optimizer
.
minimize
(
loss
,
grad_clip
=
clip
)
if
FLAGS
.
fp16
:
if
FLAGS
.
fp16
:
loss
/=
ctx
.
get_loss_scale_var
()
loss
/=
ctx
.
get_loss_scale_var
()
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
global_steps
=
_decay_step_counter
()
ema
=
ExponentialMovingAverage
(
cfg
[
'ema_decay'
],
thres_steps
=
global_steps
)
ema
.
update
()
# parse train fetches
# parse train fetches
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_keys
,
train_values
,
_
=
parse_fetches
(
train_fetches
)
train_values
.
append
(
lr
)
train_values
.
append
(
lr
)
...
@@ -265,6 +278,8 @@ def main():
...
@@ -265,6 +278,8 @@ def main():
if
(
it
>
0
and
it
%
cfg
.
snapshot_iter
==
0
or
it
==
cfg
.
max_iters
-
1
)
\
if
(
it
>
0
and
it
%
cfg
.
snapshot_iter
==
0
or
it
==
cfg
.
max_iters
-
1
)
\
and
(
not
FLAGS
.
dist
or
trainer_id
==
0
):
and
(
not
FLAGS
.
dist
or
trainer_id
==
0
):
save_name
=
str
(
it
)
if
it
!=
cfg
.
max_iters
-
1
else
"model_final"
save_name
=
str
(
it
)
if
it
!=
cfg
.
max_iters
-
1
else
"model_final"
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
exe
.
run
(
ema
.
apply_program
)
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
save_name
))
checkpoint
.
save
(
exe
,
train_prog
,
os
.
path
.
join
(
save_dir
,
save_name
))
if
FLAGS
.
eval
:
if
FLAGS
.
eval
:
...
@@ -299,6 +314,9 @@ def main():
...
@@ -299,6 +314,9 @@ def main():
logger
.
info
(
"Best test box ap: {}, in iter: {}"
.
format
(
logger
.
info
(
"Best test box ap: {}, in iter: {}"
.
format
(
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
best_box_ap_list
[
0
],
best_box_ap_list
[
1
]))
if
'use_ema'
in
cfg
and
cfg
[
'use_ema'
]:
exe
.
run
(
ema
.
restore_program
)
train_loader
.
reset
()
train_loader
.
reset
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录