Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
2f89e761
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2f89e761
编写于
6月 01, 2020
作者:
S
sunyanfang01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fasterrcnn loss
上级
948032a7
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
504 addition
and
32 deletion
+504
-32
paddlex/cv/models/faster_rcnn.py
paddlex/cv/models/faster_rcnn.py
+40
-12
paddlex/cv/nets/detection/bbox_head.py
paddlex/cv/nets/detection/bbox_head.py
+61
-18
paddlex/cv/nets/detection/faster_rcnn.py
paddlex/cv/nets/detection/faster_rcnn.py
+5
-2
paddlex/cv/nets/detection/loss/diou_loss.py
paddlex/cv/nets/detection/loss/diou_loss.py
+117
-0
paddlex/cv/nets/detection/loss/giou_loss.py
paddlex/cv/nets/detection/loss/giou_loss.py
+119
-0
paddlex/cv/nets/detection/nms.py
paddlex/cv/nets/detection/nms.py
+162
-0
未找到文件。
paddlex/cv/models/faster_rcnn.py
浏览文件 @
2f89e761
...
@@ -43,7 +43,8 @@ class FasterRCNN(BaseAPI):
...
@@ -43,7 +43,8 @@ class FasterRCNN(BaseAPI):
backbone
=
'ResNet50'
,
backbone
=
'ResNet50'
,
with_fpn
=
True
,
with_fpn
=
True
,
aspect_ratios
=
[
0.5
,
1.0
,
2.0
],
aspect_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_sizes
=
[
32
,
64
,
128
,
256
,
512
]):
anchor_sizes
=
[
32
,
64
,
128
,
256
,
512
],
bbox_loss_type
=
'SmoothL1Loss'
):
self
.
init_params
=
locals
()
self
.
init_params
=
locals
()
super
(
FasterRCNN
,
self
).
__init__
(
'detector'
)
super
(
FasterRCNN
,
self
).
__init__
(
'detector'
)
backbones
=
[
backbones
=
[
...
@@ -57,6 +58,7 @@ class FasterRCNN(BaseAPI):
...
@@ -57,6 +58,7 @@ class FasterRCNN(BaseAPI):
self
.
with_fpn
=
with_fpn
self
.
with_fpn
=
with_fpn
self
.
aspect_ratios
=
aspect_ratios
self
.
aspect_ratios
=
aspect_ratios
self
.
anchor_sizes
=
anchor_sizes
self
.
anchor_sizes
=
anchor_sizes
self
.
bbox_loss_type
=
bbox_loss_type
self
.
labels
=
None
self
.
labels
=
None
self
.
fixed_input_shape
=
None
self
.
fixed_input_shape
=
None
...
@@ -72,6 +74,8 @@ class FasterRCNN(BaseAPI):
...
@@ -72,6 +74,8 @@ class FasterRCNN(BaseAPI):
layers
=
50
layers
=
50
variant
=
'd'
variant
=
'd'
norm_type
=
'affine_channel'
norm_type
=
'affine_channel'
if
self
.
bbox_loss_type
!=
'SmoothL1Loss'
:
norm_type
=
'bn'
elif
backbone_name
==
'ResNet101'
:
elif
backbone_name
==
'ResNet101'
:
layers
=
101
layers
=
101
variant
=
'b'
variant
=
'b'
...
@@ -118,7 +122,8 @@ class FasterRCNN(BaseAPI):
...
@@ -118,7 +122,8 @@ class FasterRCNN(BaseAPI):
anchor_sizes
=
self
.
anchor_sizes
,
anchor_sizes
=
self
.
anchor_sizes
,
train_pre_nms_top_n
=
train_pre_nms_top_n
,
train_pre_nms_top_n
=
train_pre_nms_top_n
,
test_pre_nms_top_n
=
test_pre_nms_top_n
,
test_pre_nms_top_n
=
test_pre_nms_top_n
,
fixed_input_shape
=
self
.
fixed_input_shape
)
fixed_input_shape
=
self
.
fixed_input_shape
,
bbox_loss_type
=
self
.
bbox_loss_type
)
inputs
=
model
.
generate_inputs
()
inputs
=
model
.
generate_inputs
()
if
mode
==
'train'
:
if
mode
==
'train'
:
model_out
=
model
.
build_net
(
inputs
)
model_out
=
model
.
build_net
(
inputs
)
...
@@ -134,26 +139,49 @@ class FasterRCNN(BaseAPI):
...
@@ -134,26 +139,49 @@ class FasterRCNN(BaseAPI):
outputs
=
model
.
build_net
(
inputs
)
outputs
=
model
.
build_net
(
inputs
)
return
inputs
,
outputs
return
inputs
,
outputs
# def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
# lr_decay_epochs, lr_decay_gamma,
# num_steps_each_epoch):
# if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
# raise Exception("warmup_steps should less than {}".format(
# lr_decay_epochs[0] * num_steps_each_epoch))
# boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
# values = [(lr_decay_gamma**i) * learning_rate
# for i in range(len(lr_decay_epochs) + 1)]
# lr_decay = fluid.layers.piecewise_decay(
# boundaries=boundaries, values=values)
# lr_warmup = fluid.layers.linear_lr_warmup(
# learning_rate=lr_decay,
# warmup_steps=warmup_steps,
# start_lr=warmup_start_lr,
# end_lr=learning_rate)
# optimizer = fluid.optimizer.Momentum(
# learning_rate=lr_warmup,
# momentum=0.9,
# regularization=fluid.regularizer.L2Decay(1e-04))
# return optimizer
def
default_optimizer
(
self
,
learning_rate
,
warmup_steps
,
warmup_start_lr
,
def
default_optimizer
(
self
,
learning_rate
,
warmup_steps
,
warmup_start_lr
,
lr_decay_epochs
,
lr_decay_gamma
,
lr_decay_epochs
,
lr_decay_gamma
,
num_steps_each_epoch
):
num_steps_each_epoch
):
if
warmup_steps
>
lr_decay_epochs
[
0
]
*
num_steps_each_epoch
:
#
if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
raise
Exception
(
"warmup_steps should less than {}"
.
format
(
#
raise Exception("warmup_steps should less than {}".format(
lr_decay_epochs
[
0
]
*
num_steps_each_epoch
))
#
lr_decay_epochs[0] * num_steps_each_epoch))
boundaries
=
[
b
*
num_steps_each_epoch
for
b
in
lr_decay_epochs
]
boundaries
=
[
b
*
num_steps_each_epoch
for
b
in
lr_decay_epochs
]
values
=
[(
lr_decay_gamma
**
i
)
*
learning_rate
values
=
[(
lr_decay_gamma
**
i
)
*
learning_rate
for
i
in
range
(
len
(
lr_decay_epochs
)
+
1
)]
for
i
in
range
(
len
(
lr_decay_epochs
)
+
1
)]
lr_decay
=
fluid
.
layers
.
piecewise_decay
(
lr_decay
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
boundaries
,
values
=
values
)
boundaries
=
boundaries
,
values
=
values
)
lr_warmup
=
fluid
.
layers
.
linear_lr_warmup
(
#
lr_warmup = fluid.layers.linear_lr_warmup(
learning_rate
=
lr_decay
,
#
learning_rate=lr_decay,
warmup_steps
=
warmup_steps
,
#
warmup_steps=warmup_steps,
start_lr
=
warmup_start_lr
,
#
start_lr=warmup_start_lr,
end_lr
=
learning_rate
)
#
end_lr=learning_rate)
optimizer
=
fluid
.
optimizer
.
Momentum
(
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
lr_warmup
,
#learning_rate=lr_warmup,
learning_rate
=
lr_decay
,
momentum
=
0.9
,
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-04
))
regularization
=
fluid
.
regularizer
.
L2Decay
Regularizer
(
1e-04
))
return
optimizer
return
optimizer
def
train
(
self
,
def
train
(
self
,
...
...
paddlex/cv/nets/detection/bbox_head.py
浏览文件 @
2f89e761
...
@@ -24,6 +24,7 @@ from paddle.fluid.initializer import Normal, Xavier
...
@@ -24,6 +24,7 @@ from paddle.fluid.initializer import Normal, Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.initializer
import
MSRA
from
paddle.fluid.initializer
import
MSRA
__all__
=
[
'BBoxHead'
,
'TwoFCHead'
]
__all__
=
[
'BBoxHead'
,
'TwoFCHead'
]
...
@@ -82,7 +83,8 @@ class BBoxHead(object):
...
@@ -82,7 +83,8 @@ class BBoxHead(object):
background_label
=
0
,
background_label
=
0
,
#bbox_loss
#bbox_loss
sigma
=
1.0
,
sigma
=
1.0
,
num_classes
=
81
):
num_classes
=
81
,
bbox_loss_type
=
'SmoothL1Loss'
):
super
(
BBoxHead
,
self
).
__init__
()
super
(
BBoxHead
,
self
).
__init__
()
self
.
head
=
head
self
.
head
=
head
self
.
prior_box_var
=
prior_box_var
self
.
prior_box_var
=
prior_box_var
...
@@ -99,6 +101,7 @@ class BBoxHead(object):
...
@@ -99,6 +101,7 @@ class BBoxHead(object):
self
.
sigma
=
sigma
self
.
sigma
=
sigma
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
head_feat
=
None
self
.
head_feat
=
None
self
.
bbox_loss_type
=
bbox_loss_type
def
get_head_feat
(
self
,
input
=
None
):
def
get_head_feat
(
self
,
input
=
None
):
"""
"""
...
@@ -126,6 +129,7 @@ class BBoxHead(object):
...
@@ -126,6 +129,7 @@ class BBoxHead(object):
[N, num_anchors * 4, H, W].
[N, num_anchors * 4, H, W].
"""
"""
head_feat
=
self
.
get_head_feat
(
roi_feat
)
head_feat
=
self
.
get_head_feat
(
roi_feat
)
# when ResNetC5 output a single feature map
# when ResNetC5 output a single feature map
if
not
isinstance
(
self
.
head
,
TwoFCHead
):
if
not
isinstance
(
self
.
head
,
TwoFCHead
):
head_feat
=
fluid
.
layers
.
pool2d
(
head_feat
=
fluid
.
layers
.
pool2d
(
...
@@ -173,18 +177,50 @@ class BBoxHead(object):
...
@@ -173,18 +177,50 @@ class BBoxHead(object):
"""
"""
cls_score
,
bbox_pred
=
self
.
_get_output
(
roi_feat
)
cls_score
,
bbox_pred
=
self
.
_get_output
(
roi_feat
)
labels_int64
=
fluid
.
layers
.
cast
(
x
=
labels_int32
,
dtype
=
'int64'
)
labels_int64
=
fluid
.
layers
.
cast
(
x
=
labels_int32
,
dtype
=
'int64'
)
labels_int64
.
stop_gradient
=
True
labels_int64
.
stop_gradient
=
True
loss_cls
=
fluid
.
layers
.
softmax_with_cross_entropy
(
loss_cls
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
cls_score
,
label
=
labels_int64
,
numeric_stable_mode
=
True
)
logits
=
cls_score
,
label
=
labels_int64
,
numeric_stable_mode
=
True
)
loss_cls
=
fluid
.
layers
.
reduce_mean
(
loss_cls
)
loss_cls
=
fluid
.
layers
.
reduce_mean
(
loss_cls
)
loss_bbox
=
fluid
.
layers
.
smooth_l1
(
if
self
.
bbox_loss_type
==
'CiouLoss'
:
x
=
bbox_pred
,
from
.loss.diou_loss
import
DiouLoss
y
=
bbox_targets
,
loss_obj
=
DiouLoss
(
loss_weight
=
10.
,
inside_weight
=
bbox_inside_weights
,
is_cls_agnostic
=
False
,
outside_weight
=
bbox_outside_weights
,
num_classes
=
self
.
num_classes
,
sigma
=
self
.
sigma
)
use_complete_iou_loss
=
True
)
loss_bbox
=
loss_obj
(
x
=
bbox_pred
,
y
=
bbox_targets
,
inside_weight
=
bbox_inside_weights
,
outside_weight
=
bbox_outside_weights
)
elif
self
.
bbox_loss_type
==
'DiouLoss'
:
from
.loss.diou_loss
import
DiouLoss
loss_obj
=
DiouLoss
(
loss_weight
=
12.
,
is_cls_agnostic
=
False
,
num_classes
=
self
.
num_classes
,
use_complete_iou_loss
=
False
)
loss_bbox
=
loss_obj
(
x
=
bbox_pred
,
y
=
bbox_targets
,
inside_weight
=
bbox_inside_weights
,
outside_weight
=
bbox_outside_weights
)
elif
self
.
bbox_loss_type
==
'GiouLoss'
:
from
.loss.giou_loss
import
GiouLoss
loss_obj
=
GiouLoss
(
loss_weight
=
10.
,
is_cls_agnostic
=
False
,
num_classes
=
self
.
num_classes
)
loss_bbox
=
loss_obj
(
x
=
bbox_pred
,
y
=
bbox_targets
,
inside_weight
=
bbox_inside_weights
,
outside_weight
=
bbox_outside_weights
)
else
:
loss_bbox
=
fluid
.
layers
.
smooth_l1
(
x
=
bbox_pred
,
y
=
bbox_targets
,
inside_weight
=
bbox_inside_weights
,
outside_weight
=
bbox_outside_weights
,
sigma
=
self
.
sigma
)
loss_bbox
=
fluid
.
layers
.
reduce_mean
(
loss_bbox
)
loss_bbox
=
fluid
.
layers
.
reduce_mean
(
loss_bbox
)
return
{
'loss_cls'
:
loss_cls
,
'loss_bbox'
:
loss_bbox
}
return
{
'loss_cls'
:
loss_cls
,
'loss_bbox'
:
loss_bbox
}
...
@@ -229,14 +265,21 @@ class BBoxHead(object):
...
@@ -229,14 +265,21 @@ class BBoxHead(object):
cliped_box
=
fluid
.
layers
.
box_clip
(
input
=
decoded_box
,
im_info
=
im_shape
)
cliped_box
=
fluid
.
layers
.
box_clip
(
input
=
decoded_box
,
im_info
=
im_shape
)
if
return_box_score
:
if
return_box_score
:
return
{
'bbox'
:
cliped_box
,
'score'
:
cls_prob
}
return
{
'bbox'
:
cliped_box
,
'score'
:
cls_prob
}
pred_result
=
fluid
.
layers
.
multiclass_nms
(
if
self
.
bbox_loss_type
==
'CiouLoss'
:
bboxes
=
cliped_box
,
from
.nms
import
MultiClassDiouNMS
scores
=
cls_prob
,
nms_obj
=
MultiClassDiouNMS
(
score_threshold
=
self
.
score_threshold
,
score_threshold
=
self
.
score_threshold
,
nms_threshold
=
self
.
nms_threshold
,
nms_top_k
=
self
.
nms_top_k
,
keep_top_k
=
self
.
keep_top_k
)
keep_top_k
=
self
.
keep_top_k
,
pred_result
=
nms_obj
(
bboxes
=
cliped_box
,
scores
=
cls_prob
)
nms_threshold
=
self
.
nms_threshold
,
else
:
normalized
=
self
.
normalized
,
pred_result
=
fluid
.
layers
.
multiclass_nms
(
nms_eta
=
self
.
nms_eta
,
bboxes
=
cliped_box
,
background_label
=
self
.
background_label
)
scores
=
cls_prob
,
score_threshold
=
self
.
score_threshold
,
nms_top_k
=
self
.
nms_top_k
,
keep_top_k
=
self
.
keep_top_k
,
nms_threshold
=
self
.
nms_threshold
,
normalized
=
self
.
normalized
,
nms_eta
=
self
.
nms_eta
,
background_label
=
self
.
background_label
)
return
{
'bbox'
:
pred_result
}
return
{
'bbox'
:
pred_result
}
paddlex/cv/nets/detection/faster_rcnn.py
浏览文件 @
2f89e761
...
@@ -70,6 +70,7 @@ class FasterRCNN(object):
...
@@ -70,6 +70,7 @@ class FasterRCNN(object):
keep_top_k
=
100
,
keep_top_k
=
100
,
nms_threshold
=
0.5
,
nms_threshold
=
0.5
,
score_threshold
=
0.05
,
score_threshold
=
0.05
,
bbox_loss_type
=
'SmoothL1Loss'
,
#bbox_assigner
#bbox_assigner
batch_size_per_im
=
512
,
batch_size_per_im
=
512
,
fg_fraction
=
.
25
,
fg_fraction
=
.
25
,
...
@@ -145,7 +146,8 @@ class FasterRCNN(object):
...
@@ -145,7 +146,8 @@ class FasterRCNN(object):
keep_top_k
=
keep_top_k
,
keep_top_k
=
keep_top_k
,
nms_threshold
=
nms_threshold
,
nms_threshold
=
nms_threshold
,
score_threshold
=
score_threshold
,
score_threshold
=
score_threshold
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
bbox_loss_type
=
bbox_loss_type
)
self
.
bbox_head
=
bbox_head
self
.
bbox_head
=
bbox_head
self
.
batch_size_per_im
=
batch_size_per_im
self
.
batch_size_per_im
=
batch_size_per_im
self
.
fg_fraction
=
fg_fraction
self
.
fg_fraction
=
fg_fraction
...
@@ -189,7 +191,6 @@ class FasterRCNN(object):
...
@@ -189,7 +191,6 @@ class FasterRCNN(object):
bbox_reg_weights
=
self
.
bbox_reg_weights
,
bbox_reg_weights
=
self
.
bbox_reg_weights
,
class_nums
=
self
.
num_classes
,
class_nums
=
self
.
num_classes
,
use_random
=
self
.
rpn_head
.
use_random
)
use_random
=
self
.
rpn_head
.
use_random
)
rois
=
outputs
[
0
]
rois
=
outputs
[
0
]
labels_int32
=
outputs
[
1
]
labels_int32
=
outputs
[
1
]
bbox_targets
=
outputs
[
2
]
bbox_targets
=
outputs
[
2
]
...
@@ -211,10 +212,12 @@ class FasterRCNN(object):
...
@@ -211,10 +212,12 @@ class FasterRCNN(object):
else
:
else
:
roi_feat
=
self
.
roi_extractor
(
body_feats
,
rois
,
spatial_scale
)
roi_feat
=
self
.
roi_extractor
(
body_feats
,
rois
,
spatial_scale
)
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
loss
=
self
.
bbox_head
.
get_loss
(
roi_feat
,
labels_int32
,
loss
=
self
.
bbox_head
.
get_loss
(
roi_feat
,
labels_int32
,
bbox_targets
,
bbox_inside_weights
,
bbox_targets
,
bbox_inside_weights
,
bbox_outside_weights
)
bbox_outside_weights
)
loss
.
update
(
rpn_loss
)
loss
.
update
(
rpn_loss
)
total_loss
=
fluid
.
layers
.
sum
(
list
(
loss
.
values
()))
total_loss
=
fluid
.
layers
.
sum
(
list
(
loss
.
values
()))
loss
.
update
({
'loss'
:
total_loss
})
loss
.
update
({
'loss'
:
total_loss
})
...
...
paddlex/cv/nets/detection/loss/diou_loss.py
0 → 100644
浏览文件 @
2f89e761
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle
import
fluid
from
.giou_loss
import
GiouLoss
class
DiouLoss
(
GiouLoss
):
"""
Distance-IoU Loss, see https://arxiv.org/abs/1911.08287
Args:
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
use_complete_iou_loss (bool): whether to use complete iou loss
"""
def
__init__
(
self
,
loss_weight
=
10.
,
is_cls_agnostic
=
False
,
num_classes
=
81
,
use_complete_iou_loss
=
True
):
super
(
DiouLoss
,
self
).
__init__
(
loss_weight
=
loss_weight
,
is_cls_agnostic
=
is_cls_agnostic
,
num_classes
=
num_classes
)
self
.
use_complete_iou_loss
=
use_complete_iou_loss
def
__call__
(
self
,
x
,
y
,
inside_weight
=
None
,
outside_weight
=
None
,
bbox_reg_weight
=
[
0.1
,
0.1
,
0.2
,
0.2
]):
eps
=
1.e-10
x1
,
y1
,
x2
,
y2
=
self
.
bbox_transform
(
x
,
bbox_reg_weight
)
x1g
,
y1g
,
x2g
,
y2g
=
self
.
bbox_transform
(
y
,
bbox_reg_weight
)
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
x2
-
x1
h
=
y2
-
y1
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
hg
=
y2g
-
y1g
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
# A and B
xkis1
=
fluid
.
layers
.
elementwise_max
(
x1
,
x1g
)
ykis1
=
fluid
.
layers
.
elementwise_max
(
y1
,
y1g
)
xkis2
=
fluid
.
layers
.
elementwise_min
(
x2
,
x2g
)
ykis2
=
fluid
.
layers
.
elementwise_min
(
y2
,
y2g
)
# A or B
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
intsctk
=
(
xkis2
-
xkis1
)
*
(
ykis2
-
ykis1
)
intsctk
=
intsctk
*
fluid
.
layers
.
greater_than
(
xkis2
,
xkis1
)
*
fluid
.
layers
.
greater_than
(
ykis2
,
ykis1
)
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
*
(
cx
-
cxg
)
+
(
cy
-
cyg
)
*
(
cy
-
cyg
)
dist_union
=
(
xc2
-
xc1
)
*
(
xc2
-
xc1
)
+
(
yc2
-
yc1
)
*
(
yc2
-
yc1
)
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
# CIOU term
ciou_term
=
0
if
self
.
use_complete_iou_loss
:
ar_gt
=
wg
/
hg
ar_pred
=
w
/
h
arctan
=
fluid
.
layers
.
atan
(
ar_gt
)
-
fluid
.
layers
.
atan
(
ar_pred
)
ar_loss
=
4.
/
np
.
pi
/
np
.
pi
*
arctan
*
arctan
alpha
=
ar_loss
/
(
1
-
iouk
+
ar_loss
+
eps
)
alpha
.
stop_gradient
=
True
ciou_term
=
alpha
*
ar_loss
iou_weights
=
1
if
inside_weight
is
not
None
and
outside_weight
is
not
None
:
inside_weight
=
fluid
.
layers
.
reshape
(
inside_weight
,
shape
=
(
-
1
,
4
))
outside_weight
=
fluid
.
layers
.
reshape
(
outside_weight
,
shape
=
(
-
1
,
4
))
inside_weight
=
fluid
.
layers
.
reduce_mean
(
inside_weight
,
dim
=
1
)
outside_weight
=
fluid
.
layers
.
reduce_mean
(
outside_weight
,
dim
=
1
)
iou_weights
=
inside_weight
*
outside_weight
class_weight
=
2
if
self
.
is_cls_agnostic
else
self
.
num_classes
diou
=
fluid
.
layers
.
reduce_mean
(
(
1
-
iouk
+
ciou_term
+
diou_term
)
*
iou_weights
)
*
class_weight
return
diou
*
self
.
loss_weight
paddlex/cv/nets/detection/loss/giou_loss.py
0 → 100644
浏览文件 @
2f89e761
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle
import
fluid
class
GiouLoss
(
object
):
'''
Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630
Args:
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
'''
def
__init__
(
self
,
loss_weight
=
10.
,
is_cls_agnostic
=
False
,
num_classes
=
81
):
super
(
GiouLoss
,
self
).
__init__
()
self
.
loss_weight
=
loss_weight
self
.
is_cls_agnostic
=
is_cls_agnostic
self
.
num_classes
=
num_classes
# deltas: NxMx4
def
bbox_transform
(
self
,
deltas
,
weights
):
wx
,
wy
,
ww
,
wh
=
weights
deltas
=
fluid
.
layers
.
reshape
(
deltas
,
shape
=
(
0
,
-
1
,
4
))
dx
=
fluid
.
layers
.
slice
(
deltas
,
axes
=
[
2
],
starts
=
[
0
],
ends
=
[
1
])
*
wx
dy
=
fluid
.
layers
.
slice
(
deltas
,
axes
=
[
2
],
starts
=
[
1
],
ends
=
[
2
])
*
wy
dw
=
fluid
.
layers
.
slice
(
deltas
,
axes
=
[
2
],
starts
=
[
2
],
ends
=
[
3
])
*
ww
dh
=
fluid
.
layers
.
slice
(
deltas
,
axes
=
[
2
],
starts
=
[
3
],
ends
=
[
4
])
*
wh
dw
=
fluid
.
layers
.
clip
(
dw
,
-
1.e10
,
np
.
log
(
1000.
/
16
))
dh
=
fluid
.
layers
.
clip
(
dh
,
-
1.e10
,
np
.
log
(
1000.
/
16
))
pred_ctr_x
=
dx
pred_ctr_y
=
dy
pred_w
=
fluid
.
layers
.
exp
(
dw
)
pred_h
=
fluid
.
layers
.
exp
(
dh
)
x1
=
pred_ctr_x
-
0.5
*
pred_w
y1
=
pred_ctr_y
-
0.5
*
pred_h
x2
=
pred_ctr_x
+
0.5
*
pred_w
y2
=
pred_ctr_y
+
0.5
*
pred_h
x1
=
fluid
.
layers
.
reshape
(
x1
,
shape
=
(
-
1
,
))
y1
=
fluid
.
layers
.
reshape
(
y1
,
shape
=
(
-
1
,
))
x2
=
fluid
.
layers
.
reshape
(
x2
,
shape
=
(
-
1
,
))
y2
=
fluid
.
layers
.
reshape
(
y2
,
shape
=
(
-
1
,
))
return
x1
,
y1
,
x2
,
y2
def
__call__
(
self
,
x
,
y
,
inside_weight
=
None
,
outside_weight
=
None
,
bbox_reg_weight
=
[
0.1
,
0.1
,
0.2
,
0.2
]):
eps
=
1.e-10
x1
,
y1
,
x2
,
y2
=
self
.
bbox_transform
(
x
,
bbox_reg_weight
)
x1g
,
y1g
,
x2g
,
y2g
=
self
.
bbox_transform
(
y
,
bbox_reg_weight
)
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
xkis1
=
fluid
.
layers
.
elementwise_max
(
x1
,
x1g
)
ykis1
=
fluid
.
layers
.
elementwise_max
(
y1
,
y1g
)
xkis2
=
fluid
.
layers
.
elementwise_min
(
x2
,
x2g
)
ykis2
=
fluid
.
layers
.
elementwise_min
(
y2
,
y2g
)
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
intsctk
=
(
xkis2
-
xkis1
)
*
(
ykis2
-
ykis1
)
intsctk
=
intsctk
*
fluid
.
layers
.
greater_than
(
xkis2
,
xkis1
)
*
fluid
.
layers
.
greater_than
(
ykis2
,
ykis1
)
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
area_c
=
(
xc2
-
xc1
)
*
(
yc2
-
yc1
)
+
eps
miouk
=
iouk
-
((
area_c
-
unionk
)
/
area_c
)
iou_weights
=
1
if
inside_weight
is
not
None
and
outside_weight
is
not
None
:
inside_weight
=
fluid
.
layers
.
reshape
(
inside_weight
,
shape
=
(
-
1
,
4
))
outside_weight
=
fluid
.
layers
.
reshape
(
outside_weight
,
shape
=
(
-
1
,
4
))
inside_weight
=
fluid
.
layers
.
reduce_mean
(
inside_weight
,
dim
=
1
)
outside_weight
=
fluid
.
layers
.
reduce_mean
(
outside_weight
,
dim
=
1
)
iou_weights
=
inside_weight
*
outside_weight
class_weight
=
2
if
self
.
is_cls_agnostic
else
self
.
num_classes
iouk
=
fluid
.
layers
.
reduce_mean
((
1
-
iouk
)
*
iou_weights
)
*
class_weight
miouk
=
fluid
.
layers
.
reduce_mean
(
(
1
-
miouk
)
*
iou_weights
)
*
class_weight
return
miouk
*
self
.
loss_weight
paddlex/cv/nets/detection/nms.py
0 → 100644
浏览文件 @
2f89e761
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
paddle
import
fluid
class
MultiClassDiouNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
0.05
,
keep_top_k
=
100
,
nms_threshold
=
0.5
,
normalized
=
False
,
background_label
=
0
,
):
super
(
MultiClassDiouNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
nms_threshold
=
nms_threshold
self
.
keep_top_k
=
keep_top_k
self
.
normalized
=
normalized
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
def
create_tmp_var
(
program
,
name
,
dtype
,
shape
,
lod_level
):
return
program
.
current_block
().
create_var
(
name
=
name
,
dtype
=
dtype
,
shape
=
shape
,
lod_level
=
lod_level
)
def
_calc_diou_term
(
dets1
,
dets2
):
eps
=
1.e-10
eta
=
0
if
self
.
normalized
else
1
x1
,
y1
,
x2
,
y2
=
dets1
[
0
],
dets1
[
1
],
dets1
[
2
],
dets1
[
3
]
x1g
,
y1g
,
x2g
,
y2g
=
dets2
[
0
],
dets2
[
1
],
dets2
[
2
],
dets2
[
3
]
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
x2
-
x1
+
eta
h
=
y2
-
y1
+
eta
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
+
eta
hg
=
y2g
-
y1g
+
eta
x2
=
np
.
maximum
(
x1
,
x2
)
y2
=
np
.
maximum
(
y1
,
y2
)
# A or B
xc1
=
np
.
minimum
(
x1
,
x1g
)
yc1
=
np
.
minimum
(
y1
,
y1g
)
xc2
=
np
.
maximum
(
x2
,
x2g
)
yc2
=
np
.
maximum
(
y2
,
y2g
)
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
**
2
+
(
cy
-
cyg
)
**
2
dist_union
=
(
xc2
-
xc1
)
**
2
+
(
yc2
-
yc1
)
**
2
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
return
diou_term
def
_diou_nms_for_cls
(
dets
,
thres
):
"""_diou_nms_for_cls"""
scores
=
dets
[:,
0
]
x1
=
dets
[:,
1
]
y1
=
dets
[:,
2
]
x2
=
dets
[:,
3
]
y2
=
dets
[:,
4
]
eta
=
0
if
self
.
normalized
else
1
areas
=
(
x2
-
x1
+
eta
)
*
(
y2
-
y1
+
eta
)
dt_num
=
dets
.
shape
[
0
]
order
=
np
.
array
(
range
(
dt_num
))
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
xx1
=
np
.
maximum
(
x1
[
i
],
x1
[
order
[
1
:]])
yy1
=
np
.
maximum
(
y1
[
i
],
y1
[
order
[
1
:]])
xx2
=
np
.
minimum
(
x2
[
i
],
x2
[
order
[
1
:]])
yy2
=
np
.
minimum
(
y2
[
i
],
y2
[
order
[
1
:]])
w
=
np
.
maximum
(
0.0
,
xx2
-
xx1
+
eta
)
h
=
np
.
maximum
(
0.0
,
yy2
-
yy1
+
eta
)
inter
=
w
*
h
ovr
=
inter
/
(
areas
[
i
]
+
areas
[
order
[
1
:]]
-
inter
)
diou_term
=
_calc_diou_term
([
x1
[
i
],
y1
[
i
],
x2
[
i
],
y2
[
i
]],
[
x1
[
order
[
1
:]],
y1
[
order
[
1
:]],
x2
[
order
[
1
:]],
y2
[
order
[
1
:]]
])
inds
=
np
.
where
(
ovr
-
diou_term
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
dets_final
=
dets
[
keep
]
return
dets_final
def
_diou_nms
(
bboxes
,
scores
):
bboxes
=
np
.
array
(
bboxes
)
scores
=
np
.
array
(
scores
)
class_nums
=
scores
.
shape
[
-
1
]
score_threshold
=
self
.
score_threshold
nms_threshold
=
self
.
nms_threshold
keep_top_k
=
self
.
keep_top_k
cls_boxes
=
[[]
for
_
in
range
(
class_nums
)]
cls_ids
=
[[]
for
_
in
range
(
class_nums
)]
start_idx
=
1
if
self
.
background_label
==
0
else
0
for
j
in
range
(
start_idx
,
class_nums
):
inds
=
np
.
where
(
scores
[:,
j
]
>=
score_threshold
)[
0
]
scores_j
=
scores
[
inds
,
j
]
rois_j
=
bboxes
[
inds
,
j
,
:]
dets_j
=
np
.
hstack
((
scores_j
[:,
np
.
newaxis
],
rois_j
)).
astype
(
np
.
float32
,
copy
=
False
)
cls_rank
=
np
.
argsort
(
-
dets_j
[:,
0
])
dets_j
=
dets_j
[
cls_rank
]
cls_boxes
[
j
]
=
_diou_nms_for_cls
(
dets_j
,
thres
=
nms_threshold
)
cls_ids
[
j
]
=
np
.
array
([
j
]
*
cls_boxes
[
j
].
shape
[
0
]).
reshape
(
-
1
,
1
)
cls_boxes
=
np
.
vstack
(
cls_boxes
[
start_idx
:])
cls_ids
=
np
.
vstack
(
cls_ids
[
start_idx
:])
pred_result
=
np
.
hstack
([
cls_ids
,
cls_boxes
]).
astype
(
np
.
float32
)
# Limit to max_per_image detections **over all classes**
image_scores
=
cls_boxes
[:,
0
]
if
len
(
image_scores
)
>
keep_top_k
:
image_thresh
=
np
.
sort
(
image_scores
)[
-
keep_top_k
]
keep
=
np
.
where
(
cls_boxes
[:,
0
]
>=
image_thresh
)[
0
]
pred_result
=
pred_result
[
keep
,
:]
res
=
fluid
.
LoDTensor
()
res
.
set_lod
([[
0
,
pred_result
.
shape
[
0
]]])
if
pred_result
.
shape
[
0
]
==
0
:
pred_result
=
np
.
array
([[
1
]],
dtype
=
np
.
float32
)
res
.
set
(
pred_result
,
fluid
.
CPUPlace
())
return
res
pred_result
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
'diou_nms_pred_result'
,
dtype
=
'float32'
,
shape
=
[
-
1
,
6
],
lod_level
=
0
)
fluid
.
layers
.
py_func
(
func
=
_diou_nms
,
x
=
[
bboxes
,
scores
],
out
=
pred_result
)
return
pred_result
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录