Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cc01db60
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cc01db60
编写于
12月 28, 2018
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
calc valid gt before loss calc. test=develop
上级
32d533c2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
36 addition
and
5 deletion
+36
-5
paddle/fluid/operators/yolov3_loss_op.h
paddle/fluid/operators/yolov3_loss_op.h
+36
-5
未找到文件。
paddle/fluid/operators/yolov3_loss_op.h
浏览文件 @
cc01db60
...
...
@@ -219,6 +219,22 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss,
}
}
template
<
typename
T
>
static
void
inline
GtValid
(
bool
*
valid
,
const
T
*
gtbox
,
const
int
n
,
const
int
b
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
b
;
j
++
)
{
if
(
LessEqualZero
(
gtbox
[
j
*
4
+
2
])
||
LessEqualZero
(
gtbox
[
j
*
4
+
3
]))
{
valid
[
j
]
=
false
;
}
else
{
valid
[
j
]
=
true
;
}
}
valid
+=
b
;
gtbox
+=
b
*
4
;
}
}
template
<
typename
T
>
class
Yolov3LossKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -257,20 +273,28 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int
*
gt_match_mask_data
=
gt_match_mask
->
mutable_data
<
int
>
({
n
,
b
},
ctx
.
GetPlace
());
// calc valid gt box mask, avoid calc duplicately in following code
Tensor
gt_valid_mask
;
bool
*
gt_valid_mask_data
=
gt_valid_mask
.
mutable_data
<
bool
>
({
n
,
b
},
ctx
.
GetPlace
());
GtValid
<
T
>
(
gt_valid_mask_data
,
gt_box_data
,
n
,
b
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
mask_num
;
j
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
// each predict box find a best match gt box, if overlap is bigger
// then ignore_thresh, ignore the objectness loss.
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
mask_num
,
an_stride
,
stride
,
0
);
Box
<
T
>
pred
=
GetYoloBox
(
input_data
,
anchors
,
l
,
k
,
anchor_mask
[
j
],
h
,
input_size
,
box_idx
,
stride
);
T
best_iou
=
0
;
for
(
int
t
=
0
;
t
<
b
;
t
++
)
{
Box
<
T
>
gt
=
GetGtBox
(
gt_box_data
,
i
,
b
,
t
);
if
(
LessEqualZero
<
T
>
(
gt
.
w
)
||
LessEqualZero
<
T
>
(
gt
.
h
))
{
if
(
!
gt_valid_mask_data
[
i
*
b
+
t
])
{
continue
;
}
Box
<
T
>
gt
=
GetGtBox
(
gt_box_data
,
i
,
b
,
t
);
T
iou
=
CalcBoxIoU
(
pred
,
gt
);
if
(
iou
>
best_iou
)
{
best_iou
=
iou
;
...
...
@@ -281,15 +305,18 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int
obj_idx
=
(
i
*
mask_num
+
j
)
*
stride
+
k
*
w
+
l
;
obj_mask_data
[
obj_idx
]
=
-
1
;
}
// TODO(dengkaipeng): all losses should be calculated if best IoU
// is bigger then truth thresh should be calculated here, but
// currently, truth thresh is an unreachable value as 1.0.
}
}
}
for
(
int
t
=
0
;
t
<
b
;
t
++
)
{
Box
<
T
>
gt
=
GetGtBox
(
gt_box_data
,
i
,
b
,
t
);
if
(
LessEqualZero
<
T
>
(
gt
.
w
)
||
LessEqualZero
<
T
>
(
gt
.
h
))
{
if
(
!
gt_valid_mask_data
[
i
*
b
+
t
])
{
gt_match_mask_data
[
i
*
b
+
t
]
=
-
1
;
continue
;
}
Box
<
T
>
gt
=
GetGtBox
(
gt_box_data
,
i
,
b
,
t
);
int
gi
=
static_cast
<
int
>
(
gt
.
x
*
w
);
int
gj
=
static_cast
<
int
>
(
gt
.
y
*
h
);
Box
<
T
>
gt_shift
=
gt
;
...
...
@@ -297,6 +324,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
gt_shift
.
y
=
0.0
;
T
best_iou
=
0.0
;
int
best_n
=
0
;
// each gt box find a best match anchor box as positive sample,
// for positive sample, all losses should be calculated, and for
// other samples, only objectness loss is required.
for
(
int
an_idx
=
0
;
an_idx
<
an_num
;
an_idx
++
)
{
Box
<
T
>
an_box
;
an_box
.
x
=
0.0
;
...
...
@@ -304,7 +334,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
an_box
.
w
=
anchors
[
2
*
an_idx
]
/
static_cast
<
T
>
(
input_size
);
an_box
.
h
=
anchors
[
2
*
an_idx
+
1
]
/
static_cast
<
T
>
(
input_size
);
float
iou
=
CalcBoxIoU
<
T
>
(
an_box
,
gt_shift
);
// TO DO: iou > 0.5 ?
// TODO(dengkaipeng): In paper, objectness loss is ignore when
// best IoU > 0.5, but darknet code didn't implement this.
if
(
iou
>
best_iou
)
{
best_iou
=
iou
;
best_n
=
an_idx
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录