Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
beaa62a7
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
beaa62a7
编写于
7月 07, 2020
作者:
L
longxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update yolov3
上级
a66dfe9c
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
908 addition
and
527 deletion
+908
-527
configs/ppyolo/ppyolo.yml
configs/ppyolo/ppyolo.yml
+91
-0
configs/ppyolo/ppyolo_lb.yml
configs/ppyolo/ppyolo_lb.yml
+91
-0
configs/ppyolo/ppyolo_reader.yml
configs/ppyolo/ppyolo_reader.yml
+111
-0
ppdet/modeling/anchor_heads/yolo_head.py
ppdet/modeling/anchor_heads/yolo_head.py
+590
-526
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+25
-1
未找到文件。
configs/ppyolo/ppyolo.yml
0 → 100644
浏览文件 @
beaa62a7
architecture
:
YOLOv3
use_gpu
:
true
max_iters
:
500000
log_smooth_window
:
100
log_iter
:
100
save_dir
:
output
snapshot_iter
:
10000
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights
:
output/ppyolo/model_final
num_classes
:
80
use_fine_grained_loss
:
true
use_ema
:
true
ema_decay
:
0.9998
YOLOv3
:
backbone
:
ResNet
yolo_head
:
YOLOv3Head
use_fine_grained_loss
:
true
ResNet
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
50
feature_maps
:
[
3
,
4
,
5
]
variant
:
d
dcn_v2_stages
:
[
5
]
YOLOv3Head
:
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
norm_decay
:
0.
coord_conv
:
true
iou_aware
:
true
iou_aware_factor
:
0.4
scale_x_y
:
1.05
spp
:
true
yolo_loss
:
YOLOv3Loss
nms
:
background_label
:
-1
keep_top_k
:
100
# nms_threshold: 0.45
# nms_top_k: 1000
normalized
:
false
score_threshold
:
0.01
drop_block
:
true
YOLOv3Loss
:
batch_size
:
24
ignore_thresh
:
0.7
scale_x_y
:
1.05
label_smooth
:
false
use_fine_grained_loss
:
true
iou_loss
:
IouLoss
iou_aware_loss
:
IouAwareLoss
IouLoss
:
loss_weight
:
2.5
max_height
:
608
max_width
:
608
IouAwareLoss
:
loss_weight
:
1.0
max_height
:
608
max_width
:
608
LearningRate
:
base_lr
:
0.00333
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
400000
-
450000
-
!LinearWarmup
start_factor
:
0.
steps
:
4000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
ppyolo_reader.yml'
configs/ppyolo/ppyolo_lb.yml
0 → 100644
浏览文件 @
beaa62a7
architecture
:
YOLOv3
use_gpu
:
true
max_iters
:
250000
log_smooth_window
:
100
log_iter
:
100
save_dir
:
output
snapshot_iter
:
10000
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights
:
output/ppyolo_lb/model_final
num_classes
:
80
use_fine_grained_loss
:
true
use_ema
:
true
ema_decay
:
0.9998
YOLOv3
:
backbone
:
ResNet
yolo_head
:
YOLOv3Head
use_fine_grained_loss
:
true
ResNet
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
50
feature_maps
:
[
3
,
4
,
5
]
variant
:
d
dcn_v2_stages
:
[
5
]
YOLOv3Head
:
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
norm_decay
:
0.
coord_conv
:
true
iou_aware
:
true
iou_aware_factor
:
0.4
scale_x_y
:
1.05
spp
:
true
yolo_loss
:
YOLOv3Loss
nms
:
background_label
:
-1
keep_top_k
:
100
# nms_threshold: 0.45
# nms_top_k: 1000
normalized
:
false
score_threshold
:
0.01
drop_block
:
true
YOLOv3Loss
:
batch_size
:
24
ignore_thresh
:
0.7
scale_x_y
:
1.05
label_smooth
:
false
use_fine_grained_loss
:
true
iou_loss
:
IouLoss
iou_aware_loss
:
IouAwareLoss
IouLoss
:
loss_weight
:
2.5
max_height
:
608
max_width
:
608
IouAwareLoss
:
loss_weight
:
1.0
max_height
:
608
max_width
:
608
LearningRate
:
base_lr
:
0.01
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
150000
-
200000
-
!LinearWarmup
start_factor
:
0.
steps
:
4000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
ppyolo_reader.yml'
configs/ppyolo/ppyolo_reader.yml
0 → 100644
浏览文件 @
beaa62a7
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
gt_bbox'
,
'
gt_class'
,
'
gt_score'
]
num_max_boxes
:
50
dataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
with_mixup
:
True
-
!MixupImage
alpha
:
1.5
beta
:
1.5
-
!ColorDistort
{}
-
!RandomExpand
fill_value
:
[
123.675
,
116.28
,
103.53
]
-
!RandomCrop
{}
-
!RandomFlipImage
is_normalized
:
false
-
!NormalizeBox
{}
-
!PadBox
num_max_boxes
:
50
-
!BboxXYXY2XYWH
{}
batch_transforms
:
-
!RandomShape
sizes
:
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]
random_inter
:
True
-
!NormalizeImage
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
-
!Gt2YoloTarget
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchors
:
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
downsample_ratios
:
[
32
,
16
,
8
]
batch_size
:
24
shuffle
:
true
# mixup_epoch: 250
mixup_epoch
:
25000
drop_last
:
true
worker_num
:
8
bufsize
:
4
use_process
:
true
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
]
num_max_boxes
:
50
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
2
-
!NormalizeImage
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
is_scale
:
True
is_channel_first
:
false
-
!PadBox
num_max_boxes
:
50
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
8
drop_empty
:
false
worker_num
:
8
bufsize
:
4
TestReader
:
inputs_def
:
image_shape
:
[
3
,
608
,
608
]
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
]
dataset
:
!ImageFolder
anno_path
:
annotations/instances_val2017.json
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
2
-
!NormalizeImage
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
ppdet/modeling/anchor_heads/yolo_head.py
浏览文件 @
beaa62a7
...
...
@@ -21,6 +21,7 @@ from paddle.fluid.param_attr import ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.modeling.ops
import
MultiClassNMS
,
MultiClassSoftNMS
from
ppdet.modeling.ops
import
MultiClassMatrixNMS
from
ppdet.modeling.losses.yolo_loss
import
YOLOv3Loss
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
DropBlock
...
...
@@ -56,11 +57,13 @@ class YOLOv3Head(object):
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]],
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]],
drop_block
=
False
,
coord_conv
=
False
,
iou_aware
=
False
,
iou_aware_factor
=
0.4
,
block_size
=
3
,
keep_prob
=
0.9
,
yolo_loss
=
"YOLOv3Loss"
,
spp
=
False
,
nms
=
MultiClassNMS
(
score_threshold
=
0.01
,
nms_top_k
=
1000
,
...
...
@@ -81,24 +84,45 @@ class YOLOv3Head(object):
self
.
prefix_name
=
weight_prefix_name
self
.
drop_block
=
drop_block
self
.
iou_aware
=
iou_aware
self
.
coord_conv
=
coord_conv
self
.
iou_aware_factor
=
iou_aware_factor
self
.
block_size
=
block_size
self
.
keep_prob
=
keep_prob
self
.
use_spp
=
spp
if
isinstance
(
nms
,
dict
):
self
.
nms
=
MultiClass
NMS
(
**
nms
)
self
.
nms
=
MultiClass
MatrixNMS
(
**
nms
)
self
.
downsample
=
downsample
self
.
scale_x_y
=
scale_x_y
self
.
clip_bbox
=
clip_bbox
def
_add_coord
(
self
,
input
):
input_shape
=
fluid
.
layers
.
shape
(
input
)
b
=
input_shape
[
0
]
h
=
input_shape
[
2
]
w
=
input_shape
[
3
]
x_range
=
fluid
.
layers
.
range
(
0
,
w
,
1
,
'float32'
)
/
(
w
-
1.
)
x_range
=
x_range
*
2.
-
1.
x_range
=
fluid
.
layers
.
unsqueeze
(
x_range
,
[
0
,
1
,
2
])
x_range
=
fluid
.
layers
.
expand
(
x_range
,
[
b
,
1
,
h
,
1
])
x_range
.
stop_gradient
=
True
y_range
=
fluid
.
layers
.
transpose
(
x_range
,
[
0
,
1
,
3
,
2
])
y_range
.
stop_gradient
=
True
return
fluid
.
layers
.
concat
([
input
,
x_range
,
y_range
],
axis
=
1
)
def
_conv_bn
(
self
,
input
,
ch_out
,
filter_size
,
stride
,
padding
,
coord_conv
=
False
,
act
=
'leaky'
,
is_test
=
True
,
name
=
None
):
if
coord_conv
:
input
=
self
.
_add_coord
(
input
)
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
...
...
@@ -117,6 +141,7 @@ class YOLOv3Head(object):
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
is_test
=
is_test
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
...
...
@@ -126,6 +151,32 @@ class YOLOv3Head(object):
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
def
_spp_module
(
self
,
input
,
is_test
=
True
,
name
=
""
):
output1
=
input
output2
=
fluid
.
layers
.
pool2d
(
input
=
output1
,
pool_size
=
5
,
pool_stride
=
1
,
pool_padding
=
2
,
ceil_mode
=
False
,
pool_type
=
'max'
)
output3
=
fluid
.
layers
.
pool2d
(
input
=
output1
,
pool_size
=
9
,
pool_stride
=
1
,
pool_padding
=
4
,
ceil_mode
=
False
,
pool_type
=
'max'
)
output4
=
fluid
.
layers
.
pool2d
(
input
=
output1
,
pool_size
=
13
,
pool_stride
=
1
,
pool_padding
=
6
,
ceil_mode
=
False
,
pool_type
=
'max'
)
output
=
fluid
.
layers
.
concat
(
input
=
[
output1
,
output2
,
output3
,
output4
],
axis
=
1
)
return
output
def
_detection_block
(
self
,
input
,
channel
,
is_test
=
True
,
name
=
None
):
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2 in detection block {}"
\
...
...
@@ -139,8 +190,19 @@ class YOLOv3Head(object):
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
coord_conv
=
True
,
is_test
=
is_test
,
name
=
'{}.{}.0'
.
format
(
name
,
j
))
if
self
.
use_spp
and
channel
==
512
and
j
==
1
:
conv
=
self
.
_spp_module
(
conv
,
is_test
=
is_test
,
name
=
"spp"
)
conv
=
self
.
_conv_bn
(
conv
,
512
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
,
name
=
'{}.{}.spp.conv'
.
format
(
name
,
j
))
conv
=
self
.
_conv_bn
(
conv
,
channel
*
2
,
...
...
@@ -168,6 +230,7 @@ class YOLOv3Head(object):
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
coord_conv
=
True
,
is_test
=
is_test
,
name
=
'{}.2'
.
format
(
name
))
tip
=
self
.
_conv_bn
(
...
...
@@ -176,6 +239,7 @@ class YOLOv3Head(object):
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
coord_conv
=
True
,
is_test
=
is_test
,
name
=
'{}.tip'
.
format
(
name
))
return
route
,
tip
...
...
ppdet/modeling/ops.py
浏览文件 @
beaa62a7
...
...
@@ -30,9 +30,33 @@ __all__ = [
'GenerateProposals'
,
'MultiClassNMS'
,
'BBoxAssigner'
,
'MaskAssigner'
,
'RoIAlign'
,
'RoIPool'
,
'MultiBoxHead'
,
'SSDLiteMultiBoxHead'
,
'SSDOutputDecoder'
,
'RetinaTargetAssign'
,
'RetinaOutputDecoder'
,
'ConvNorm'
,
'DeformConvNorm'
,
'MultiClassSoftNMS'
,
'LibraBBoxAssigner'
'DeformConvNorm'
,
'MultiClassSoftNMS'
,
'LibraBBoxAssigner'
,
'MultiClassMatrixNMS'
]
@
register
@
serializable
class
MultiClassMatrixNMS
(
object
):
__op__
=
fluid
.
layers
.
matrix_nms
__append_doc__
=
True
def
__init__
(
self
,
score_threshold
=
.
05
,
post_threshold
=
.
01
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
use_gaussian
=
False
,
gaussian_sigma
=
2.0
,
normalized
=
False
,
background_label
=
0
):
super
(
MultiClassMatrixNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
score_threshold
=
score_threshold
self
.
post_threshold
=
post_threshold
self
.
use_gaussian
=
use_gaussian
self
.
normalized
=
normalized
self
.
background_label
=
background_label
def
_conv_offset
(
input
,
filter_size
,
stride
,
padding
,
act
=
None
,
name
=
None
):
out_channel
=
filter_size
*
filter_size
*
3
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录