Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
06e6afcf
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
06e6afcf
编写于
11月 06, 2020
作者:
W
wangguanzhong
提交者:
GitHub
11月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update roi extractor & post_process (#1664)
上级
8a878423
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
358 addition
and
271 deletion
+358
-271
configs/mask_rcnn_r50_fpn_1x.yml
configs/mask_rcnn_r50_fpn_1x.yml
+46
-40
configs/yolov3_darknet.yml
configs/yolov3_darknet.yml
+18
-16
ppdet/modeling/__init__.py
ppdet/modeling/__init__.py
+2
-0
ppdet/modeling/architecture/mask_rcnn.py
ppdet/modeling/architecture/mask_rcnn.py
+11
-3
ppdet/modeling/architecture/yolo.py
ppdet/modeling/architecture/yolo.py
+6
-4
ppdet/modeling/bbox.py
ppdet/modeling/bbox.py
+6
-112
ppdet/modeling/head/__init__.py
ppdet/modeling/head/__init__.py
+2
-0
ppdet/modeling/head/bbox_head.py
ppdet/modeling/head/bbox_head.py
+39
-1
ppdet/modeling/head/roi_extractor.py
ppdet/modeling/head/roi_extractor.py
+72
-0
ppdet/modeling/layers.py
ppdet/modeling/layers.py
+100
-62
ppdet/modeling/mask.py
ppdet/modeling/mask.py
+2
-30
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+3
-2
ppdet/modeling/post_process.py
ppdet/modeling/post_process.py
+50
-0
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+1
-1
未找到文件。
configs/mask_rcnn_r50_fpn_1x.yml
浏览文件 @
06e6afcf
...
@@ -13,7 +13,7 @@ load_static_weights: True
...
@@ -13,7 +13,7 @@ load_static_weights: True
# Model Achitecture
# Model Achitecture
MaskRCNN
:
MaskRCNN
:
# model anchor info flow
# model anchor info flow
anchor
:
Anchor
RPN
anchor
:
Anchor
proposal
:
Proposal
proposal
:
Proposal
mask
:
Mask
mask
:
Mask
# model feat info flow
# model feat info flow
...
@@ -22,6 +22,9 @@ MaskRCNN:
...
@@ -22,6 +22,9 @@ MaskRCNN:
rpn_head
:
RPNHead
rpn_head
:
RPNHead
bbox_head
:
BBoxHead
bbox_head
:
BBoxHead
mask_head
:
MaskHead
mask_head
:
MaskHead
# post process
bbox_post_process
:
BBoxPostProcess
mask_post_process
:
MaskPostProcess
ResNet
:
ResNet
:
# index 0 stands for res2
# index 0 stands for res2
...
@@ -38,7 +41,6 @@ FPN:
...
@@ -38,7 +41,6 @@ FPN:
max_level
:
4
max_level
:
4
spatial_scale
:
[
0.25
,
0.125
,
0.0625
,
0.03125
]
spatial_scale
:
[
0.25
,
0.125
,
0.0625
,
0.03125
]
RPNHead
:
RPNHead
:
rpn_feat
:
rpn_feat
:
name
:
RPNFeat
name
:
RPNFeat
...
@@ -47,33 +49,7 @@ RPNHead:
...
@@ -47,33 +49,7 @@ RPNHead:
anchor_per_position
:
3
anchor_per_position
:
3
rpn_channel
:
256
rpn_channel
:
256
BBoxHead
:
Anchor
:
bbox_feat
:
name
:
BBoxFeat
roi_extractor
:
name
:
RoIExtractor
resolution
:
7
sampling_ratio
:
2
head_feat
:
name
:
TwoFCHead
in_dim
:
256
mlp_dim
:
1024
in_feat
:
1024
MaskHead
:
mask_feat
:
name
:
MaskFeat
num_convs
:
4
feat_in
:
256
feat_out
:
256
mask_roi_extractor
:
name
:
RoIExtractor
resolution
:
14
sampling_ratio
:
2
share_bbox_feat
:
False
feat_in
:
256
AnchorRPN
:
anchor_generator
:
anchor_generator
:
name
:
AnchorGeneratorRPN
name
:
AnchorGeneratorRPN
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
...
@@ -104,11 +80,27 @@ Proposal:
...
@@ -104,11 +80,27 @@ Proposal:
bg_thresh_lo
:
[
0.0
,]
bg_thresh_lo
:
[
0.0
,]
fg_thresh
:
[
0.5
,]
fg_thresh
:
[
0.5
,]
fg_fraction
:
0.25
fg_fraction
:
0.25
bbox_post_process
:
# used in infer
name
:
BBoxPostProcess
BBoxHead
:
# decode -> clip -> nms
bbox_feat
:
decode_clip_nms
:
name
:
BBoxFeat
name
:
DecodeClipNms
roi_extractor
:
name
:
RoIAlign
resolution
:
7
sampling_ratio
:
2
head_feat
:
name
:
TwoFCHead
in_dim
:
256
mlp_dim
:
1024
in_feat
:
1024
BBoxPostProcess
:
decode
:
name
:
RCNNBox
num_classes
:
81
batch_size
:
1
nms
:
name
:
MultiClassNMS
keep_top_k
:
100
keep_top_k
:
100
score_threshold
:
0.05
score_threshold
:
0.05
nms_threshold
:
0.5
nms_threshold
:
0.5
...
@@ -117,8 +109,22 @@ Mask:
...
@@ -117,8 +109,22 @@ Mask:
mask_target_generator
:
mask_target_generator
:
name
:
MaskTargetGenerator
name
:
MaskTargetGenerator
mask_resolution
:
28
mask_resolution
:
28
mask_post_process
:
name
:
MaskPostProcess
MaskHead
:
mask_feat
:
name
:
MaskFeat
num_convs
:
4
feat_in
:
256
feat_out
:
256
mask_roi_extractor
:
name
:
RoIAlign
resolution
:
14
sampling_ratio
:
2
share_bbox_feat
:
False
feat_in
:
256
MaskPostProcess
:
mask_resolution
:
28
mask_resolution
:
28
...
...
configs/yolov3_darknet.yml
浏览文件 @
06e6afcf
...
@@ -15,6 +15,7 @@ YOLOv3:
...
@@ -15,6 +15,7 @@ YOLOv3:
anchor
:
AnchorYOLO
anchor
:
AnchorYOLO
backbone
:
DarkNet
backbone
:
DarkNet
yolo_head
:
YOLOv3Head
yolo_head
:
YOLOv3Head
post_process
:
BBoxPostProcess
DarkNet
:
DarkNet
:
depth
:
53
depth
:
53
...
@@ -29,15 +30,8 @@ YOLOv3Head:
...
@@ -29,15 +30,8 @@ YOLOv3Head:
label_smooth
:
true
label_smooth
:
true
anchor_per_position
:
3
anchor_per_position
:
3
AnchorYOLO
:
BBoxPostProcess
:
anchor_generator
:
decode
:
name
:
AnchorGeneratorYOLO
anchors
:
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
anchor_post_process
:
name
:
BBoxPostProcessYOLO
# decode -> clip
yolo_box
:
name
:
YOLOBox
name
:
YOLOBox
conf_thresh
:
0.005
conf_thresh
:
0.005
downsample_ratio
:
32
downsample_ratio
:
32
...
@@ -51,6 +45,14 @@ AnchorYOLO:
...
@@ -51,6 +45,14 @@ AnchorYOLO:
normalized
:
false
normalized
:
false
background_label
:
-1
background_label
:
-1
AnchorYOLO
:
anchor_generator
:
name
:
AnchorGeneratorYOLO
anchors
:
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
anchor_masks
:
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
LearningRate
:
LearningRate
:
base_lr
:
0.001
base_lr
:
0.001
schedulers
:
schedulers
:
...
...
ppdet/modeling/__init__.py
浏览文件 @
06e6afcf
...
@@ -5,6 +5,7 @@ from . import backbone
...
@@ -5,6 +5,7 @@ from . import backbone
from
.
import
neck
from
.
import
neck
from
.
import
head
from
.
import
head
from
.
import
architecture
from
.
import
architecture
from
.
import
post_process
from
.ops
import
*
from
.ops
import
*
from
.bbox
import
*
from
.bbox
import
*
...
@@ -13,3 +14,4 @@ from .backbone import *
...
@@ -13,3 +14,4 @@ from .backbone import *
from
.neck
import
*
from
.neck
import
*
from
.head
import
*
from
.head
import
*
from
.architecture
import
*
from
.architecture
import
*
from
.post_process
import
*
ppdet/modeling/architecture/mask_rcnn.py
浏览文件 @
06e6afcf
...
@@ -21,6 +21,8 @@ class MaskRCNN(BaseArch):
...
@@ -21,6 +21,8 @@ class MaskRCNN(BaseArch):
'rpn_head'
,
'rpn_head'
,
'bbox_head'
,
'bbox_head'
,
'mask_head'
,
'mask_head'
,
'bbox_post_process'
,
'mask_post_process'
,
]
]
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -31,6 +33,8 @@ class MaskRCNN(BaseArch):
...
@@ -31,6 +33,8 @@ class MaskRCNN(BaseArch):
rpn_head
,
rpn_head
,
bbox_head
,
bbox_head
,
mask_head
,
mask_head
,
bbox_post_process
,
mask_post_process
,
neck
=
None
):
neck
=
None
):
super
(
MaskRCNN
,
self
).
__init__
()
super
(
MaskRCNN
,
self
).
__init__
()
self
.
anchor
=
anchor
self
.
anchor
=
anchor
...
@@ -41,6 +45,8 @@ class MaskRCNN(BaseArch):
...
@@ -41,6 +45,8 @@ class MaskRCNN(BaseArch):
self
.
rpn_head
=
rpn_head
self
.
rpn_head
=
rpn_head
self
.
bbox_head
=
bbox_head
self
.
bbox_head
=
bbox_head
self
.
mask_head
=
mask_head
self
.
mask_head
=
mask_head
self
.
bbox_post_process
=
bbox_post_process
self
.
mask_post_process
=
mask_post_process
def
model_arch
(
self
):
def
model_arch
(
self
):
# Backbone
# Backbone
...
@@ -72,9 +78,11 @@ class MaskRCNN(BaseArch):
...
@@ -72,9 +78,11 @@ class MaskRCNN(BaseArch):
rois_has_mask_int32
=
None
rois_has_mask_int32
=
None
if
self
.
inputs
[
'mode'
]
==
'infer'
:
if
self
.
inputs
[
'mode'
]
==
'infer'
:
# Refine bbox by the output from bbox_head at test stage
bbox_pred
,
bboxes
=
self
.
bbox_head
.
get_prediction
(
self
.
bboxes
=
self
.
proposal
.
post_process
(
self
.
inputs
,
self
.
bbox_head_out
,
rois
)
self
.
bbox_head_out
,
rois
)
# Refine bbox by the output from bbox_head at test stage
self
.
bboxes
=
self
.
bbox_post_process
(
bbox_pred
,
bboxes
,
self
.
inputs
[
'im_info'
])
else
:
else
:
# Proposal RoI for Mask branch
# Proposal RoI for Mask branch
# bboxes update at training stage only
# bboxes update at training stage only
...
@@ -111,7 +119,7 @@ class MaskRCNN(BaseArch):
...
@@ -111,7 +119,7 @@ class MaskRCNN(BaseArch):
return
loss
return
loss
def
infer
(
self
,
):
def
infer
(
self
,
):
mask
=
self
.
mask
.
post_process
(
self
.
bboxes
,
self
.
mask_head_out
,
mask
=
self
.
mask
_
post_process
(
self
.
bboxes
,
self
.
mask_head_out
,
self
.
inputs
[
'im_info'
])
self
.
inputs
[
'im_info'
])
bbox
,
bbox_num
=
self
.
bboxes
bbox
,
bbox_num
=
self
.
bboxes
output
=
{
output
=
{
...
...
ppdet/modeling/architecture/yolo.py
浏览文件 @
06e6afcf
...
@@ -15,13 +15,15 @@ class YOLOv3(BaseArch):
...
@@ -15,13 +15,15 @@ class YOLOv3(BaseArch):
'anchor'
,
'anchor'
,
'backbone'
,
'backbone'
,
'yolo_head'
,
'yolo_head'
,
'post_process'
,
]
]
def
__init__
(
self
,
anchor
,
backbone
,
yolo_head
):
def
__init__
(
self
,
anchor
,
backbone
,
yolo_head
,
post_process
):
super
(
YOLOv3
,
self
).
__init__
()
super
(
YOLOv3
,
self
).
__init__
()
self
.
anchor
=
anchor
self
.
anchor
=
anchor
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
yolo_head
=
yolo_head
self
.
yolo_head
=
yolo_head
self
.
post_process
=
post_process
def
model_arch
(
self
,
):
def
model_arch
(
self
,
):
# Backbone
# Backbone
...
@@ -40,11 +42,11 @@ class YOLOv3(BaseArch):
...
@@ -40,11 +42,11 @@ class YOLOv3(BaseArch):
return
yolo_loss
return
yolo_loss
def
infer
(
self
,
):
def
infer
(
self
,
):
bbox
,
bbox_num
=
self
.
anchor
.
post_process
(
bbox
,
bbox_num
=
self
.
post_process
(
self
.
inputs
[
'im_size'
],
self
.
yolo_head_out
,
self
.
mask_anchors
)
self
.
yolo_head_out
,
self
.
mask_anchors
,
self
.
inputs
[
'im_size'
]
)
outs
=
{
outs
=
{
"bbox"
:
bbox
.
numpy
(),
"bbox"
:
bbox
.
numpy
(),
"bbox_num"
:
bbox_num
,
"bbox_num"
:
bbox_num
.
numpy
()
,
'im_id'
:
self
.
inputs
[
'im_id'
].
numpy
()
'im_id'
:
self
.
inputs
[
'im_id'
].
numpy
()
}
}
return
outs
return
outs
ppdet/modeling/bbox.py
浏览文件 @
06e6afcf
...
@@ -8,105 +8,11 @@ from . import ops
...
@@ -8,105 +8,11 @@ from . import ops
@
register
@
register
class
BBoxPostProcess
(
object
):
class
Anchor
(
object
):
__shared__
=
[
'num_classes'
]
__inject__
=
[
'decode_clip_nms'
]
def
__init__
(
self
,
decode_clip_nms
,
num_classes
=
81
,
cls_agnostic
=
False
,
decode
=
None
,
clip
=
None
,
nms
=
None
,
score_stage
=
[
0
,
1
,
2
],
delta_stage
=
[
2
]):
super
(
BBoxPostProcess
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
decode
=
decode
self
.
clip
=
clip
self
.
nms
=
nms
self
.
decode_clip_nms
=
decode_clip_nms
self
.
score_stage
=
score_stage
self
.
delta_stage
=
delta_stage
self
.
out_dim
=
2
if
cls_agnostic
else
num_classes
self
.
cls_agnostic
=
cls_agnostic
def
__call__
(
self
,
inputs
,
bboxheads
,
rois
):
# TODO: split into 3 steps
# TODO: modify related ops for deploying
# decode
# clip
# nms
if
isinstance
(
rois
,
tuple
):
proposal
,
proposal_num
=
rois
score
,
delta
=
bboxheads
[
0
]
bbox_prob
=
fluid
.
layers
.
softmax
(
score
)
delta
=
fluid
.
layers
.
reshape
(
delta
,
(
-
1
,
self
.
out_dim
,
4
))
else
:
num_stage
=
len
(
rois
)
proposal_list
=
[]
prob_list
=
[]
delta_list
=
[]
for
stage
,
(
proposals
,
bboxhead
)
in
zip
(
rois
,
bboxheads
):
score
,
delta
=
bboxhead
proposal
,
proposal_num
=
proposals
if
stage
in
self
.
score_stage
:
bbox_prob
=
fluid
.
layers
.
softmax
(
score
)
prob_list
.
append
(
bbox_prob
)
if
stage
in
self
.
delta_stage
:
proposal_list
.
append
(
proposal
)
delta_list
.
append
(
delta
)
bbox_prob
=
fluid
.
layers
.
mean
(
prob_list
)
delta
=
fluid
.
layers
.
mean
(
delta_list
)
proposal
=
fluid
.
layers
.
mean
(
proposal_list
)
delta
=
fluid
.
layers
.
reshape
(
delta
,
(
-
1
,
self
.
out_dim
,
4
))
if
self
.
cls_agnostic
:
delta
=
delta
[:,
1
:
2
,
:]
delta
=
fluid
.
layers
.
expand
(
delta
,
[
1
,
self
.
num_classes
,
1
])
bboxes
=
(
proposal
,
proposal_num
)
bboxes
,
bbox_nums
=
self
.
decode_clip_nms
(
bboxes
,
bbox_prob
,
delta
,
inputs
[
'im_info'
])
return
bboxes
,
bbox_nums
@
register
class
BBoxPostProcessYOLO
(
object
):
__shared__
=
[
'num_classes'
]
__inject__
=
[
'yolo_box'
,
'nms'
]
def
__init__
(
self
,
yolo_box
,
nms
,
num_classes
=
80
,
decode
=
None
,
clip
=
None
):
super
(
BBoxPostProcessYOLO
,
self
).
__init__
()
self
.
yolo_box
=
yolo_box
self
.
nms
=
nms
self
.
num_classes
=
num_classes
self
.
decode
=
decode
self
.
clip
=
clip
def
__call__
(
self
,
im_size
,
yolo_head_out
,
mask_anchors
):
# TODO: split yolo_box into 2 steps
# decode
# clip
boxes_list
=
[]
scores_list
=
[]
for
i
,
head_out
in
enumerate
(
yolo_head_out
):
boxes
,
scores
=
self
.
yolo_box
(
head_out
,
im_size
,
mask_anchors
[
i
],
self
.
num_classes
,
i
)
boxes_list
.
append
(
boxes
)
scores_list
.
append
(
paddle
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
yolo_boxes
=
paddle
.
concat
(
boxes_list
,
axis
=
1
)
yolo_scores
=
paddle
.
concat
(
scores_list
,
axis
=
2
)
bbox
,
bbox_num
=
self
.
nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
)
return
bbox
,
bbox_num
@
register
class
AnchorRPN
(
object
):
__inject__
=
[
'anchor_generator'
,
'anchor_target_generator'
]
__inject__
=
[
'anchor_generator'
,
'anchor_target_generator'
]
def
__init__
(
self
,
anchor_generator
,
anchor_target_generator
):
def
__init__
(
self
,
anchor_generator
,
anchor_target_generator
):
super
(
Anchor
RPN
,
self
).
__init__
()
super
(
Anchor
,
self
).
__init__
()
self
.
anchor_generator
=
anchor_generator
self
.
anchor_generator
=
anchor_generator
self
.
anchor_target_generator
=
anchor_target_generator
self
.
anchor_target_generator
=
anchor_target_generator
...
@@ -167,32 +73,24 @@ class AnchorRPN(object):
...
@@ -167,32 +73,24 @@ class AnchorRPN(object):
@
register
@
register
class
AnchorYOLO
(
object
):
class
AnchorYOLO
(
object
):
__inject__
=
[
'anchor_generator'
,
'anchor_post_process'
]
__inject__
=
[
'anchor_generator'
]
def
__init__
(
self
,
anchor_generator
,
anchor_post_process
):
def
__init__
(
self
,
anchor_generator
):
super
(
AnchorYOLO
,
self
).
__init__
()
super
(
AnchorYOLO
,
self
).
__init__
()
self
.
anchor_generator
=
anchor_generator
self
.
anchor_generator
=
anchor_generator
self
.
anchor_post_process
=
anchor_post_process
def
__call__
(
self
):
def
__call__
(
self
):
return
self
.
anchor_generator
()
return
self
.
anchor_generator
()
def
post_process
(
self
,
im_size
,
yolo_head_out
,
mask_anchors
):
return
self
.
anchor_post_process
(
im_size
,
yolo_head_out
,
mask_anchors
)
@
register
@
register
class
Proposal
(
object
):
class
Proposal
(
object
):
__inject__
=
[
__inject__
=
[
'proposal_generator'
,
'proposal_target_generator'
]
'proposal_generator'
,
'proposal_target_generator'
,
'bbox_post_process'
]
def
__init__
(
self
,
proposal_generator
,
proposal_target_generator
,
def
__init__
(
self
,
proposal_generator
,
proposal_target_generator
):
bbox_post_process
):
super
(
Proposal
,
self
).
__init__
()
super
(
Proposal
,
self
).
__init__
()
self
.
proposal_generator
=
proposal_generator
self
.
proposal_generator
=
proposal_generator
self
.
proposal_target_generator
=
proposal_target_generator
self
.
proposal_target_generator
=
proposal_target_generator
self
.
bbox_post_process
=
bbox_post_process
def
generate_proposal
(
self
,
inputs
,
rpn_head_out
,
anchor_out
):
def
generate_proposal
(
self
,
inputs
,
rpn_head_out
,
anchor_out
):
rpn_rois_list
=
[]
rpn_rois_list
=
[]
...
@@ -294,7 +192,3 @@ class Proposal(object):
...
@@ -294,7 +192,3 @@ class Proposal(object):
def
get_proposals
(
self
):
def
get_proposals
(
self
):
return
self
.
proposals_list
return
self
.
proposals_list
def
post_process
(
self
,
inputs
,
bbox_head_out
,
rois
):
bboxes
=
self
.
bbox_post_process
(
inputs
,
bbox_head_out
,
rois
)
return
bboxes
ppdet/modeling/head/__init__.py
浏览文件 @
06e6afcf
...
@@ -2,8 +2,10 @@ from . import rpn_head
...
@@ -2,8 +2,10 @@ from . import rpn_head
from
.
import
bbox_head
from
.
import
bbox_head
from
.
import
mask_head
from
.
import
mask_head
from
.
import
yolo_head
from
.
import
yolo_head
from
.
import
roi_extractor
from
.rpn_head
import
*
from
.rpn_head
import
*
from
.bbox_head
import
*
from
.bbox_head
import
*
from
.mask_head
import
*
from
.mask_head
import
*
from
.yolo_head
import
*
from
.yolo_head
import
*
from
.roi_extractor
import
*
ppdet/modeling/head/bbox_head.py
浏览文件 @
06e6afcf
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
...
@@ -5,6 +6,7 @@ from paddle.fluid.initializer import Normal, Xavier
...
@@ -5,6 +6,7 @@ from paddle.fluid.initializer import Normal, Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
ppdet.core.workspace
import
register
from
ppdet.core.workspace
import
register
import
paddle.nn.functional
as
F
@
register
@
register
...
@@ -85,7 +87,9 @@ class BBoxHead(Layer):
...
@@ -85,7 +87,9 @@ class BBoxHead(Layer):
num_classes
=
81
,
num_classes
=
81
,
cls_agnostic
=
False
,
cls_agnostic
=
False
,
num_stages
=
1
,
num_stages
=
1
,
with_pool
=
False
):
with_pool
=
False
,
score_stage
=
[
0
,
1
,
2
],
delta_stage
=
[
2
]):
super
(
BBoxHead
,
self
).
__init__
()
super
(
BBoxHead
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
delta_dim
=
2
if
cls_agnostic
else
num_classes
self
.
delta_dim
=
2
if
cls_agnostic
else
num_classes
...
@@ -94,6 +98,8 @@ class BBoxHead(Layer):
...
@@ -94,6 +98,8 @@ class BBoxHead(Layer):
self
.
bbox_score_list
=
[]
self
.
bbox_score_list
=
[]
self
.
bbox_delta_list
=
[]
self
.
bbox_delta_list
=
[]
self
.
with_pool
=
with_pool
self
.
with_pool
=
with_pool
self
.
score_stage
=
score_stage
self
.
delta_stage
=
delta_stage
for
stage
in
range
(
num_stages
):
for
stage
in
range
(
num_stages
):
score_name
=
'bbox_score_{}'
.
format
(
stage
)
score_name
=
'bbox_score_{}'
.
format
(
stage
)
delta_name
=
'bbox_delta_{}'
.
format
(
stage
)
delta_name
=
'bbox_delta_{}'
.
format
(
stage
)
...
@@ -169,3 +175,35 @@ class BBoxHead(Layer):
...
@@ -169,3 +175,35 @@ class BBoxHead(Layer):
loss_bbox
[
cls_name
]
=
loss_bbox_cls
loss_bbox
[
cls_name
]
=
loss_bbox_cls
loss_bbox
[
reg_name
]
=
loss_bbox_reg
loss_bbox
[
reg_name
]
=
loss_bbox_reg
return
loss_bbox
return
loss_bbox
def
get_prediction
(
self
,
bbox_head_out
,
rois
):
if
len
(
bbox_head_out
)
==
1
:
proposal
,
proposal_num
=
rois
score
,
delta
=
bbox_head_out
[
0
]
bbox_prob
=
F
.
softmax
(
score
)
delta
=
paddle
.
reshape
(
delta
,
(
-
1
,
self
.
delta_dim
,
4
))
else
:
num_stage
=
len
(
rois
)
proposal_list
=
[]
prob_list
=
[]
delta_list
=
[]
for
stage
,
(
proposals
,
bboxhead
)
in
zip
(
rois
,
bboxheads
):
score
,
delta
=
bboxhead
proposal
,
proposal_num
=
proposals
if
stage
in
self
.
score_stage
:
bbox_prob
=
F
.
softmax
(
score
)
prob_list
.
append
(
bbox_prob
)
if
stage
in
self
.
delta_stage
:
proposal_list
.
append
(
proposal
)
delta_list
.
append
(
delta
)
bbox_prob
=
paddle
.
mean
(
paddle
.
stack
(
prob_list
),
axis
=
0
)
delta
=
paddle
.
mean
(
paddle
.
stack
(
delta_list
),
axis
=
0
)
proposal
=
paddle
.
mean
(
paddle
.
stack
(
proposal_list
),
axis
=
0
)
delta
=
paddle
.
reshape
(
delta
,
(
-
1
,
self
.
out_dim
,
4
))
if
self
.
cls_agnostic
:
N
,
C
,
M
=
delta
.
shape
delta
=
delta
[:,
1
:
2
,
:]
delta
=
paddle
.
expand
(
delta
,
[
N
,
self
.
num_classes
,
M
])
bboxes
=
(
proposal
,
proposal_num
)
bbox_pred
=
(
delta
,
bbox_prob
)
return
bbox_pred
,
bboxes
ppdet/modeling/head/roi_extractor.py
0 → 100644
浏览文件 @
06e6afcf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
ppdet.core.workspace
import
register
from
ppdet.modeling
import
ops
@
register
class
RoIAlign
(
object
):
def
__init__
(
self
,
resolution
=
14
,
sampling_ratio
=
0
,
canconical_level
=
4
,
canonical_size
=
224
,
start_level
=
0
,
end_level
=
3
):
super
(
RoIAlign
,
self
).
__init__
()
self
.
resolution
=
resolution
self
.
sampling_ratio
=
sampling_ratio
self
.
canconical_level
=
canconical_level
self
.
canonical_size
=
canonical_size
self
.
start_level
=
start_level
self
.
end_level
=
end_level
def
__call__
(
self
,
feats
,
rois
,
spatial_scale
):
roi
,
rois_num
=
rois
cur_l
=
0
if
self
.
start_level
==
self
.
end_level
:
rois_feat
=
ops
.
roi_align
(
feats
[
self
.
start_level
],
roi
,
self
.
resolution
,
spatial_scale
,
rois_num
=
rois_num
)
return
rois_feat
offset
=
2
k_min
=
self
.
start_level
+
offset
k_max
=
self
.
end_level
+
offset
rois_dist
,
restore_index
,
rois_num_dist
=
ops
.
distribute_fpn_proposals
(
roi
,
k_min
,
k_max
,
self
.
canconical_level
,
self
.
canonical_size
,
rois_num
=
rois_num
)
rois_feat_list
=
[]
for
lvl
in
range
(
self
.
start_level
,
self
.
end_level
+
1
):
roi_feat
=
ops
.
roi_align
(
feats
[
lvl
],
rois_dist
[
lvl
],
self
.
resolution
,
spatial_scale
[
lvl
],
sampling_ratio
=
self
.
sampling_ratio
,
rois_num
=
rois_num_dist
[
lvl
])
rois_feat_list
.
append
(
roi_feat
)
rois_feat_shuffle
=
paddle
.
concat
(
rois_feat_list
)
rois_feat
=
paddle
.
gather
(
rois_feat_shuffle
,
restore_index
)
return
rois_feat
ppdet/modeling/layers.py
浏览文件 @
06e6afcf
...
@@ -14,12 +14,15 @@
...
@@ -14,12 +14,15 @@
import
numpy
as
np
import
numpy
as
np
from
numbers
import
Integral
from
numbers
import
Integral
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.base
import
to_variable
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.py_op.target
import
generate_rpn_anchor_target
,
generate_proposal_target
,
generate_mask_target
from
ppdet.py_op.target
import
generate_rpn_anchor_target
,
generate_proposal_target
,
generate_mask_target
from
ppdet.py_op.post_process
import
bbox_post_process
from
ppdet.py_op.post_process
import
bbox_post_process
from
.
import
ops
from
.
import
ops
import
paddle.nn.functional
as
F
@
register
@
register
...
@@ -278,58 +281,71 @@ class MaskTargetGenerator(object):
...
@@ -278,58 +281,71 @@ class MaskTargetGenerator(object):
@
register
@
register
class
RoIExtractor
(
object
):
@
serializable
class
RCNNBox
(
object
):
__shared__
=
[
'num_classes'
,
'batch_size'
]
def
__init__
(
self
,
def
__init__
(
self
,
resolution
=
14
,
num_classes
=
81
,
sampling_ratio
=
0
,
batch_size
=
1
,
canconical_level
=
4
,
prior_box_var
=
[
0.1
,
0.1
,
0.2
,
0.2
],
canonical_size
=
224
,
code_type
=
"decode_center_size"
,
start_level
=
0
,
box_normalized
=
False
,
end_level
=
3
):
axis
=
1
):
super
(
RoIExtractor
,
self
).
__init__
()
super
(
RCNNBox
,
self
).
__init__
()
self
.
resolution
=
resolution
self
.
num_classes
=
num_classes
self
.
sampling_ratio
=
sampling_ratio
self
.
batch_size
=
batch_size
self
.
canconical_level
=
canconical_level
self
.
prior_box_var
=
prior_box_var
self
.
canonical_size
=
canonical_size
self
.
code_type
=
code_type
self
.
start_level
=
start_level
self
.
box_normalized
=
box_normalized
self
.
end_level
=
end_level
self
.
axis
=
axis
def
__call__
(
self
,
feats
,
rois
,
spatial_scale
):
def
__call__
(
self
,
bbox_head_out
,
rois
,
im_shape
,
scale_factor
):
bbox_pred
,
cls_prob
=
bbox_head_out
roi
,
rois_num
=
rois
roi
,
rois_num
=
rois
cur_l
=
0
origin_shape
=
im_shape
/
scale_factor
if
self
.
start_level
==
self
.
end_level
:
scale_list
=
[]
rois_feat
=
ops
.
roi_align
(
origin_shape_list
=
[]
feats
[
self
.
start_level
],
for
idx
in
range
(
self
.
batch_size
):
roi
,
scale
=
scale_factor
[
idx
,
:]
self
.
resolution
,
rois_num_per_im
=
rois_num
[
idx
]
spatial_scale
,
expand_scale
=
paddle
.
expand
(
scale
,
[
rois_num_per_im
,
1
])
rois_num
=
rois_num
)
scale_list
.
append
(
expand_scale
)
return
rois_feat
expand_im_shape
=
paddle
.
expand
(
origin_shape
[
idx
,
:],
offset
=
2
[
rois_num_per_im
,
2
])
k_min
=
self
.
start_level
+
offset
origin_shape_list
.
append
(
expand_im_shape
)
k_max
=
self
.
end_level
+
offset
rois_dist
,
restore_index
,
rois_num_dist
=
ops
.
distribute_fpn_proposals
(
scale
=
paddle
.
concat
(
scale_list
)
roi
,
origin_shape
=
paddle
.
concat
(
origin_shape_list
)
k_min
,
k_max
,
bbox
=
roi
/
scale
self
.
canconical_level
,
bbox
=
ops
.
box_coder
(
self
.
canonical_size
,
prior_box
=
bbox
,
rois_num
=
rois_num
)
prior_box_var
=
self
.
prior_box_var
,
target_box
=
bbox_pred
,
rois_feat_list
=
[]
code_type
=
self
.
code_type
,
for
lvl
in
range
(
self
.
start_level
,
self
.
end_level
+
1
):
box_normalized
=
self
.
box_normalized
,
roi_feat
=
ops
.
roi_align
(
axis
=
self
.
axis
)
feats
[
lvl
],
# TODO: Updata box_clip
rois_dist
[
lvl
],
origin_h
=
origin_shape
[:,
0
]
-
1
self
.
resolution
,
origin_w
=
origin_shape
[:,
1
]
-
1
spatial_scale
[
lvl
],
zeros
=
paddle
.
zeros
(
origin_h
.
shape
,
'float32'
)
sampling_ratio
=
self
.
sampling_ratio
,
x1
=
paddle
.
maximum
(
rois_num
=
rois_num_dist
[
lvl
])
paddle
.
minimum
(
rois_feat_list
.
append
(
roi_feat
)
bbox
[:,
:,
0
],
origin_w
,
axis
=
0
),
zeros
,
axis
=
0
)
rois_feat_shuffle
=
fluid
.
layers
.
concat
(
rois_feat_list
)
y1
=
paddle
.
maximum
(
rois_feat
=
fluid
.
layers
.
gather
(
rois_feat_shuffle
,
restore_index
)
paddle
.
minimum
(
bbox
[:,
:,
1
],
origin_h
,
axis
=
0
),
zeros
,
axis
=
0
)
return
rois_feat
x2
=
paddle
.
maximum
(
paddle
.
minimum
(
bbox
[:,
:,
2
],
origin_w
,
axis
=
0
),
zeros
,
axis
=
0
)
y2
=
paddle
.
maximum
(
paddle
.
minimum
(
bbox
[:,
:,
3
],
origin_h
,
axis
=
0
),
zeros
,
axis
=
0
)
bbox
=
paddle
.
stack
([
x1
,
y1
,
x2
,
y2
],
axis
=-
1
)
bboxes
=
(
bbox
,
rois_num
)
return
bboxes
,
cls_prob
@
register
@
register
...
@@ -367,9 +383,6 @@ class DecodeClipNms(object):
...
@@ -367,9 +383,6 @@ class DecodeClipNms(object):
@
register
@
register
@
serializable
@
serializable
class
MultiClassNMS
(
object
):
class
MultiClassNMS
(
object
):
__op__
=
ops
.
multiclass_nms
__append_doc__
=
True
def
__init__
(
self
,
def
__init__
(
self
,
score_threshold
=
.
05
,
score_threshold
=
.
05
,
nms_top_k
=-
1
,
nms_top_k
=-
1
,
...
@@ -387,6 +400,13 @@ class MultiClassNMS(object):
...
@@ -387,6 +400,13 @@ class MultiClassNMS(object):
self
.
nms_eta
=
nms_eta
self
.
nms_eta
=
nms_eta
self
.
background_label
=
background_label
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
score
):
kwargs
=
self
.
__dict__
.
copy
()
if
isinstance
(
bboxes
,
tuple
):
bboxes
,
bbox_num
=
bboxes
kwargs
.
update
({
'rois_num'
:
bbox_num
})
return
ops
.
multiclass_nms
(
bboxes
,
score
,
**
kwargs
)
@
register
@
register
@
serializable
@
serializable
...
@@ -417,19 +437,37 @@ class MatrixNMS(object):
...
@@ -417,19 +437,37 @@ class MatrixNMS(object):
@
register
@
register
@
serializable
@
serializable
class
YOLOBox
(
object
):
class
YOLOBox
(
object
):
def
__init__
(
__shared__
=
[
'num_classes'
]
self
,
def
__init__
(
self
,
num_classes
=
80
,
conf_thresh
=
0.005
,
conf_thresh
=
0.005
,
downsample_ratio
=
32
,
downsample_ratio
=
32
,
clip_bbox
=
True
,
):
clip_bbox
=
True
,
scale_x_y
=
1.
):
self
.
num_classes
=
num_classes
self
.
conf_thresh
=
conf_thresh
self
.
conf_thresh
=
conf_thresh
self
.
downsample_ratio
=
downsample_ratio
self
.
downsample_ratio
=
downsample_ratio
self
.
clip_bbox
=
clip_bbox
self
.
clip_bbox
=
clip_bbox
self
.
scale_x_y
=
scale_x_y
def
__call__
(
self
,
x
,
img_size
,
anchors
,
num_classes
,
stage
=
0
):
def
__call__
(
self
,
yolo_head_out
,
anchors
,
im_shape
,
scale_factor
=
None
):
outs
=
ops
.
yolo_box
(
x
,
img_size
,
anchors
,
num_classes
,
self
.
conf_thresh
,
boxes_list
=
[]
self
.
downsample_ratio
//
2
**
stage
,
self
.
clip_bbox
)
scores_list
=
[]
return
outs
if
scale_factor
is
not
None
:
origin_shape
=
im_shape
/
scale_factor
else
:
origin_shape
=
im_shape
for
i
,
head_out
in
enumerate
(
yolo_head_out
):
boxes
,
scores
=
ops
.
yolo_box
(
head_out
,
origin_shape
,
anchors
[
i
],
self
.
num_classes
,
self
.
conf_thresh
,
self
.
downsample_ratio
//
2
**
i
,
self
.
clip_bbox
,
self
.
scale_x_y
)
boxes_list
.
append
(
boxes
)
scores_list
.
append
(
paddle
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
yolo_boxes
=
paddle
.
concat
(
boxes_list
,
axis
=
1
)
yolo_scores
=
paddle
.
concat
(
scores_list
,
axis
=
2
)
return
yolo_boxes
,
yolo_scores
@
register
@
register
...
...
ppdet/modeling/mask.py
浏览文件 @
06e6afcf
...
@@ -2,38 +2,14 @@ import numpy as np
...
@@ -2,38 +2,14 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
ppdet.core.workspace
import
register
from
ppdet.core.workspace
import
register
# TODO: regitster mask_post_process op
from
ppdet.py_op.post_process
import
mask_post_process
@
register
class
MaskPostProcess
(
object
):
__shared__
=
[
'mask_resolution'
]
def
__init__
(
self
,
mask_resolution
=
28
,
binary_thresh
=
0.5
):
super
(
MaskPostProcess
,
self
).
__init__
()
self
.
mask_resolution
=
mask_resolution
self
.
binary_thresh
=
binary_thresh
def
__call__
(
self
,
bboxes
,
mask_head_out
,
im_info
):
# TODO: modify related ops for deploying
bboxes_np
=
(
i
.
numpy
()
for
i
in
bboxes
)
mask
=
mask_post_process
(
bboxes_np
,
mask_head_out
.
numpy
(),
im_info
.
numpy
(),
self
.
mask_resolution
,
self
.
binary_thresh
)
mask
=
{
'mask'
:
mask
}
return
mask
@
register
@
register
class
Mask
(
object
):
class
Mask
(
object
):
__inject__
=
[
'mask_target_generator'
,
'mask_post_process'
]
__inject__
=
[
'mask_target_generator'
]
def
__init__
(
self
,
mask_target_generator
,
mask_post_process
):
def
__init__
(
self
,
mask_target_generator
):
super
(
Mask
,
self
).
__init__
()
super
(
Mask
,
self
).
__init__
()
self
.
mask_target_generator
=
mask_target_generator
self
.
mask_target_generator
=
mask_target_generator
self
.
mask_post_process
=
mask_post_process
def
__call__
(
self
,
inputs
,
rois
,
targets
):
def
__call__
(
self
,
inputs
,
rois
,
targets
):
mask_rois
,
rois_has_mask_int32
=
self
.
generate_mask_target
(
inputs
,
rois
,
mask_rois
,
rois_has_mask_int32
=
self
.
generate_mask_target
(
inputs
,
rois
,
...
@@ -56,7 +32,3 @@ class Mask(object):
...
@@ -56,7 +32,3 @@ class Mask(object):
def
get_targets
(
self
):
def
get_targets
(
self
):
return
self
.
mask_int32
return
self
.
mask_int32
def
post_process
(
self
,
bboxes
,
mask_head_out
,
im_info
):
mask
=
self
.
mask_post_process
(
bboxes
,
mask_head_out
,
im_info
)
return
mask
ppdet/modeling/ops.py
浏览文件 @
06e6afcf
...
@@ -1337,8 +1337,9 @@ def box_coder(prior_box,
...
@@ -1337,8 +1337,9 @@ def box_coder(prior_box,
elif
isinstance
(
prior_box_var
,
list
):
elif
isinstance
(
prior_box_var
,
list
):
output_box
=
core
.
ops
.
box_coder
(
output_box
=
core
.
ops
.
box_coder
(
prior_box
,
target_box
,
"code_type"
,
code_type
,
"box_normalized"
,
prior_box
,
None
,
target_box
,
"code_type"
,
code_type
,
box_normalized
,
"axis"
,
axis
,
"variance"
,
prior_box_var
)
"box_normalized"
,
box_normalized
,
"axis"
,
axis
,
"variance"
,
prior_box_var
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"Input variance of box_coder must be Variable or list"
)
"Input variance of box_coder must be Variable or list"
)
...
...
ppdet/modeling/post_process.py
0 → 100644
浏览文件 @
06e6afcf
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
from
ppdet.py_op.post_process
import
mask_post_process
from
.
import
ops
@
register
class
BBoxPostProcess
(
object
):
__inject__
=
[
'decode'
,
'nms'
]
def
__init__
(
self
,
decode
=
None
,
nms
=
None
):
super
(
BBoxPostProcess
,
self
).
__init__
()
self
.
decode
=
decode
self
.
nms
=
nms
def
__call__
(
self
,
head_out
,
rois
,
im_shape
,
scale_factor
=
None
):
# TODO: compatible for im_info
# remove after unify the im_shape. scale_factor
if
im_shape
.
shape
[
1
]
>
2
:
origin_shape
=
im_shape
[:,
:
2
]
scale_factor
=
im_shape
[:,
2
:]
else
:
origin_shape
=
im_shape
bboxes
,
score
=
self
.
decode
(
head_out
,
rois
,
origin_shape
,
scale_factor
)
bbox_pred
,
bbox_num
=
self
.
nms
(
bboxes
,
score
)
return
bbox_pred
,
bbox_num
@
register
class
MaskPostProcess
(
object
):
__shared__
=
[
'mask_resolution'
]
def
__init__
(
self
,
mask_resolution
=
28
,
binary_thresh
=
0.5
):
super
(
MaskPostProcess
,
self
).
__init__
()
self
.
mask_resolution
=
mask_resolution
self
.
binary_thresh
=
binary_thresh
def
__call__
(
self
,
bboxes
,
mask_head_out
,
im_info
):
# TODO: modify related ops for deploying
bboxes_np
=
(
i
.
numpy
()
for
i
in
bboxes
)
mask
=
mask_post_process
(
bboxes_np
,
mask_head_out
.
numpy
(),
im_info
.
numpy
(),
self
.
mask_resolution
,
self
.
binary_thresh
)
mask
=
{
'mask'
:
mask
}
return
mask
ppdet/utils/eval_utils.py
浏览文件 @
06e6afcf
...
@@ -85,7 +85,7 @@ def eval_results(res, metric, anno_file):
...
@@ -85,7 +85,7 @@ def eval_results(res, metric, anno_file):
json
.
dump
(
res
[
'mask'
],
f
)
json
.
dump
(
res
[
'mask'
],
f
)
logger
.
info
(
'The mask result is saved to mask.json.'
)
logger
.
info
(
'The mask result is saved to mask.json.'
)
seg_stats
=
cocoapi_eval
(
'mask.json'
,
'
mask
'
,
anno_file
=
anno_file
)
seg_stats
=
cocoapi_eval
(
'mask.json'
,
'
segm
'
,
anno_file
=
anno_file
)
eval_res
.
append
(
seg_stats
)
eval_res
.
append
(
seg_stats
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
else
:
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录