Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c7c0568f
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c7c0568f
编写于
5月 08, 2021
作者:
W
wangguanzhong
提交者:
GitHub
5月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ags module (#2885)
上级
a06a6258
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
72 addition
and
8 deletion
+72
-8
ppdet/modeling/heads/ttf_head.py
ppdet/modeling/heads/ttf_head.py
+27
-2
ppdet/modeling/losses/iou_loss.py
ppdet/modeling/losses/iou_loss.py
+8
-2
static/ppdet/modeling/anchor_heads/ttf_head.py
static/ppdet/modeling/anchor_heads/ttf_head.py
+26
-2
static/ppdet/modeling/losses/giou_loss.py
static/ppdet/modeling/losses/giou_loss.py
+11
-2
未找到文件。
ppdet/modeling/heads/ttf_head.py
浏览文件 @
c7c0568f
...
...
@@ -200,6 +200,9 @@ class TTFHead(nn.Layer):
lite_head(bool): whether use lite version. False by default.
norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
bn by default
ags_module(bool): whether use AGS module to reweight location feature.
false by default.
"""
__shared__
=
[
'num_classes'
,
'down_ratio'
,
'norm_type'
]
...
...
@@ -218,7 +221,8 @@ class TTFHead(nn.Layer):
down_ratio
=
4
,
dcn_head
=
False
,
lite_head
=
False
,
norm_type
=
'bn'
):
norm_type
=
'bn'
,
ags_module
=
False
):
super
(
TTFHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
hm_head
=
HMHead
(
in_channels
,
hm_head_planes
,
num_classes
,
...
...
@@ -230,6 +234,7 @@ class TTFHead(nn.Layer):
self
.
wh_offset_base
=
wh_offset_base
self
.
down_ratio
=
down_ratio
self
.
ags_module
=
ags_module
@
classmethod
def
from_config
(
cls
,
cfg
,
input_shape
):
...
...
@@ -253,6 +258,12 @@ class TTFHead(nn.Layer):
target
=
paddle
.
gather_nd
(
target
,
index
)
return
pred
,
target
,
weight
def
filter_loc_by_weight
(
self
,
score
,
weight
):
index
=
paddle
.
nonzero
(
weight
>
0
)
index
.
stop_gradient
=
True
score
=
paddle
.
gather_nd
(
score
,
index
)
return
score
def
get_loss
(
self
,
pred_hm
,
pred_wh
,
target_hm
,
box_target
,
target_weight
):
pred_hm
=
paddle
.
clip
(
F
.
sigmoid
(
pred_hm
),
1e-4
,
1
-
1e-4
)
hm_loss
=
self
.
hm_loss
(
pred_hm
,
target_hm
)
...
...
@@ -274,10 +285,24 @@ class TTFHead(nn.Layer):
boxes
=
paddle
.
transpose
(
box_target
,
[
0
,
2
,
3
,
1
])
boxes
.
stop_gradient
=
True
if
self
.
ags_module
:
pred_hm_max
=
paddle
.
max
(
pred_hm
,
axis
=
1
,
keepdim
=
True
)
pred_hm_max_softmax
=
F
.
softmax
(
pred_hm_max
,
axis
=
1
)
pred_hm_max_softmax
=
paddle
.
transpose
(
pred_hm_max_softmax
,
[
0
,
2
,
3
,
1
])
pred_hm_max_softmax
=
self
.
filter_loc_by_weight
(
pred_hm_max_softmax
,
mask
)
else
:
pred_hm_max_softmax
=
None
pred_boxes
,
boxes
,
mask
=
self
.
filter_box_by_weight
(
pred_boxes
,
boxes
,
mask
)
mask
.
stop_gradient
=
True
wh_loss
=
self
.
wh_loss
(
pred_boxes
,
boxes
,
iou_weight
=
mask
.
unsqueeze
(
1
))
wh_loss
=
self
.
wh_loss
(
pred_boxes
,
boxes
,
iou_weight
=
mask
.
unsqueeze
(
1
),
loc_reweight
=
pred_hm_max_softmax
)
wh_loss
=
wh_loss
/
avg_factor
ttf_loss
=
{
'hm_loss'
:
hm_loss
,
'wh_loss'
:
wh_loss
}
...
...
ppdet/modeling/losses/iou_loss.py
浏览文件 @
c7c0568f
...
...
@@ -110,7 +110,7 @@ class GIoULoss(object):
return
iou
,
overlap
,
union
def
__call__
(
self
,
pbox
,
gbox
,
iou_weight
=
1.
):
def
__call__
(
self
,
pbox
,
gbox
,
iou_weight
=
1.
,
loc_reweight
=
None
):
x1
,
y1
,
x2
,
y2
=
paddle
.
split
(
pbox
,
num_or_sections
=
4
,
axis
=-
1
)
x1g
,
y1g
,
x2g
,
y2g
=
paddle
.
split
(
gbox
,
num_or_sections
=
4
,
axis
=-
1
)
box1
=
[
x1
,
y1
,
x2
,
y2
]
...
...
@@ -123,7 +123,13 @@ class GIoULoss(object):
area_c
=
(
xc2
-
xc1
)
*
(
yc2
-
yc1
)
+
self
.
eps
miou
=
iou
-
((
area_c
-
union
)
/
area_c
)
giou
=
1
-
miou
if
loc_reweight
is
not
None
:
loc_reweight
=
paddle
.
reshape
(
loc_reweight
,
shape
=
(
-
1
,
1
))
loc_thresh
=
0.9
giou
=
1
-
(
1
-
loc_thresh
)
*
miou
-
loc_thresh
*
miou
*
loc_reweight
else
:
giou
=
1
-
miou
if
self
.
reduction
==
'none'
:
loss
=
giou
elif
self
.
reduction
==
'sum'
:
...
...
static/ppdet/modeling/anchor_heads/ttf_head.py
浏览文件 @
c7c0568f
...
...
@@ -67,6 +67,8 @@ class TTFHead(object):
keep_prob(float): keep_prob parameter for drop_block. 0.9 by default.
fusion_method (string): Method to fusion upsample and lateral branch.
'add' and 'concat' are optional, add by default
ags_module(bool): whether use AGS module to reweight location feature.
false by default.
"""
__inject__
=
[
'wh_loss'
]
...
...
@@ -93,7 +95,8 @@ class TTFHead(object):
drop_block
=
False
,
block_size
=
3
,
keep_prob
=
0.9
,
fusion_method
=
'add'
):
fusion_method
=
'add'
,
ags_module
=
False
):
super
(
TTFHead
,
self
).
__init__
()
self
.
head_conv
=
head_conv
self
.
num_classes
=
num_classes
...
...
@@ -119,6 +122,7 @@ class TTFHead(object):
self
.
block_size
=
block_size
self
.
keep_prob
=
keep_prob
self
.
fusion_method
=
fusion_method
self
.
ags_module
=
ags_module
def
shortcut
(
self
,
x
,
out_c
,
layer_num
,
kernel_size
=
3
,
padding
=
1
,
name
=
None
):
...
...
@@ -359,6 +363,12 @@ class TTFHead(object):
target
=
fluid
.
layers
.
gather_nd
(
target
,
index
)
return
pred
,
target
,
weight
def
filter_loc_by_weight
(
self
,
score
,
weight
):
index
=
fluid
.
layers
.
where
(
weight
>
0
)
index
.
stop_gradient
=
True
score
=
fluid
.
layers
.
gather_nd
(
score
,
index
)
return
score
def
get_loss
(
self
,
pred_hm
,
pred_wh
,
target_hm
,
box_target
,
target_weight
):
try
:
pred_hm
=
paddle
.
clip
(
fluid
.
layers
.
sigmoid
(
pred_hm
),
1e-4
,
1
-
1e-4
)
...
...
@@ -387,11 +397,25 @@ class TTFHead(object):
boxes
=
fluid
.
layers
.
transpose
(
box_target
,
[
0
,
2
,
3
,
1
])
boxes
.
stop_gradient
=
True
if
self
.
ags_module
:
pred_hm_max
=
fluid
.
layers
.
reduce_max
(
pred_hm
,
dim
=
1
,
keep_dim
=
True
)
pred_hm_max_softmax
=
fluid
.
layers
.
softmax
(
pred_hm_max
,
axis
=
1
)
pred_hm_max_softmax
=
fluid
.
layers
.
transpose
(
pred_hm_max_softmax
,
[
0
,
2
,
3
,
1
])
pred_hm_max_softmax
=
self
.
filter_loc_by_weight
(
pred_hm_max_softmax
,
mask
)
else
:
pred_hm_max_softmax
=
None
pred_boxes
,
boxes
,
mask
=
self
.
filter_box_by_weight
(
pred_boxes
,
boxes
,
mask
)
mask
.
stop_gradient
=
True
wh_loss
=
self
.
wh_loss
(
pred_boxes
,
boxes
,
outside_weight
=
mask
,
use_transform
=
False
)
pred_boxes
,
boxes
,
loc_reweight
=
pred_hm_max_softmax
,
outside_weight
=
mask
,
use_transform
=
False
)
wh_loss
=
wh_loss
/
avg_factor
ttf_loss
=
{
'hm_loss'
:
hm_loss
,
'wh_loss'
:
wh_loss
}
...
...
static/ppdet/modeling/losses/giou_loss.py
浏览文件 @
c7c0568f
...
...
@@ -89,6 +89,7 @@ class GiouLoss(object):
inside_weight
=
None
,
outside_weight
=
None
,
bbox_reg_weight
=
[
0.1
,
0.1
,
0.2
,
0.2
],
loc_reweight
=
None
,
use_transform
=
True
):
eps
=
1.e-10
if
use_transform
:
...
...
@@ -134,11 +135,19 @@ class GiouLoss(object):
elif
outside_weight
is
not
None
:
iou_weights
=
outside_weight
if
loc_reweight
is
not
None
:
loc_reweight
=
fluid
.
layers
.
reshape
(
loc_reweight
,
shape
=
(
-
1
,
1
))
loc_thresh
=
0.9
giou
=
1
-
(
1
-
loc_thresh
)
*
miouk
-
loc_thresh
*
miouk
*
loc_reweight
else
:
giou
=
1
-
miouk
if
self
.
do_average
:
miouk
=
fluid
.
layers
.
reduce_mean
(
(
1
-
miouk
)
*
iou_weights
)
miouk
=
fluid
.
layers
.
reduce_mean
(
giou
*
iou_weights
)
else
:
iou_distance
=
fluid
.
layers
.
elementwise_mul
(
1
-
miouk
,
iou_weights
,
axis
=
0
)
giou
,
iou_weights
,
axis
=
0
)
miouk
=
fluid
.
layers
.
reduce_sum
(
iou_distance
)
if
self
.
use_class_weight
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录