Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
92b9ce34
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
92b9ce34
编写于
3月 15, 2019
作者:
X
Xin Pan
提交者:
GitHub
3月 15, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16073 from heavengate/yolov3_loss_imporve
Yolov3 loss: add mixup score and label smooth
上级
8ad672a2
2c0abba0
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
224 addition
and
69 deletion
+224
-69
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/detection/yolov3_loss_op.cc
paddle/fluid/operators/detection/yolov3_loss_op.cc
+33
-0
paddle/fluid/operators/detection/yolov3_loss_op.h
paddle/fluid/operators/detection/yolov3_loss_op.h
+79
-26
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+31
-12
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+10
-2
python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
+70
-28
未找到文件。
paddle/fluid/API.spec
浏览文件 @
92b9ce34
...
...
@@ -330,7 +330,7 @@ paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes',
paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '587845f60c5d97ffdf2dfd21da52eca1'))
paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e'))
paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', '
name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691
'))
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', '
gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8
'))
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
...
...
paddle/fluid/operators/detection/yolov3_loss_op.cc
浏览文件 @
92b9ce34
...
...
@@ -10,6 +10,7 @@
limitations under the License. */
#include "paddle/fluid/operators/detection/yolov3_loss_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
...
...
@@ -72,6 +73,18 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT
(
class_num
,
0
,
"Attr(class_num) should be an integer greater then 0."
);
if
(
ctx
->
HasInput
(
"GTScore"
))
{
auto
dim_gtscore
=
ctx
->
GetInputDim
(
"GTScore"
);
PADDLE_ENFORCE_EQ
(
dim_gtscore
.
size
(),
2
,
"Input(GTScore) should be a 2-D tensor"
);
PADDLE_ENFORCE_EQ
(
dim_gtscore
[
0
],
dim_gtbox
[
0
],
"Input(GTBox) and Input(GTScore) dim[0] should be same"
);
PADDLE_ENFORCE_EQ
(
dim_gtscore
[
1
],
dim_gtbox
[
1
],
"Input(GTBox) and Input(GTScore) dim[1] should be same"
);
}
std
::
vector
<
int64_t
>
dim_out
({
dim_x
[
0
]});
ctx
->
SetOutputDim
(
"Loss"
,
framework
::
make_ddim
(
dim_out
));
...
...
@@ -112,6 +125,12 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"This is a 2-D tensor with shape of [N, max_box_num], "
"and each element should be an integer to indicate the "
"box class id."
);
AddInput
(
"GTScore"
,
"The score of GTLabel, This is a 2-D tensor in same shape "
"GTLabel, and score values should in range (0, 1). This "
"input is for GTLabel score can be not 1.0 in image mixup "
"augmentation."
)
.
AsDispensable
();
AddOutput
(
"Loss"
,
"The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [N]"
);
...
...
@@ -143,6 +162,9 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
float
>
(
"ignore_thresh"
,
"The ignore threshold to ignore confidence loss."
)
.
SetDefault
(
0.7
);
AddAttr
<
bool
>
(
"use_label_smooth"
,
"Whether to use label smooth. Default True."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
This operator generates yolov3 loss based on given predict result and ground
truth boxes.
...
...
@@ -204,6 +226,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
loss = (loss_{xy} + loss_{wh}) * weight_{box}
+ loss_{conf} + loss_{class}
$$
While :attr:`use_label_smooth` is set to be :attr:`True`, the classification
target will be smoothed when calculating classification loss, target of
positive samples will be smoothed to :math:`1.0 - 1.0 / class\_num` and target of
negetive samples will be smoothed to :math:`1.0 / class\_num`.
While :attr:`GTScore` is given, which means the mixup score of ground truth
boxes, all losses incured by a ground truth box will be multiplied by its
mixup score.
)DOC"
);
}
};
...
...
@@ -240,6 +271,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"GTBox"
,
Input
(
"GTBox"
));
op
->
SetInput
(
"GTLabel"
,
Input
(
"GTLabel"
));
op
->
SetInput
(
"GTScore"
,
Input
(
"GTScore"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Loss"
),
OutputGrad
(
"Loss"
));
op
->
SetInput
(
"ObjectnessMask"
,
Output
(
"ObjectnessMask"
));
op
->
SetInput
(
"GTMatchMask"
,
Output
(
"GTMatchMask"
));
...
...
@@ -249,6 +281,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"GTBox"
),
{});
op
->
SetOutput
(
framework
::
GradVarName
(
"GTLabel"
),
{});
op
->
SetOutput
(
framework
::
GradVarName
(
"GTScore"
),
{});
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
...
...
paddle/fluid/operators/detection/yolov3_loss_op.h
浏览文件 @
92b9ce34
...
...
@@ -37,8 +37,8 @@ static T SigmoidCrossEntropy(T x, T label) {
}
template
<
typename
T
>
static
T
L
2
Loss
(
T
x
,
T
y
)
{
return
0.5
*
(
y
-
x
)
*
(
y
-
x
);
static
T
L
1
Loss
(
T
x
,
T
y
)
{
return
std
::
abs
(
y
-
x
);
}
template
<
typename
T
>
...
...
@@ -47,8 +47,8 @@ static T SigmoidCrossEntropyGrad(T x, T label) {
}
template
<
typename
T
>
static
T
L
2
LossGrad
(
T
x
,
T
y
)
{
return
x
-
y
;
static
T
L
1
LossGrad
(
T
x
,
T
y
)
{
return
x
>
y
?
1.0
:
-
1.0
;
}
static
int
GetMaskIndex
(
std
::
vector
<
int
>
mask
,
int
val
)
{
...
...
@@ -121,47 +121,49 @@ template <typename T>
static
void
CalcBoxLocationLoss
(
T
*
loss
,
const
T
*
input
,
Box
<
T
>
gt
,
std
::
vector
<
int
>
anchors
,
int
an_idx
,
int
box_idx
,
int
gi
,
int
gj
,
int
grid_size
,
int
input_size
,
int
stride
)
{
int
input_size
,
int
stride
,
T
score
)
{
T
tx
=
gt
.
x
*
grid_size
-
gi
;
T
ty
=
gt
.
y
*
grid_size
-
gj
;
T
tw
=
std
::
log
(
gt
.
w
*
input_size
/
anchors
[
2
*
an_idx
]);
T
th
=
std
::
log
(
gt
.
h
*
input_size
/
anchors
[
2
*
an_idx
+
1
]);
T
scale
=
(
2.0
-
gt
.
w
*
gt
.
h
);
T
scale
=
(
2.0
-
gt
.
w
*
gt
.
h
)
*
score
;
loss
[
0
]
+=
SigmoidCrossEntropy
<
T
>
(
input
[
box_idx
],
tx
)
*
scale
;
loss
[
0
]
+=
SigmoidCrossEntropy
<
T
>
(
input
[
box_idx
+
stride
],
ty
)
*
scale
;
loss
[
0
]
+=
L
2
Loss
<
T
>
(
input
[
box_idx
+
2
*
stride
],
tw
)
*
scale
;
loss
[
0
]
+=
L
2
Loss
<
T
>
(
input
[
box_idx
+
3
*
stride
],
th
)
*
scale
;
loss
[
0
]
+=
L
1
Loss
<
T
>
(
input
[
box_idx
+
2
*
stride
],
tw
)
*
scale
;
loss
[
0
]
+=
L
1
Loss
<
T
>
(
input
[
box_idx
+
3
*
stride
],
th
)
*
scale
;
}
template
<
typename
T
>
static
void
CalcBoxLocationLossGrad
(
T
*
input_grad
,
const
T
loss
,
const
T
*
input
,
Box
<
T
>
gt
,
std
::
vector
<
int
>
anchors
,
int
an_idx
,
int
box_idx
,
int
gi
,
int
gj
,
int
grid_size
,
int
input_size
,
int
stride
)
{
int
grid_size
,
int
input_size
,
int
stride
,
T
score
)
{
T
tx
=
gt
.
x
*
grid_size
-
gi
;
T
ty
=
gt
.
y
*
grid_size
-
gj
;
T
tw
=
std
::
log
(
gt
.
w
*
input_size
/
anchors
[
2
*
an_idx
]);
T
th
=
std
::
log
(
gt
.
h
*
input_size
/
anchors
[
2
*
an_idx
+
1
]);
T
scale
=
(
2.0
-
gt
.
w
*
gt
.
h
);
T
scale
=
(
2.0
-
gt
.
w
*
gt
.
h
)
*
score
;
input_grad
[
box_idx
]
=
SigmoidCrossEntropyGrad
<
T
>
(
input
[
box_idx
],
tx
)
*
scale
*
loss
;
input_grad
[
box_idx
+
stride
]
=
SigmoidCrossEntropyGrad
<
T
>
(
input
[
box_idx
+
stride
],
ty
)
*
scale
*
loss
;
input_grad
[
box_idx
+
2
*
stride
]
=
L
2
LossGrad
<
T
>
(
input
[
box_idx
+
2
*
stride
],
tw
)
*
scale
*
loss
;
L
1
LossGrad
<
T
>
(
input
[
box_idx
+
2
*
stride
],
tw
)
*
scale
*
loss
;
input_grad
[
box_idx
+
3
*
stride
]
=
L
2
LossGrad
<
T
>
(
input
[
box_idx
+
3
*
stride
],
th
)
*
scale
*
loss
;
L
1
LossGrad
<
T
>
(
input
[
box_idx
+
3
*
stride
],
th
)
*
scale
*
loss
;
}
template
<
typename
T
>
static
inline
void
CalcLabelLoss
(
T
*
loss
,
const
T
*
input
,
const
int
index
,
const
int
label
,
const
int
class_num
,
const
int
stride
)
{
const
int
stride
,
const
T
pos
,
const
T
neg
,
T
score
)
{
for
(
int
i
=
0
;
i
<
class_num
;
i
++
)
{
T
pred
=
input
[
index
+
i
*
stride
];
loss
[
0
]
+=
SigmoidCrossEntropy
<
T
>
(
pred
,
(
i
==
label
)
?
1.0
:
0.0
)
;
loss
[
0
]
+=
SigmoidCrossEntropy
<
T
>
(
pred
,
(
i
==
label
)
?
pos
:
neg
)
*
score
;
}
}
...
...
@@ -169,11 +171,13 @@ template <typename T>
static
inline
void
CalcLabelLossGrad
(
T
*
input_grad
,
const
T
loss
,
const
T
*
input
,
const
int
index
,
const
int
label
,
const
int
class_num
,
const
int
stride
)
{
const
int
stride
,
const
T
pos
,
const
T
neg
,
T
score
)
{
for
(
int
i
=
0
;
i
<
class_num
;
i
++
)
{
T
pred
=
input
[
index
+
i
*
stride
];
input_grad
[
index
+
i
*
stride
]
=
SigmoidCrossEntropyGrad
<
T
>
(
pred
,
(
i
==
label
)
?
1.0
:
0.0
)
*
loss
;
SigmoidCrossEntropyGrad
<
T
>
(
pred
,
(
i
==
label
)
?
pos
:
neg
)
*
score
*
loss
;
}
}
...
...
@@ -188,8 +192,8 @@ static inline void CalcObjnessLoss(T* loss, const T* input, const T* objness,
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
T
obj
=
objness
[
k
*
w
+
l
];
if
(
obj
>
1e-5
)
{
// positive sample: obj =
1
loss
[
i
]
+=
SigmoidCrossEntropy
<
T
>
(
input
[
k
*
w
+
l
],
1.0
);
// positive sample: obj =
mixup score
loss
[
i
]
+=
SigmoidCrossEntropy
<
T
>
(
input
[
k
*
w
+
l
],
1.0
)
*
obj
;
}
else
if
(
obj
>
-
0.5
)
{
// negetive sample: obj = 0
loss
[
i
]
+=
SigmoidCrossEntropy
<
T
>
(
input
[
k
*
w
+
l
],
0.0
);
...
...
@@ -215,7 +219,8 @@ static inline void CalcObjnessLossGrad(T* input_grad, const T* loss,
T
obj
=
objness
[
k
*
w
+
l
];
if
(
obj
>
1e-5
)
{
input_grad
[
k
*
w
+
l
]
=
SigmoidCrossEntropyGrad
<
T
>
(
input
[
k
*
w
+
l
],
1.0
)
*
loss
[
i
];
SigmoidCrossEntropyGrad
<
T
>
(
input
[
k
*
w
+
l
],
1.0
)
*
obj
*
loss
[
i
];
}
else
if
(
obj
>
-
0.5
)
{
input_grad
[
k
*
w
+
l
]
=
SigmoidCrossEntropyGrad
<
T
>
(
input
[
k
*
w
+
l
],
0.0
)
*
loss
[
i
];
...
...
@@ -252,6 +257,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
gt_box
=
ctx
.
Input
<
Tensor
>
(
"GTBox"
);
auto
*
gt_label
=
ctx
.
Input
<
Tensor
>
(
"GTLabel"
);
auto
*
gt_score
=
ctx
.
Input
<
Tensor
>
(
"GTScore"
);
auto
*
loss
=
ctx
.
Output
<
Tensor
>
(
"Loss"
);
auto
*
objness_mask
=
ctx
.
Output
<
Tensor
>
(
"ObjectnessMask"
);
auto
*
gt_match_mask
=
ctx
.
Output
<
Tensor
>
(
"GTMatchMask"
);
...
...
@@ -260,6 +266,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int
class_num
=
ctx
.
Attr
<
int
>
(
"class_num"
);
float
ignore_thresh
=
ctx
.
Attr
<
float
>
(
"ignore_thresh"
);
int
downsample_ratio
=
ctx
.
Attr
<
int
>
(
"downsample_ratio"
);
bool
use_label_smooth
=
ctx
.
Attr
<
bool
>
(
"use_label_smooth"
);
const
int
n
=
input
->
dims
()[
0
];
const
int
h
=
input
->
dims
()[
2
];
...
...
@@ -272,6 +279,13 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
const
int
stride
=
h
*
w
;
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
T
label_pos
=
1.0
;
T
label_neg
=
0.0
;
if
(
use_label_smooth
)
{
label_pos
=
1.0
-
1.0
/
static_cast
<
T
>
(
class_num
);
label_neg
=
1.0
/
static_cast
<
T
>
(
class_num
);
}
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
gt_box_data
=
gt_box
->
data
<
T
>
();
const
int
*
gt_label_data
=
gt_label
->
data
<
int
>
();
...
...
@@ -283,6 +297,19 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int
*
gt_match_mask_data
=
gt_match_mask
->
mutable_data
<
int
>
({
n
,
b
},
ctx
.
GetPlace
());
const
T
*
gt_score_data
;
if
(
!
gt_score
)
{
Tensor
gtscore
;
gtscore
.
mutable_data
<
T
>
({
n
,
b
},
ctx
.
GetPlace
());
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
()(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
&
gtscore
,
static_cast
<
T
>
(
1.0
));
gt_score
=
&
gtscore
;
gt_score_data
=
gtscore
.
data
<
T
>
();
}
else
{
gt_score_data
=
gt_score
->
data
<
T
>
();
}
// calc valid gt box mask, avoid calc duplicately in following code
Tensor
gt_valid_mask
;
bool
*
gt_valid_mask_data
=
...
...
@@ -355,19 +382,20 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
int
mask_idx
=
GetMaskIndex
(
anchor_mask
,
best_n
);
gt_match_mask_data
[
i
*
b
+
t
]
=
mask_idx
;
if
(
mask_idx
>=
0
)
{
T
score
=
gt_score_data
[
i
*
b
+
t
];
int
box_idx
=
GetEntryIndex
(
i
,
mask_idx
,
gj
*
w
+
gi
,
mask_num
,
an_stride
,
stride
,
0
);
CalcBoxLocationLoss
<
T
>
(
loss_data
+
i
,
input_data
,
gt
,
anchors
,
best_n
,
box_idx
,
gi
,
gj
,
h
,
input_size
,
stride
);
box_idx
,
gi
,
gj
,
h
,
input_size
,
stride
,
score
);
int
obj_idx
=
(
i
*
mask_num
+
mask_idx
)
*
stride
+
gj
*
w
+
gi
;
obj_mask_data
[
obj_idx
]
=
1.0
;
obj_mask_data
[
obj_idx
]
=
score
;
int
label
=
gt_label_data
[
i
*
b
+
t
];
int
label_idx
=
GetEntryIndex
(
i
,
mask_idx
,
gj
*
w
+
gi
,
mask_num
,
an_stride
,
stride
,
5
);
CalcLabelLoss
<
T
>
(
loss_data
+
i
,
input_data
,
label_idx
,
label
,
class_num
,
stride
);
class_num
,
stride
,
label_pos
,
label_neg
,
score
);
}
}
}
...
...
@@ -384,6 +412,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
gt_box
=
ctx
.
Input
<
Tensor
>
(
"GTBox"
);
auto
*
gt_label
=
ctx
.
Input
<
Tensor
>
(
"GTLabel"
);
auto
*
gt_score
=
ctx
.
Input
<
Tensor
>
(
"GTScore"
);
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
loss_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
auto
*
objness_mask
=
ctx
.
Input
<
Tensor
>
(
"ObjectnessMask"
);
...
...
@@ -392,6 +421,7 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
auto
anchor_mask
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"anchor_mask"
);
int
class_num
=
ctx
.
Attr
<
int
>
(
"class_num"
);
int
downsample_ratio
=
ctx
.
Attr
<
int
>
(
"downsample_ratio"
);
bool
use_label_smooth
=
ctx
.
Attr
<
bool
>
(
"use_label_smooth"
);
const
int
n
=
input_grad
->
dims
()[
0
];
const
int
c
=
input_grad
->
dims
()[
1
];
...
...
@@ -404,6 +434,13 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
const
int
stride
=
h
*
w
;
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
T
label_pos
=
1.0
;
T
label_neg
=
0.0
;
if
(
use_label_smooth
)
{
label_pos
=
1.0
-
1.0
/
static_cast
<
T
>
(
class_num
);
label_neg
=
1.0
/
static_cast
<
T
>
(
class_num
);
}
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
gt_box_data
=
gt_box
->
data
<
T
>
();
const
int
*
gt_label_data
=
gt_label
->
data
<
int
>
();
...
...
@@ -414,25 +451,41 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
input_grad
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
memset
(
input_grad_data
,
0
,
input_grad
->
numel
()
*
sizeof
(
T
));
const
T
*
gt_score_data
;
if
(
!
gt_score
)
{
Tensor
gtscore
;
gtscore
.
mutable_data
<
T
>
({
n
,
b
},
ctx
.
GetPlace
());
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
()(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
&
gtscore
,
static_cast
<
T
>
(
1.0
));
gt_score
=
&
gtscore
;
gt_score_data
=
gtscore
.
data
<
T
>
();
}
else
{
gt_score_data
=
gt_score
->
data
<
T
>
();
}
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
t
=
0
;
t
<
b
;
t
++
)
{
int
mask_idx
=
gt_match_mask_data
[
i
*
b
+
t
];
if
(
mask_idx
>=
0
)
{
T
score
=
gt_score_data
[
i
*
b
+
t
];
Box
<
T
>
gt
=
GetGtBox
(
gt_box_data
,
i
,
b
,
t
);
int
gi
=
static_cast
<
int
>
(
gt
.
x
*
w
);
int
gj
=
static_cast
<
int
>
(
gt
.
y
*
h
);
int
box_idx
=
GetEntryIndex
(
i
,
mask_idx
,
gj
*
w
+
gi
,
mask_num
,
an_stride
,
stride
,
0
);
CalcBoxLocationLossGrad
<
T
>
(
input_grad_data
,
loss_grad_data
[
i
],
input_data
,
gt
,
anchors
,
anchor_mask
[
mask_idx
],
box_idx
,
gi
,
gj
,
h
,
input_size
,
stride
);
CalcBoxLocationLossGrad
<
T
>
(
input_grad_data
,
loss_grad_data
[
i
],
input_data
,
gt
,
anchors
,
anchor_mask
[
mask_idx
],
box_idx
,
gi
,
gj
,
h
,
input_size
,
stride
,
score
);
int
label
=
gt_label_data
[
i
*
b
+
t
];
int
label_idx
=
GetEntryIndex
(
i
,
mask_idx
,
gj
*
w
+
gi
,
mask_num
,
an_stride
,
stride
,
5
);
CalcLabelLossGrad
<
T
>
(
input_grad_data
,
loss_grad_data
[
i
],
input_data
,
label_idx
,
label
,
class_num
,
stride
);
label_idx
,
label
,
class_num
,
stride
,
label_pos
,
label_neg
,
score
);
}
}
}
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
92b9ce34
...
...
@@ -515,6 +515,8 @@ def yolov3_loss(x,
class_num
,
ignore_thresh
,
downsample_ratio
,
gtscore
=
None
,
use_label_smooth
=
True
,
name
=
None
):
"""
${comment}
...
...
@@ -533,28 +535,35 @@ def yolov3_loss(x,
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
name (string): the name of yolov3 loss
name (string): the name of yolov3 loss. Default None.
gtscore (Variable): mixup score of ground truth boxes, shoud be in shape
of [N, B]. Default None.
use_label_smooth (bool): ${use_label_smooth_comment}
Returns:
Variable: A 1-D tensor with shape [
1
], the value of yolov3 loss
Variable: A 1-D tensor with shape [
N
], the value of yolov3 loss
Raises:
TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable"
TypeError: Input gtlabel of yolov3_loss must be Variable"
TypeError: Input gtbox of yolov3_loss must be Variable
TypeError: Input gtlabel of yolov3_loss must be Variable
TypeError: Input gtscore of yolov3_loss must be None or Variable
TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number
TypeError: Attr use_label_smooth of yolov3_loss must be a bool value
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 4], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6], dtype='int32')
gtscore = fluid.layers.data(name='gtscore', shape=[6], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
anchor_mask = [0, 1, 2]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel, anchors=anchors,
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, gtlabel=gtlabel,
gtscore=gtscore, anchors=anchors,
anchor_mask=anchor_mask, class_num=80,
ignore_thresh=0.7, downsample_ratio=32)
"""
...
...
@@ -566,6 +575,8 @@ def yolov3_loss(x,
raise
TypeError
(
"Input gtbox of yolov3_loss must be Variable"
)
if
not
isinstance
(
gtlabel
,
Variable
):
raise
TypeError
(
"Input gtlabel of yolov3_loss must be Variable"
)
if
gtscore
is
not
None
and
not
isinstance
(
gtscore
,
Variable
):
raise
TypeError
(
"Input gtscore of yolov3_loss must be Variable"
)
if
not
isinstance
(
anchors
,
list
)
and
not
isinstance
(
anchors
,
tuple
):
raise
TypeError
(
"Attr anchors of yolov3_loss must be list or tuple"
)
if
not
isinstance
(
anchor_mask
,
list
)
and
not
isinstance
(
anchor_mask
,
tuple
):
...
...
@@ -575,6 +586,9 @@ def yolov3_loss(x,
if
not
isinstance
(
ignore_thresh
,
float
):
raise
TypeError
(
"Attr ignore_thresh of yolov3_loss must be a float number"
)
if
not
isinstance
(
use_label_smooth
,
bool
):
raise
TypeError
(
"Attr use_label_smooth of yolov3_loss must be a bool value"
)
if
name
is
None
:
loss
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
...
...
@@ -585,21 +599,26 @@ def yolov3_loss(x,
objectness_mask
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
gt_match_mask
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
inputs
=
{
"X"
:
x
,
"GTBox"
:
gtbox
,
"GTLabel"
:
gtlabel
,
}
if
gtscore
:
inputs
[
"GTScore"
]
=
gtscore
attrs
=
{
"anchors"
:
anchors
,
"anchor_mask"
:
anchor_mask
,
"class_num"
:
class_num
,
"ignore_thresh"
:
ignore_thresh
,
"downsample_ratio"
:
downsample_ratio
,
"use_label_smooth"
:
use_label_smooth
,
}
helper
.
append_op
(
type
=
'yolov3_loss'
,
inputs
=
{
"X"
:
x
,
"GTBox"
:
gtbox
,
"GTLabel"
:
gtlabel
,
},
inputs
=
inputs
,
outputs
=
{
'Loss'
:
loss
,
'ObjectnessMask'
:
objectness_mask
,
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
92b9ce34
...
...
@@ -476,8 +476,16 @@ class TestYoloDetection(unittest.TestCase):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
30
,
7
,
7
],
dtype
=
'float32'
)
gtbox
=
layers
.
data
(
name
=
'gtbox'
,
shape
=
[
10
,
4
],
dtype
=
'float32'
)
gtlabel
=
layers
.
data
(
name
=
'gtlabel'
,
shape
=
[
10
],
dtype
=
'int32'
)
loss
=
layers
.
yolov3_loss
(
x
,
gtbox
,
gtlabel
,
[
10
,
13
,
30
,
13
],
[
0
,
1
],
10
,
0.7
,
32
)
gtscore
=
layers
.
data
(
name
=
'gtscore'
,
shape
=
[
10
],
dtype
=
'float32'
)
loss
=
layers
.
yolov3_loss
(
x
,
gtbox
,
gtlabel
,
[
10
,
13
,
30
,
13
],
[
0
,
1
],
10
,
0.7
,
32
,
gtscore
=
gtscore
,
use_label_smooth
=
False
)
self
.
assertIsNotNone
(
loss
)
...
...
python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py
浏览文件 @
92b9ce34
...
...
@@ -23,8 +23,8 @@ from op_test import OpTest
from
paddle.fluid
import
core
def
l
2
loss
(
x
,
y
):
return
0.5
*
(
y
-
x
)
*
(
y
-
x
)
def
l
1
loss
(
x
,
y
):
return
abs
(
x
-
y
)
def
sce
(
x
,
label
):
...
...
@@ -66,7 +66,7 @@ def batch_xywh_box_iou(box1, box2):
return
inter_area
/
union
def
YOLOv3Loss
(
x
,
gtbox
,
gtlabel
,
attrs
):
def
YOLOv3Loss
(
x
,
gtbox
,
gtlabel
,
gtscore
,
attrs
):
n
,
c
,
h
,
w
=
x
.
shape
b
=
gtbox
.
shape
[
1
]
anchors
=
attrs
[
'anchors'
]
...
...
@@ -75,21 +75,21 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
mask_num
=
len
(
anchor_mask
)
class_num
=
attrs
[
"class_num"
]
ignore_thresh
=
attrs
[
'ignore_thresh'
]
downsample
=
attrs
[
'downsample'
]
input_size
=
downsample
*
h
downsample_ratio
=
attrs
[
'downsample_ratio'
]
use_label_smooth
=
attrs
[
'use_label_smooth'
]
input_size
=
downsample_ratio
*
h
x
=
x
.
reshape
((
n
,
mask_num
,
5
+
class_num
,
h
,
w
)).
transpose
((
0
,
1
,
3
,
4
,
2
))
loss
=
np
.
zeros
((
n
)).
astype
(
'float32'
)
label_pos
=
1.0
-
1.0
/
class_num
if
use_label_smooth
else
1.0
label_neg
=
1.0
/
class_num
if
use_label_smooth
else
0.0
pred_box
=
x
[:,
:,
:,
:,
:
4
].
copy
()
grid_x
=
np
.
tile
(
np
.
arange
(
w
).
reshape
((
1
,
w
)),
(
h
,
1
))
grid_y
=
np
.
tile
(
np
.
arange
(
h
).
reshape
((
h
,
1
)),
(
1
,
w
))
pred_box
[:,
:,
:,
:,
0
]
=
(
grid_x
+
sigmoid
(
pred_box
[:,
:,
:,
:,
0
]))
/
w
pred_box
[:,
:,
:,
:,
1
]
=
(
grid_y
+
sigmoid
(
pred_box
[:,
:,
:,
:,
1
]))
/
h
x
[:,
:,
:,
:,
5
:]
=
np
.
where
(
x
[:,
:,
:,
:,
5
:]
<
-
0.5
,
x
[:,
:,
:,
:,
5
:],
np
.
ones_like
(
x
[:,
:,
:,
:,
5
:])
*
1.0
/
class_num
)
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
((
anchors
[
2
*
m
],
anchors
[
2
*
m
+
1
]))
...
...
@@ -138,21 +138,22 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
ty
=
gtbox
[
i
,
j
,
1
]
*
w
-
gj
tw
=
np
.
log
(
gtbox
[
i
,
j
,
2
]
*
input_size
/
mask_anchors
[
an_idx
][
0
])
th
=
np
.
log
(
gtbox
[
i
,
j
,
3
]
*
input_size
/
mask_anchors
[
an_idx
][
1
])
scale
=
(
2.0
-
gtbox
[
i
,
j
,
2
]
*
gtbox
[
i
,
j
,
3
])
scale
=
(
2.0
-
gtbox
[
i
,
j
,
2
]
*
gtbox
[
i
,
j
,
3
])
*
gtscore
[
i
,
j
]
loss
[
i
]
+=
sce
(
x
[
i
,
an_idx
,
gj
,
gi
,
0
],
tx
)
*
scale
loss
[
i
]
+=
sce
(
x
[
i
,
an_idx
,
gj
,
gi
,
1
],
ty
)
*
scale
loss
[
i
]
+=
l
2
loss
(
x
[
i
,
an_idx
,
gj
,
gi
,
2
],
tw
)
*
scale
loss
[
i
]
+=
l
2
loss
(
x
[
i
,
an_idx
,
gj
,
gi
,
3
],
th
)
*
scale
loss
[
i
]
+=
l
1
loss
(
x
[
i
,
an_idx
,
gj
,
gi
,
2
],
tw
)
*
scale
loss
[
i
]
+=
l
1
loss
(
x
[
i
,
an_idx
,
gj
,
gi
,
3
],
th
)
*
scale
objness
[
i
,
an_idx
*
h
*
w
+
gj
*
w
+
gi
]
=
1.0
objness
[
i
,
an_idx
*
h
*
w
+
gj
*
w
+
gi
]
=
gtscore
[
i
,
j
]
for
label_idx
in
range
(
class_num
):
loss
[
i
]
+=
sce
(
x
[
i
,
an_idx
,
gj
,
gi
,
5
+
label_idx
],
float
(
label_idx
==
gtlabel
[
i
,
j
]))
loss
[
i
]
+=
sce
(
x
[
i
,
an_idx
,
gj
,
gi
,
5
+
label_idx
],
label_pos
if
label_idx
==
gtlabel
[
i
,
j
]
else
label_neg
)
*
gtscore
[
i
,
j
]
for
j
in
range
(
mask_num
*
h
*
w
):
if
objness
[
i
,
j
]
>
0
:
loss
[
i
]
+=
sce
(
pred_obj
[
i
,
j
],
1.0
)
loss
[
i
]
+=
sce
(
pred_obj
[
i
,
j
],
1.0
)
*
objness
[
i
,
j
]
elif
objness
[
i
,
j
]
==
0
:
loss
[
i
]
+=
sce
(
pred_obj
[
i
,
j
],
0.0
)
...
...
@@ -176,7 +177,8 @@ class TestYolov3LossOp(OpTest):
"anchor_mask"
:
self
.
anchor_mask
,
"class_num"
:
self
.
class_num
,
"ignore_thresh"
:
self
.
ignore_thresh
,
"downsample"
:
self
.
downsample
,
"downsample_ratio"
:
self
.
downsample_ratio
,
"use_label_smooth"
:
self
.
use_label_smooth
,
}
self
.
inputs
=
{
...
...
@@ -184,7 +186,14 @@ class TestYolov3LossOp(OpTest):
'GTBox'
:
gtbox
.
astype
(
'float32'
),
'GTLabel'
:
gtlabel
.
astype
(
'int32'
),
}
loss
,
objness
,
gt_matches
=
YOLOv3Loss
(
x
,
gtbox
,
gtlabel
,
self
.
attrs
)
gtscore
=
np
.
ones
(
self
.
gtbox_shape
[:
2
]).
astype
(
'float32'
)
if
self
.
gtscore
:
gtscore
=
np
.
random
.
random
(
self
.
gtbox_shape
[:
2
]).
astype
(
'float32'
)
self
.
inputs
[
'GTScore'
]
=
gtscore
loss
,
objness
,
gt_matches
=
YOLOv3Loss
(
x
,
gtbox
,
gtlabel
,
gtscore
,
self
.
attrs
)
self
.
outputs
=
{
'Loss'
:
loss
,
'ObjectnessMask'
:
objness
,
...
...
@@ -193,24 +202,57 @@ class TestYolov3LossOp(OpTest):
def
test_check_output
(
self
):
place
=
core
.
CPUPlace
()
self
.
check_output_with_place
(
place
,
atol
=
1
e-3
)
self
.
check_output_with_place
(
place
,
atol
=
2
e-3
)
def
test_check_grad_ignore_gtbox
(
self
):
place
=
core
.
CPUPlace
()
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Loss'
,
no_grad_set
=
set
([
"GTBox"
,
"GTLabel"
]),
max_relative_error
=
0.3
)
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Loss'
,
max_relative_error
=
0.2
)
def
initTestCase
(
self
):
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_mask
=
[
0
,
1
,
2
]
self
.
class_num
=
5
self
.
ignore_thresh
=
0.7
self
.
downsample_ratio
=
32
self
.
x_shape
=
(
3
,
len
(
self
.
anchor_mask
)
*
(
5
+
self
.
class_num
),
5
,
5
)
self
.
gtbox_shape
=
(
3
,
5
,
4
)
self
.
gtscore
=
True
self
.
use_label_smooth
=
True
class
TestYolov3LossWithoutLabelSmooth
(
TestYolov3LossOp
):
def
initTestCase
(
self
):
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_mask
=
[
0
,
1
,
2
]
self
.
class_num
=
5
self
.
ignore_thresh
=
0.7
self
.
downsample_ratio
=
32
self
.
x_shape
=
(
3
,
len
(
self
.
anchor_mask
)
*
(
5
+
self
.
class_num
),
5
,
5
)
self
.
gtbox_shape
=
(
3
,
5
,
4
)
self
.
gtscore
=
True
self
.
use_label_smooth
=
False
class
TestYolov3LossNoGTScore
(
TestYolov3LossOp
):
def
initTestCase
(
self
):
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
]
self
.
anchor_mask
=
[
1
,
2
]
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_mask
=
[
0
,
1
,
2
]
self
.
class_num
=
5
self
.
ignore_thresh
=
0.
5
self
.
downsample
=
32
self
.
ignore_thresh
=
0.
7
self
.
downsample
_ratio
=
32
self
.
x_shape
=
(
3
,
len
(
self
.
anchor_mask
)
*
(
5
+
self
.
class_num
),
5
,
5
)
self
.
gtbox_shape
=
(
3
,
5
,
4
)
self
.
gtscore
=
False
self
.
use_label_smooth
=
True
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录