Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
8ef4a22d
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8ef4a22d
编写于
6月 30, 2020
作者:
F
FDInSky
提交者:
GitHub
6月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=dygraph add yolov3 model part (#995)
add yolov3
上级
e9d7f8a8
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
645 addition
and
40 deletion
+645
-40
ppdet/modeling/__init__.py
ppdet/modeling/__init__.py
+6
-6
ppdet/modeling/anchor.py
ppdet/modeling/anchor.py
+93
-18
ppdet/modeling/architecture/__init__.py
ppdet/modeling/architecture/__init__.py
+2
-0
ppdet/modeling/architecture/faster_rcnn.py
ppdet/modeling/architecture/faster_rcnn.py
+6
-4
ppdet/modeling/architecture/mask_rcnn.py
ppdet/modeling/architecture/mask_rcnn.py
+5
-3
ppdet/modeling/architecture/meta_arch.py
ppdet/modeling/architecture/meta_arch.py
+2
-1
ppdet/modeling/architecture/yolo.py
ppdet/modeling/architecture/yolo.py
+68
-0
ppdet/modeling/backbone/__init__.py
ppdet/modeling/backbone/__init__.py
+2
-0
ppdet/modeling/backbone/darknet.py
ppdet/modeling/backbone/darknet.py
+150
-0
ppdet/modeling/head/__init__.py
ppdet/modeling/head/__init__.py
+2
-0
ppdet/modeling/head/rpn_head.py
ppdet/modeling/head/rpn_head.py
+3
-4
ppdet/modeling/head/yolo_head.py
ppdet/modeling/head/yolo_head.py
+203
-0
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+103
-4
未找到文件。
ppdet/modeling/__init__.py
浏览文件 @
8ef4a22d
from
.
import
architecture
from
.
import
backbone
from
.
import
head
from
.
import
ops
from
.
import
anchor
from
.
import
backbone
from
.
import
head
from
.
import
architecture
from
.architecture
import
*
from
.backbone
import
*
from
.head
import
*
from
.ops
import
*
from
.anchor
import
*
from
.backbone
import
*
from
.head
import
*
from
.architecture
import
*
ppdet/modeling/anchor.py
浏览文件 @
8ef4a22d
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.dygraph.base
import
to_variable
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
(
AnchorGenerator
,
RPNAnchorTargetGenerator
,
ProposalGenerator
,
ProposalTargetGenerator
,
MaskTargetGenerator
,
DecodeClipNms
)
from
ppdet.modeling.ops
import
(
AnchorGeneratorYOLO
,
AnchorTargetGeneratorYOLO
,
AnchorGeneratorRPN
,
AnchorTargetGeneratorRPN
,
ProposalGenerator
,
ProposalTargetGenerator
,
MaskTargetGenerator
,
DecodeClipNms
,
YOLOBox
,
MultiClassNMS
)
# TODO: modify here into ppdet.modeling.ops like DecodeClipNms
from
ppdet.py_op.post_process
import
mask_post_process
@
register
class
BBoxPostProcess
(
Layer
):
class
BBoxPostProcess
(
object
):
def
__init__
(
self
,
decode
=
None
,
clip
=
None
,
...
...
@@ -39,6 +37,49 @@ class BBoxPostProcess(Layer):
return
outs
@
register
class
BBoxPostProcessYOLO
(
object
):
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
num_classes
=
80
,
decode
=
None
,
clip
=
None
,
yolo_box
=
YOLOBox
().
__dict__
,
nms
=
MultiClassNMS
().
__dict__
):
super
(
BBoxPostProcessYOLO
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
decode
=
decode
self
.
clip
=
clip
self
.
nms
=
nms
if
isinstance
(
yolo_box
,
dict
):
self
.
yolo_box
=
YOLOBox
(
**
yolo_box
)
if
isinstance
(
nms
,
dict
):
self
.
nms
=
MultiClassNMS
(
**
nms
)
def
__call__
(
self
,
inputs
):
# TODO: split yolo_box into 2 steps
# decode
# clip
boxes_list
=
[]
scores_list
=
[]
for
i
,
out
in
enumerate
(
inputs
[
'yolo_outs'
]):
boxes
,
scores
=
self
.
yolo_box
(
out
,
inputs
[
'im_size'
],
inputs
[
'mask_anchors'
][
i
],
i
,
"yolo_box_"
+
str
(
i
))
boxes_list
.
append
(
boxes
)
scores_list
.
append
(
fluid
.
layers
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
yolo_boxes
=
fluid
.
layers
.
concat
(
boxes_list
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
scores_list
,
axis
=
2
)
nmsed_bbox
=
self
.
nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
)
# TODO: parse the lod of nmsed_bbox
# default batch size is 1
bbox_nums
=
np
.
array
([
0
,
int
(
nmsed_bbox
.
shape
[
0
])],
dtype
=
np
.
int32
)
outs
=
{
"predicted_bbox_nums"
:
bbox_nums
,
"predicted_bbox"
:
nmsed_bbox
}
return
outs
@
register
class
MaskPostProcess
(
object
):
__shared__
=
[
'num_classes'
]
...
...
@@ -58,20 +99,20 @@ class MaskPostProcess(object):
@
register
class
Anchor
(
object
):
class
Anchor
RPN
(
object
):
__inject__
=
[
'anchor_generator'
,
'anchor_target_generator'
]
def
__init__
(
self
,
anchor_type
=
'rpn'
,
anchor_generator
=
AnchorGenerator
().
__dict__
,
anchor_target_generator
=
RPNAnchorTargetGenerator
().
__dict__
):
super
(
Anchor
,
self
).
__init__
()
anchor_generator
=
AnchorGenerator
RPN
().
__dict__
,
anchor_target_generator
=
AnchorTargetGeneratorRPN
().
__dict__
):
super
(
Anchor
RPN
,
self
).
__init__
()
self
.
anchor_generator
=
anchor_generator
self
.
anchor_target_generator
=
anchor_target_generator
if
isinstance
(
anchor_generator
,
dict
):
self
.
anchor_generator
=
AnchorGenerator
(
**
anchor_generator
)
self
.
anchor_generator
=
AnchorGenerator
RPN
(
**
anchor_generator
)
if
isinstance
(
anchor_target_generator
,
dict
):
self
.
anchor_target_generator
=
RPNAnchorTargetGenerator
(
self
.
anchor_target_generator
=
AnchorTargetGeneratorRPN
(
**
anchor_target_generator
)
def
__call__
(
self
,
inputs
):
...
...
@@ -85,7 +126,6 @@ class Anchor(object):
return
outs
def
generate_anchors_target
(
self
,
inputs
):
# TODO: add yolo anchor targets
rpn_rois_score
=
fluid
.
layers
.
transpose
(
inputs
[
'rpn_rois_score'
],
perm
=
[
0
,
2
,
3
,
1
])
rpn_rois_delta
=
fluid
.
layers
.
transpose
(
...
...
@@ -96,7 +136,6 @@ class Anchor(object):
x
=
rpn_rois_delta
,
shape
=
(
0
,
-
1
,
4
))
anchor
=
fluid
.
layers
.
reshape
(
inputs
[
'anchor'
],
shape
=
(
-
1
,
4
))
#var = fluid.layers.reshape(inputs['var'], shape=(-1, 4))
score_pred
,
roi_pred
,
score_tgt
,
roi_tgt
,
roi_weight
=
self
.
anchor_target_generator
(
bbox_pred
=
rpn_rois_delta
,
...
...
@@ -114,9 +153,45 @@ class Anchor(object):
}
return
outs
def
post_process
(
self
,
):
# TODO: whether move bbox post process to here
pass
@
register
class
AnchorYOLO
(
object
):
__inject__
=
[
'anchor_generator'
,
'anchor_target_generator'
,
'anchor_post_process'
]
def
__init__
(
self
,
anchor_generator
=
AnchorGeneratorYOLO
().
__dict__
,
anchor_target_generator
=
AnchorTargetGeneratorYOLO
().
__dict__
,
anchor_post_process
=
BBoxPostProcessYOLO
().
__dict__
):
super
(
AnchorYOLO
,
self
).
__init__
()
self
.
anchor_generator
=
anchor_generator
self
.
anchor_target_generator
=
anchor_target_generator
self
.
anchor_post_process
=
anchor_post_process
if
isinstance
(
anchor_generator
,
dict
):
self
.
anchor_generator
=
AnchorGeneratorYOLO
(
**
anchor_generator
)
if
isinstance
(
anchor_target_generator
,
dict
):
self
.
anchor_target_generator
=
AnchorTargetGeneratorYOLO
(
**
anchor_target_generator
)
if
isinstance
(
anchor_post_process
,
dict
):
self
.
anchor_post_process
=
BBoxPostProcessYOLO
(
**
anchor_post_process
)
def
__call__
(
self
,
inputs
):
outs
=
self
.
generate_anchors
(
inputs
)
return
outs
def
generate_anchors
(
self
,
inputs
):
outs
=
self
.
anchor_generator
(
inputs
[
'yolo_outs'
])
outs
[
'anchor_module'
]
=
self
return
outs
def
generate_anchors_target
(
self
,
inputs
):
outs
=
self
.
anchor_target_generator
()
return
outs
def
post_process
(
self
,
inputs
):
return
self
.
anchor_post_process
(
inputs
)
@
register
...
...
ppdet/modeling/architecture/__init__.py
浏览文件 @
8ef4a22d
...
...
@@ -8,7 +8,9 @@
from
.
import
meta_arch
from
.
import
faster_rcnn
from
.
import
mask_rcnn
from
.
import
yolo
from
.meta_arch
import
*
from
.faster_rcnn
import
*
from
.mask_rcnn
import
*
from
.yolo
import
*
ppdet/modeling/architecture/faster_rcnn.py
浏览文件 @
8ef4a22d
...
...
@@ -27,7 +27,8 @@ class FasterRCNN(BaseArch):
backbone
,
rpn_head
,
bbox_head
,
rpn_only
=
False
):
rpn_only
=
False
,
mode
=
'train'
):
super
(
FasterRCNN
,
self
).
__init__
()
self
.
anchor
=
anchor
self
.
proposal
=
proposal
...
...
@@ -35,10 +36,11 @@ class FasterRCNN(BaseArch):
self
.
rpn_head
=
rpn_head
self
.
bbox_head
=
bbox_head
self
.
rpn_only
=
rpn_only
self
.
mode
=
mode
def
forward
(
self
,
inputs
,
inputs_keys
,
mode
=
'train'
):
def
forward
(
self
,
inputs
,
inputs_keys
):
self
.
gbd
=
self
.
build_inputs
(
inputs
,
inputs_keys
)
self
.
gbd
[
'mode'
]
=
mode
self
.
gbd
[
'mode'
]
=
self
.
mode
# Backbone
bb_out
=
self
.
backbone
(
self
.
gbd
)
...
...
@@ -89,8 +91,8 @@ class FasterRCNN(BaseArch):
def
infer
(
self
,
inputs
):
outs
=
{
"bbox_nums"
:
inputs
[
'predicted_bbox_nums'
].
numpy
(),
"bbox"
:
inputs
[
'predicted_bbox'
].
numpy
(),
"bbox_nums"
:
inputs
[
'predicted_bbox_nums'
].
numpy
(),
'im_id'
:
inputs
[
'im_id'
].
numpy
(),
'im_shape'
:
inputs
[
'im_shape'
].
numpy
()
}
...
...
ppdet/modeling/architecture/mask_rcnn.py
浏览文件 @
8ef4a22d
...
...
@@ -32,7 +32,8 @@ class MaskRCNN(BaseArch):
rpn_head
,
bbox_head
,
mask_head
,
rpn_only
=
False
):
rpn_only
=
False
,
mode
=
'train'
):
super
(
MaskRCNN
,
self
).
__init__
()
self
.
anchor
=
anchor
...
...
@@ -42,8 +43,9 @@ class MaskRCNN(BaseArch):
self
.
rpn_head
=
rpn_head
self
.
bbox_head
=
bbox_head
self
.
mask_head
=
mask_head
self
.
mode
=
mode
def
forward
(
self
,
inputs
,
inputs_keys
,
mode
=
'train'
):
def
forward
(
self
,
inputs
,
inputs_keys
):
self
.
gbd
=
self
.
build_inputs
(
inputs
,
inputs_keys
)
self
.
gbd
[
'mode'
]
=
mode
...
...
@@ -112,8 +114,8 @@ class MaskRCNN(BaseArch):
def
infer
(
self
,
inputs
):
outs
=
{
'bbox_nums'
:
inputs
[
'predicted_bbox_nums'
].
numpy
(),
'bbox'
:
inputs
[
'predicted_bbox'
].
numpy
(),
'bbox_nums'
:
inputs
[
'predicted_bbox_nums'
].
numpy
(),
'mask'
:
inputs
[
'predicted_mask'
].
numpy
(),
'im_id'
:
inputs
[
'im_id'
].
numpy
(),
'im_shape'
:
inputs
[
'im_shape'
].
numpy
()
...
...
ppdet/modeling/architecture/meta_arch.py
浏览文件 @
8ef4a22d
...
...
@@ -16,8 +16,9 @@ __all__ = ['BaseArch']
@
register
class
BaseArch
(
Layer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
mode
=
'train'
,
*
args
,
**
kwargs
):
super
(
BaseArch
,
self
).
__init__
()
self
.
mode
=
mode
def
forward
(
self
,
inputs
,
inputs_keys
,
mode
=
'train'
):
raise
NotImplementedError
(
"Should implement forward method!"
)
...
...
ppdet/modeling/architecture/yolo.py
0 → 100644
浏览文件 @
8ef4a22d
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
from
ppdet.core.workspace
import
register
from
.meta_arch
import
BaseArch
__all__
=
[
'YOLOv3'
]
@
register
class
YOLOv3
(
BaseArch
):
__category__
=
'architecture'
__inject__
=
[
'anchor'
,
'backbone'
,
'yolo_head'
,
]
def
__init__
(
self
,
anchor
,
backbone
,
yolo_head
,
mode
=
'train'
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
anchor
=
anchor
self
.
backbone
=
backbone
self
.
yolo_head
=
yolo_head
self
.
mode
=
mode
def
forward
(
self
,
inputs
,
inputs_keys
):
self
.
gbd
=
self
.
build_inputs
(
inputs
,
inputs_keys
)
self
.
gbd
[
'mode'
]
=
self
.
mode
# Backbone
bb_out
=
self
.
backbone
(
self
.
gbd
)
self
.
gbd
.
update
(
bb_out
)
# YOLO Head
yolo_head_out
=
self
.
yolo_head
(
self
.
gbd
)
self
.
gbd
.
update
(
yolo_head_out
)
# Anchor
anchor_out
=
self
.
anchor
(
self
.
gbd
)
self
.
gbd
.
update
(
anchor_out
)
if
self
.
gbd
[
'mode'
]
==
'infer'
:
bbox_out
=
self
.
anchor
.
post_process
(
self
.
gbd
)
self
.
gbd
.
update
(
bbox_out
)
# result
if
self
.
gbd
[
'mode'
]
==
'train'
:
return
self
.
loss
(
self
.
gbd
)
elif
self
.
gbd
[
'mode'
]
==
'infer'
:
return
self
.
infer
(
self
.
gbd
)
else
:
raise
"Now, only support train or infer mode!"
def
loss
(
self
,
inputs
):
yolo_loss
=
self
.
yolo_head
.
loss
(
inputs
)
out
=
{
'loss'
:
yolo_loss
,
}
return
out
def
infer
(
self
,
inputs
):
outs
=
{
"bbox"
:
inputs
[
'predicted_bbox'
].
numpy
(),
"bbox_nums"
:
inputs
[
'predicted_bbox_nums'
]
}
print
(
outs
[
'bbox_nums'
])
return
outs
ppdet/modeling/backbone/__init__.py
浏览文件 @
8ef4a22d
from
.
import
resnet
from
.
import
darknet
from
.resnet
import
*
from
.darknet
import
*
ppdet/modeling/backbone/darknet.py
0 → 100755
浏览文件 @
8ef4a22d
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'DarkNet'
,
'ConvBNLayer'
]
class
ConvBNLayer
(
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
filter_size
=
3
,
stride
=
1
,
groups
=
1
,
padding
=
0
,
act
=
"leaky"
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
Conv2D
(
num_channels
=
ch_in
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
)),
bias_attr
=
False
,
act
=
None
)
self
.
batch_norm
=
BatchNorm
(
num_channels
=
ch_out
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
)))
self
.
act
=
act
def
forward
(
self
,
inputs
):
out
=
self
.
conv
(
inputs
)
out
=
self
.
batch_norm
(
out
)
if
self
.
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
class
DownSample
(
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
):
super
(
DownSample
,
self
).
__init__
()
self
.
conv_bn_layer
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
):
out
=
self
.
conv_bn_layer
(
inputs
)
return
out
class
BasicBlock
(
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
ch_out
,
ch_out
=
ch_out
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
inputs
):
conv1
=
self
.
conv1
(
inputs
)
conv2
=
self
.
conv2
(
conv1
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
inputs
,
y
=
conv2
,
act
=
None
)
return
out
class
Blocks
(
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
count
):
super
(
Blocks
,
self
).
__init__
()
self
.
basicblock0
=
BasicBlock
(
ch_in
,
ch_out
)
self
.
res_out_list
=
[]
for
i
in
range
(
1
,
count
):
res_out
=
self
.
add_sublayer
(
"basic_block_%d"
%
(
i
),
BasicBlock
(
ch_out
*
2
,
ch_out
))
self
.
res_out_list
.
append
(
res_out
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
):
y
=
self
.
basicblock0
(
inputs
)
for
basic_block_i
in
self
.
res_out_list
:
y
=
basic_block_i
(
y
)
return
y
DarkNet_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
])}
@
register
@
serializable
class
DarkNet
(
Layer
):
def
__init__
(
self
,
depth
=
53
,
mode
=
'train'
):
super
(
DarkNet
,
self
).
__init__
()
self
.
depth
=
depth
self
.
mode
=
mode
self
.
stages
=
DarkNet_cfg
[
self
.
depth
][
0
:
5
]
self
.
conv0
=
ConvBNLayer
(
ch_in
=
3
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
downsample0
=
DownSample
(
ch_in
=
32
,
ch_out
=
32
*
2
)
self
.
darknet53_conv_block_list
=
[]
self
.
downsample_list
=
[]
ch_in
=
[
64
,
128
,
256
,
512
,
1024
]
for
i
,
stage
in
enumerate
(
self
.
stages
):
conv_block
=
self
.
add_sublayer
(
"stage_%d"
%
(
i
),
Blocks
(
int
(
ch_in
[
i
]),
32
*
(
2
**
i
),
stage
))
self
.
darknet53_conv_block_list
.
append
(
conv_block
)
for
i
in
range
(
len
(
self
.
stages
)
-
1
):
downsample
=
self
.
add_sublayer
(
"stage_%d_downsample"
%
i
,
DownSample
(
ch_in
=
32
*
(
2
**
(
i
+
1
)),
ch_out
=
32
*
(
2
**
(
i
+
2
))))
self
.
downsample_list
.
append
(
downsample
)
def
forward
(
self
,
inputs
):
x
=
inputs
[
'image'
]
out
=
self
.
conv0
(
x
)
out
=
self
.
downsample0
(
out
)
blocks
=
[]
for
i
,
conv_block_i
in
enumerate
(
self
.
darknet53_conv_block_list
):
out
=
conv_block_i
(
out
)
blocks
.
append
(
out
)
if
i
<
len
(
self
.
stages
)
-
1
:
out
=
self
.
downsample_list
[
i
](
out
)
outs
=
{
'darknet_outs'
:
blocks
[
-
1
:
-
4
:
-
1
]}
return
outs
ppdet/modeling/head/__init__.py
浏览文件 @
8ef4a22d
from
.
import
rpn_head
from
.
import
bbox_head
from
.
import
mask_head
from
.
import
yolo_head
from
.rpn_head
import
*
from
.bbox_head
import
*
from
.mask_head
import
*
from
.yolo_head
import
*
ppdet/modeling/head/rpn_head.py
浏览文件 @
8ef4a22d
...
...
@@ -6,14 +6,13 @@ from paddle.fluid.regularizer import L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
from
ppdet.core.workspace
import
register
from
..ops
import
RPNAnchorTargetGenerator
@
register
class
RPNFeat
(
Layer
):
def
__init__
(
self
,
feat_in
=
1024
,
feat_out
=
1024
):
super
(
RPNFeat
,
self
).
__init__
()
self
.
rpn_conv
=
fluid
.
dygraph
.
Conv2D
(
self
.
rpn_conv
=
Conv2D
(
num_channels
=
1024
,
num_filters
=
1024
,
filter_size
=
3
,
...
...
@@ -45,7 +44,7 @@ class RPNHead(Layer):
self
.
rpn_feat
=
RPNFeat
(
**
rpn_feat
)
# rpn roi classification scores
self
.
rpn_rois_score
=
fluid
.
dygraph
.
Conv2D
(
self
.
rpn_rois_score
=
Conv2D
(
num_channels
=
1024
,
num_filters
=
1
*
self
.
anchor_per_position
,
filter_size
=
1
,
...
...
@@ -61,7 +60,7 @@ class RPNHead(Layer):
regularizer
=
L2Decay
(
0.
)))
# rpn roi bbox regression deltas
self
.
rpn_rois_delta
=
fluid
.
dygraph
.
Conv2D
(
self
.
rpn_rois_delta
=
Conv2D
(
num_channels
=
1024
,
num_filters
=
4
*
self
.
anchor_per_position
,
filter_size
=
1
,
...
...
ppdet/modeling/head/yolo_head.py
0 → 100644
浏览文件 @
8ef4a22d
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Normal
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
ppdet.core.workspace
import
register
from
..backbone.darknet
import
ConvBNLayer
class
YoloDetBlock
(
Layer
):
def
__init__
(
self
,
ch_in
,
channel
):
super
(
YoloDetBlock
,
self
).
__init__
()
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2"
.
format
(
channel
)
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv1
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv3
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
route
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
tip
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
inputs
):
out
=
self
.
conv0
(
inputs
)
out
=
self
.
conv1
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
conv3
(
out
)
route
=
self
.
route
(
out
)
tip
=
self
.
tip
(
route
)
return
route
,
tip
class
Upsample
(
Layer
):
def
__init__
(
self
,
scale
=
2
):
super
(
Upsample
,
self
).
__init__
()
self
.
scale
=
scale
def
forward
(
self
,
inputs
):
# get dynamic upsample output shape
shape_nchw
=
fluid
.
layers
.
shape
(
inputs
)
shape_hw
=
fluid
.
layers
.
slice
(
shape_nchw
,
axes
=
[
0
],
starts
=
[
2
],
ends
=
[
4
])
shape_hw
.
stop_gradient
=
True
in_shape
=
fluid
.
layers
.
cast
(
shape_hw
,
dtype
=
'int32'
)
out_shape
=
in_shape
*
self
.
scale
out_shape
.
stop_gradient
=
True
# reisze by actual_shape
out
=
fluid
.
layers
.
resize_nearest
(
input
=
inputs
,
scale
=
self
.
scale
,
actual_shape
=
out_shape
)
return
out
@
register
class
YOLOFeat
(
Layer
):
def
__init__
(
self
,
feat_in_list
=
[
1024
,
768
,
384
]):
super
(
YOLOFeat
,
self
).
__init__
()
self
.
feat_in_list
=
feat_in_list
self
.
yolo_blocks
=
[]
self
.
route_blocks
=
[]
for
i
in
range
(
3
):
yolo_block
=
self
.
add_sublayer
(
"yolo_det_block_%d"
%
(
i
),
YoloDetBlock
(
feat_in_list
[
i
],
channel
=
512
//
(
2
**
i
)))
self
.
yolo_blocks
.
append
(
yolo_block
)
if
i
<
2
:
route
=
self
.
add_sublayer
(
"route_%d"
%
i
,
ConvBNLayer
(
ch_in
=
512
//
(
2
**
i
),
ch_out
=
256
//
(
2
**
i
),
filter_size
=
1
,
stride
=
1
,
padding
=
0
))
self
.
route_blocks
.
append
(
route
)
self
.
upsample
=
Upsample
()
def
forward
(
self
,
inputs
):
yolo_feats
=
[]
for
i
,
block
in
enumerate
(
inputs
[
'darknet_outs'
]):
if
i
>
0
:
block
=
fluid
.
layers
.
concat
(
input
=
[
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
i
](
block
)
yolo_feats
.
append
(
tip
)
if
i
<
2
:
route
=
self
.
route_blocks
[
i
](
route
)
route
=
self
.
upsample
(
route
)
outs
=
{
'yolo_feat'
:
yolo_feats
}
return
outs
@
register
class
YOLOv3Head
(
Layer
):
__shared__
=
[
'num_classes'
]
__inject__
=
[
'yolo_feat'
]
def
__init__
(
self
,
num_classes
=
80
,
anchor_per_position
=
3
,
mode
=
'train'
,
yolo_feat
=
YOLOFeat
().
__dict__
,
):
super
(
YOLOv3Head
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
anchor_per_position
=
anchor_per_position
self
.
mode
=
mode
self
.
yolo_feat
=
yolo_feat
if
isinstance
(
yolo_feat
,
dict
):
self
.
yolo_feat
=
YOLOFeat
(
**
yolo_feat
)
self
.
yolo_outs
=
[]
for
i
in
range
(
3
):
# TODO: optim here
#num_filters = len(cfg.anchor_masks[i]) * (self.num_classes + 5)
num_filters
=
self
.
anchor_per_position
*
(
self
.
num_classes
+
5
)
yolo_out
=
self
.
add_sublayer
(
"yolo_out_%d"
%
(
i
),
Conv2D
(
num_channels
=
1024
//
(
2
**
i
),
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
)),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
))))
self
.
yolo_outs
.
append
(
yolo_out
)
def
forward
(
self
,
inputs
):
outs
=
self
.
yolo_feat
(
inputs
)
x
=
outs
[
'yolo_feat'
]
yolo_out_list
=
[]
for
i
,
yolo_f
in
enumerate
(
x
):
yolo_out
=
self
.
yolo_outs
[
i
](
yolo_f
)
yolo_out_list
.
append
(
yolo_out
)
outs
.
update
({
"yolo_outs"
:
yolo_out_list
})
return
outs
def
loss
(
self
,
inputs
):
if
callable
(
inputs
[
'anchor_module'
]):
yolo_targets
=
inputs
[
'anchor_module'
].
generate_anchors_target
(
inputs
)
yolo_losses
=
[]
for
i
,
out
in
enumerate
(
inputs
[
'yolo_outs'
]):
# TODO: split yolov3_loss into small ops
# 1. compute target 2. loss
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
out
,
gt_box
=
inputs
[
'gt_bbox'
],
gt_label
=
inputs
[
'gt_class'
],
gt_score
=
inputs
[
'gt_score'
],
anchors
=
inputs
[
'anchors'
],
anchor_mask
=
inputs
[
'anchor_masks'
][
i
],
class_num
=
self
.
num_classes
,
ignore_thresh
=
yolo_targets
[
'ignore_thresh'
],
downsample_ratio
=
yolo_targets
[
'downsample_ratio'
]
//
2
**
i
,
use_label_smooth
=
yolo_targets
[
'label_smooth'
],
name
=
'yolo_loss_'
+
str
(
i
))
loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
yolo_losses
.
append
(
loss
)
yolo_loss
=
sum
(
yolo_losses
)
return
yolo_loss
ppdet/modeling/ops.py
浏览文件 @
8ef4a22d
...
...
@@ -9,13 +9,13 @@ from ppdet.py_op.post_process import bbox_post_process
@
register
@
serializable
class
AnchorGenerator
(
object
):
class
AnchorGenerator
RPN
(
object
):
def
__init__
(
self
,
anchor_sizes
=
[
32
,
64
,
128
,
256
,
512
],
aspect_ratios
=
[
0.5
,
1.0
,
2.0
],
stride
=
[
16.0
,
16.0
],
variance
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
super
(
AnchorGenerator
,
self
).
__init__
()
super
(
AnchorGenerator
RPN
,
self
).
__init__
()
self
.
anchor_sizes
=
anchor_sizes
self
.
aspect_ratios
=
aspect_ratios
self
.
stride
=
stride
...
...
@@ -33,7 +33,7 @@ class AnchorGenerator(object):
@
register
@
serializable
class
RPNAnchorTargetGenerator
(
object
):
class
AnchorTargetGeneratorRPN
(
object
):
def
__init__
(
self
,
batch_size_per_im
=
256
,
straddle_thresh
=
0.
,
...
...
@@ -41,7 +41,7 @@ class RPNAnchorTargetGenerator(object):
positive_overlap
=
0.7
,
negative_overlap
=
0.3
,
use_random
=
True
):
super
(
RPNAnchorTargetGenerator
,
self
).
__init__
()
super
(
AnchorTargetGeneratorRPN
,
self
).
__init__
()
self
.
batch_size_per_im
=
batch_size_per_im
self
.
straddle_thresh
=
straddle_thresh
self
.
fg_fraction
=
fg_fraction
...
...
@@ -79,6 +79,57 @@ class RPNAnchorTargetGenerator(object):
return
pred_cls_logits
,
pred_bbox_pred
,
tgt_labels
,
tgt_bboxes
,
bbox_inside_weights
@
register
@
serializable
class
AnchorGeneratorYOLO
(
object
):
def
__init__
(
self
,
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
],
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]):
super
(
AnchorGeneratorYOLO
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
def
__call__
(
self
,
yolo_outs
):
mask_anchors
=
[]
for
i
,
_
in
enumerate
(
yolo_outs
):
mask_anchor
=
[]
for
m
in
self
.
anchor_masks
[
i
]:
mask_anchor
.
append
(
self
.
anchors
[
2
*
m
])
mask_anchor
.
append
(
self
.
anchors
[
2
*
m
+
1
])
mask_anchors
.
append
(
mask_anchor
)
outs
=
{
"anchors"
:
self
.
anchors
,
"anchor_masks"
:
self
.
anchor_masks
,
"mask_anchors"
:
mask_anchors
}
return
outs
@
register
@
serializable
class
AnchorTargetGeneratorYOLO
(
object
):
def
__init__
(
self
,
ignore_thresh
=
0.7
,
downsample_ratio
=
32
,
label_smooth
=
True
):
super
(
AnchorTargetGeneratorYOLO
,
self
).
__init__
()
self
.
ignore_thresh
=
ignore_thresh
self
.
downsample_ratio
=
downsample_ratio
self
.
label_smooth
=
label_smooth
def
__call__
(
self
,
):
# TODO: split yolov3_loss into here
outs
=
{
'ignore_thresh'
:
self
.
ignore_thresh
,
'downsample_ratio'
:
self
.
downsample_ratio
,
'label_smooth'
:
self
.
label_smooth
}
return
outs
@
register
@
serializable
class
ProposalGenerator
(
object
):
...
...
@@ -284,6 +335,54 @@ class DecodeClipNms(object):
return
outs
@
register
@
serializable
class
MultiClassNMS
(
object
):
__op__
=
fluid
.
layers
.
multiclass_nms
__append_doc__
=
True
def
__init__
(
self
,
score_threshold
=
.
05
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
nms_threshold
=
.
5
,
normalized
=
False
,
nms_eta
=
1.0
,
background_label
=
0
):
super
(
MultiClassNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
nms_threshold
=
nms_threshold
self
.
normalized
=
normalized
self
.
nms_eta
=
nms_eta
self
.
background_label
=
background_label
@
register
@
serializable
class
YOLOBox
(
object
):
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
num_classes
=
80
,
conf_thresh
=
0.005
,
downsample_ratio
=
32
,
clip_bbox
=
True
,
):
self
.
num_classes
=
num_classes
self
.
conf_thresh
=
conf_thresh
self
.
downsample_ratio
=
downsample_ratio
self
.
clip_bbox
=
clip_bbox
def
__call__
(
self
,
x
,
img_size
,
anchors
,
stage
=
0
,
name
=
None
):
outs
=
fluid
.
layers
.
yolo_box
(
x
,
img_size
,
anchors
,
self
.
num_classes
,
self
.
conf_thresh
,
self
.
downsample_ratio
//
2
**
stage
,
self
.
clip_bbox
,
name
)
return
outs
@
register
@
serializable
class
AnchorGrid
(
object
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录