Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b154470c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b154470c
编写于
6月 09, 2021
作者:
W
wangxinxin08
提交者:
GitHub
6月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add two attributes for yolo box (#33400)
* add two attributes for yolo box
上级
ddc95a01
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
202 addition
and
38 deletion
+202
-38
paddle/fluid/operators/detection/yolo_box_op.cc
paddle/fluid/operators/detection/yolo_box_op.cc
+58
-9
paddle/fluid/operators/detection/yolo_box_op.cu
paddle/fluid/operators/detection/yolo_box_op.cu
+17
-8
paddle/fluid/operators/detection/yolo_box_op.h
paddle/fluid/operators/detection/yolo_box_op.h
+29
-8
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+7
-1
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
+70
-7
python/paddle/vision/ops.py
python/paddle/vision/ops.py
+21
-5
未找到文件。
paddle/fluid/operators/detection/yolo_box_op.cc
浏览文件 @
b154470c
...
...
@@ -11,6 +11,7 @@
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -31,11 +32,35 @@ class YoloBoxOp : public framework::OperatorWithKernel {
auto
anchors
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"anchors"
);
int
anchor_num
=
anchors
.
size
()
/
2
;
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
auto
iou_aware
=
ctx
->
Attrs
().
Get
<
bool
>
(
"iou_aware"
);
auto
iou_aware_factor
=
ctx
->
Attrs
().
Get
<
float
>
(
"iou_aware_factor"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input(X) should be a 4-D tensor."
"But received X dimension(%s)"
,
dim_x
.
size
()));
if
(
iou_aware
)
{
PADDLE_ENFORCE_EQ
(
dim_x
[
1
],
anchor_num
*
(
6
+
class_num
),
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[1] should be equal to (anchor_mask_number * (6 "
"+ class_num)) while iou_aware is true."
"But received dim[1](%s) != (anchor_mask_number * "
"(6+class_num)(%s)."
,
dim_x
[
1
],
anchor_num
*
(
6
+
class_num
)));
PADDLE_ENFORCE_GE
(
iou_aware_factor
,
0
,
platform
::
errors
::
InvalidArgument
(
"Attr(iou_aware_factor) should greater than or equal to 0."
"But received iou_aware_factor (%s)"
,
iou_aware_factor
));
PADDLE_ENFORCE_LE
(
iou_aware_factor
,
1
,
platform
::
errors
::
InvalidArgument
(
"Attr(iou_aware_factor) should less than or equal to 1."
"But received iou_aware_factor (%s)"
,
iou_aware_factor
));
}
else
{
PADDLE_ENFORCE_EQ
(
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
),
platform
::
errors
::
InvalidArgument
(
...
...
@@ -44,6 +69,7 @@ class YoloBoxOp : public framework::OperatorWithKernel {
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s)."
,
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
)));
}
PADDLE_ENFORCE_EQ
(
dim_imgsize
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) should be a 2-D tensor."
...
...
@@ -140,6 +166,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Scale the center point of decoded bounding "
"box. Default 1.0"
)
.
SetDefault
(
1.
);
AddAttr
<
bool
>
(
"iou_aware"
,
"Whether use iou aware. Default false."
)
.
SetDefault
(
false
);
AddAttr
<
float
>
(
"iou_aware_factor"
,
"iou aware factor. Default 0.5."
)
.
SetDefault
(
0.5
);
AddComment
(
R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network.
...
...
@@ -147,7 +177,8 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object
dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false,
otherwise C should be equal to S * (6 + class_num). class_num is the object
category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor
...
...
@@ -183,6 +214,15 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
score_{pred} = score_{conf} * score_{class}
$$
where the confidence scores follow the formula bellow
.. math::
score_{conf} = \begin{case}
obj, \text{if } iou_aware == flase \\
obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise}
\end{case}
)DOC"
);
}
};
...
...
@@ -197,3 +237,12 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
yolo_box
,
ops
::
YoloBoxKernel
<
float
>
,
ops
::
YoloBoxKernel
<
double
>
);
REGISTER_OP_VERSION
(
yolo_box
)
.
AddCheckpoint
(
R"ROC(
Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor].
)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
()
.
NewAttr
(
"iou_aware"
,
"Whether use iou aware"
,
false
)
.
NewAttr
(
"iou_aware_factor"
,
"iou aware factor"
,
0.5
f
));
paddle/fluid/operators/detection/yolo_box_op.cu
浏览文件 @
b154470c
...
...
@@ -28,7 +28,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
const
int
w
,
const
int
an_num
,
const
int
class_num
,
const
int
box_num
,
int
input_size_h
,
int
input_size_w
,
bool
clip_bbox
,
const
float
scale
,
const
float
bias
)
{
const
float
bias
,
bool
iou_aware
,
const
float
iou_aware_factor
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
T
box
[
4
];
...
...
@@ -43,23 +44,29 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int
img_height
=
imgsize
[
2
*
i
];
int
img_width
=
imgsize
[
2
*
i
+
1
];
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
4
);
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
4
,
iou_aware
);
T
conf
=
sigmoid
<
T
>
(
input
[
obj_idx
]);
if
(
iou_aware
)
{
int
iou_idx
=
GetIoUIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
);
T
iou
=
sigmoid
<
T
>
(
input
[
iou_idx
]);
conf
=
pow
(
conf
,
static_cast
<
T
>
(
1.
-
iou_aware_factor
))
*
pow
(
iou
,
static_cast
<
T
>
(
iou_aware_factor
));
}
if
(
conf
<
conf_thresh
)
{
continue
;
}
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
);
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
,
iou_aware
);
GetYoloBox
<
T
>
(
box
,
input
,
anchors
,
l
,
k
,
j
,
h
,
w
,
input_size_h
,
input_size_w
,
box_idx
,
grid_num
,
img_height
,
img_width
,
scale
,
bias
);
box_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes
,
box
,
box_idx
,
img_height
,
img_width
,
clip_bbox
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
5
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
5
,
iou_aware
);
int
score_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
class_num
;
CalcLabelScore
<
T
>
(
scores
,
input
,
label_idx
,
score_idx
,
class_num
,
conf
,
grid_num
);
...
...
@@ -80,6 +87,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
float
conf_thresh
=
ctx
.
Attr
<
float
>
(
"conf_thresh"
);
int
downsample_ratio
=
ctx
.
Attr
<
int
>
(
"downsample_ratio"
);
bool
clip_bbox
=
ctx
.
Attr
<
bool
>
(
"clip_bbox"
);
bool
iou_aware
=
ctx
.
Attr
<
bool
>
(
"iou_aware"
);
float
iou_aware_factor
=
ctx
.
Attr
<
float
>
(
"iou_aware_factor"
);
float
scale
=
ctx
.
Attr
<
float
>
(
"scale_x_y"
);
float
bias
=
-
0.5
*
(
scale
-
1.
);
...
...
@@ -115,7 +124,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
imgsize_data
,
boxes_data
,
scores_data
,
conf_thresh
,
anchors_data
,
n
,
h
,
w
,
an_num
,
class_num
,
box_num
,
input_size_h
,
input_size_w
,
clip_bbox
,
scale
,
bias
);
input_size_w
,
clip_bbox
,
scale
,
bias
,
iou_aware
,
iou_aware_factor
);
}
};
...
...
paddle/fluid/operators/detection/yolo_box_op.h
浏览文件 @
b154470c
...
...
@@ -13,6 +13,7 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
...
...
@@ -43,8 +44,19 @@ HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
HOSTDEVICE
inline
int
GetEntryIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
int
an_num
,
int
an_stride
,
int
stride
,
int
entry
)
{
int
entry
,
bool
iou_aware
)
{
if
(
iou_aware
)
{
return
(
batch
*
an_num
+
an_idx
)
*
an_stride
+
(
batch
*
an_num
+
an_num
+
entry
)
*
stride
+
hw_idx
;
}
else
{
return
(
batch
*
an_num
+
an_idx
)
*
an_stride
+
entry
*
stride
+
hw_idx
;
}
}
HOSTDEVICE
inline
int
GetIoUIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
int
an_num
,
int
an_stride
,
int
stride
)
{
return
batch
*
an_num
*
an_stride
+
(
batch
*
an_num
+
an_idx
)
*
stride
+
hw_idx
;
}
template
<
typename
T
>
...
...
@@ -92,6 +104,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
float
conf_thresh
=
ctx
.
Attr
<
float
>
(
"conf_thresh"
);
int
downsample_ratio
=
ctx
.
Attr
<
int
>
(
"downsample_ratio"
);
bool
clip_bbox
=
ctx
.
Attr
<
bool
>
(
"clip_bbox"
);
bool
iou_aware
=
ctx
.
Attr
<
bool
>
(
"iou_aware"
);
float
iou_aware_factor
=
ctx
.
Attr
<
float
>
(
"iou_aware_factor"
);
float
scale
=
ctx
.
Attr
<
float
>
(
"scale_x_y"
);
float
bias
=
-
0.5
*
(
scale
-
1.
);
...
...
@@ -127,15 +141,22 @@ class YoloBoxKernel : public framework::OpKernel<T> {
for
(
int
j
=
0
;
j
<
an_num
;
j
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
4
);
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
4
,
iou_aware
);
T
conf
=
sigmoid
<
T
>
(
input_data
[
obj_idx
]);
if
(
iou_aware
)
{
int
iou_idx
=
GetIoUIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
);
T
iou
=
sigmoid
<
T
>
(
input_data
[
iou_idx
]);
conf
=
pow
(
conf
,
static_cast
<
T
>
(
1.
-
iou_aware_factor
))
*
pow
(
iou
,
static_cast
<
T
>
(
iou_aware_factor
));
}
if
(
conf
<
conf_thresh
)
{
continue
;
}
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
);
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
,
iou_aware
);
GetYoloBox
<
T
>
(
box
,
input_data
,
anchors_data
,
l
,
k
,
j
,
h
,
w
,
input_size_h
,
input_size_w
,
box_idx
,
stride
,
img_height
,
img_width
,
scale
,
bias
);
...
...
@@ -143,8 +164,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
CalcDetectionBox
<
T
>
(
boxes_data
,
box
,
box_idx
,
img_height
,
img_width
,
clip_bbox
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
5
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
5
,
iou_aware
);
int
score_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
class_num
;
CalcLabelScore
<
T
>
(
scores_data
,
input_data
,
label_idx
,
score_idx
,
class_num
,
conf
,
stride
);
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
b154470c
...
...
@@ -1139,7 +1139,9 @@ def yolo_box(x,
downsample_ratio
,
clip_bbox
=
True
,
name
=
None
,
scale_x_y
=
1.
):
scale_x_y
=
1.
,
iou_aware
=
False
,
iou_aware_factor
=
0.5
):
"""
${comment}
...
...
@@ -1156,6 +1158,8 @@ def yolo_box(x,
name (string): The default value is None. Normally there is no need
for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
iou_aware (bool): ${iou_aware_comment}
iou_aware_factor (float): ${iou_aware_factor_comment}
Returns:
Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
...
...
@@ -1204,6 +1208,8 @@ def yolo_box(x,
"downsample_ratio"
:
downsample_ratio
,
"clip_bbox"
:
clip_bbox
,
"scale_x_y"
:
scale_x_y
,
"iou_aware"
:
iou_aware
,
"iou_aware_factor"
:
iou_aware_factor
}
helper
.
append_op
(
...
...
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
浏览文件 @
b154470c
...
...
@@ -35,10 +35,16 @@ def YoloBox(x, img_size, attrs):
downsample
=
attrs
[
'downsample'
]
clip_bbox
=
attrs
[
'clip_bbox'
]
scale_x_y
=
attrs
[
'scale_x_y'
]
iou_aware
=
attrs
[
'iou_aware'
]
iou_aware_factor
=
attrs
[
'iou_aware_factor'
]
bias_x_y
=
-
0.5
*
(
scale_x_y
-
1.
)
input_h
=
downsample
*
h
input_w
=
downsample
*
w
if
iou_aware
:
ioup
=
x
[:,
:
an_num
,
:,
:]
ioup
=
np
.
expand_dims
(
ioup
,
axis
=-
1
)
x
=
x
[:,
an_num
:,
:,
:]
x
=
x
.
reshape
((
n
,
an_num
,
5
+
class_num
,
h
,
w
)).
transpose
((
0
,
1
,
3
,
4
,
2
))
pred_box
=
x
[:,
:,
:,
:,
:
4
].
copy
()
...
...
@@ -57,6 +63,10 @@ def YoloBox(x, img_size, attrs):
pred_box
[:,
:,
:,
:,
2
]
=
np
.
exp
(
pred_box
[:,
:,
:,
:,
2
])
*
anchor_w
pred_box
[:,
:,
:,
:,
3
]
=
np
.
exp
(
pred_box
[:,
:,
:,
:,
3
])
*
anchor_h
if
iou_aware
:
pred_conf
=
sigmoid
(
x
[:,
:,
:,
:,
4
:
5
])
**
(
1
-
iou_aware_factor
)
*
sigmoid
(
ioup
)
**
iou_aware_factor
else
:
pred_conf
=
sigmoid
(
x
[:,
:,
:,
:,
4
:
5
])
pred_conf
[
pred_conf
<
conf_thresh
]
=
0.
pred_score
=
sigmoid
(
x
[:,
:,
:,
:,
5
:])
*
pred_conf
...
...
@@ -97,6 +107,8 @@ class TestYoloBoxOp(OpTest):
"downsample"
:
self
.
downsample
,
"clip_bbox"
:
self
.
clip_bbox
,
"scale_x_y"
:
self
.
scale_x_y
,
"iou_aware"
:
self
.
iou_aware
,
"iou_aware_factor"
:
self
.
iou_aware_factor
}
self
.
inputs
=
{
...
...
@@ -123,6 +135,8 @@ class TestYoloBoxOp(OpTest):
self
.
x_shape
=
(
self
.
batch_size
,
an_num
*
(
5
+
self
.
class_num
),
13
,
13
)
self
.
imgsize_shape
=
(
self
.
batch_size
,
2
)
self
.
scale_x_y
=
1.
self
.
iou_aware
=
False
self
.
iou_aware_factor
=
0.5
class
TestYoloBoxOpNoClipBbox
(
TestYoloBoxOp
):
...
...
@@ -137,6 +151,8 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
self
.
x_shape
=
(
self
.
batch_size
,
an_num
*
(
5
+
self
.
class_num
),
13
,
13
)
self
.
imgsize_shape
=
(
self
.
batch_size
,
2
)
self
.
scale_x_y
=
1.
self
.
iou_aware
=
False
self
.
iou_aware_factor
=
0.5
class
TestYoloBoxOpScaleXY
(
TestYoloBoxOp
):
...
...
@@ -151,19 +167,36 @@ class TestYoloBoxOpScaleXY(TestYoloBoxOp):
self
.
x_shape
=
(
self
.
batch_size
,
an_num
*
(
5
+
self
.
class_num
),
13
,
13
)
self
.
imgsize_shape
=
(
self
.
batch_size
,
2
)
self
.
scale_x_y
=
1.2
self
.
iou_aware
=
False
self
.
iou_aware_factor
=
0.5
class
TestYoloBoxOpIoUAware
(
TestYoloBoxOp
):
def
initTestCase
(
self
):
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
]
an_num
=
int
(
len
(
self
.
anchors
)
//
2
)
self
.
batch_size
=
32
self
.
class_num
=
2
self
.
conf_thresh
=
0.5
self
.
downsample
=
32
self
.
clip_bbox
=
True
self
.
x_shape
=
(
self
.
batch_size
,
an_num
*
(
6
+
self
.
class_num
),
13
,
13
)
self
.
imgsize_shape
=
(
self
.
batch_size
,
2
)
self
.
scale_x_y
=
1.
self
.
iou_aware
=
True
self
.
iou_aware_factor
=
0.5
class
TestYoloBoxDygraph
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
paddle
.
disable_static
()
x
=
np
.
random
.
random
([
2
,
14
,
8
,
8
]).
astype
(
'float32'
)
img_size
=
np
.
ones
((
2
,
2
)).
astype
(
'int32'
)
x
=
paddle
.
to_tensor
(
x
)
img_size
=
paddle
.
to_tensor
(
img_size
)
x1
=
np
.
random
.
random
([
2
,
14
,
8
,
8
]).
astype
(
'float32'
)
x1
=
paddle
.
to_tensor
(
x1
)
boxes
,
scores
=
paddle
.
vision
.
ops
.
yolo_box
(
x
,
x
1
,
img_size
=
img_size
,
anchors
=
[
10
,
13
,
16
,
30
],
class_num
=
2
,
...
...
@@ -172,16 +205,30 @@ class TestYoloBoxDygraph(unittest.TestCase):
clip_bbox
=
True
,
scale_x_y
=
1.
)
assert
boxes
is
not
None
and
scores
is
not
None
x2
=
np
.
random
.
random
([
2
,
16
,
8
,
8
]).
astype
(
'float32'
)
x2
=
paddle
.
to_tensor
(
x2
)
boxes
,
scores
=
paddle
.
vision
.
ops
.
yolo_box
(
x2
,
img_size
=
img_size
,
anchors
=
[
10
,
13
,
16
,
30
],
class_num
=
2
,
conf_thresh
=
0.01
,
downsample_ratio
=
8
,
clip_bbox
=
True
,
scale_x_y
=
1.
,
iou_aware
=
True
,
iou_aware_factor
=
0.5
)
paddle
.
enable_static
()
class
TestYoloBoxStatic
(
unittest
.
TestCase
):
def
test_static
(
self
):
x
=
paddle
.
static
.
data
(
'x
'
,
[
2
,
14
,
8
,
8
],
'float32'
)
x
1
=
paddle
.
static
.
data
(
'x1
'
,
[
2
,
14
,
8
,
8
],
'float32'
)
img_size
=
paddle
.
static
.
data
(
'img_size'
,
[
2
,
2
],
'int32'
)
boxes
,
scores
=
paddle
.
vision
.
ops
.
yolo_box
(
x
,
x
1
,
img_size
=
img_size
,
anchors
=
[
10
,
13
,
16
,
30
],
class_num
=
2
,
...
...
@@ -191,6 +238,20 @@ class TestYoloBoxStatic(unittest.TestCase):
scale_x_y
=
1.
)
assert
boxes
is
not
None
and
scores
is
not
None
x2
=
paddle
.
static
.
data
(
'x2'
,
[
2
,
16
,
8
,
8
],
'float32'
)
boxes
,
scores
=
paddle
.
vision
.
ops
.
yolo_box
(
x2
,
img_size
=
img_size
,
anchors
=
[
10
,
13
,
16
,
30
],
class_num
=
2
,
conf_thresh
=
0.01
,
downsample_ratio
=
8
,
clip_bbox
=
True
,
scale_x_y
=
1.
,
iou_aware
=
True
,
iou_aware_factor
=
0.5
)
assert
boxes
is
not
None
and
scores
is
not
None
class
TestYoloBoxOpHW
(
TestYoloBoxOp
):
def
initTestCase
(
self
):
...
...
@@ -204,6 +265,8 @@ class TestYoloBoxOpHW(TestYoloBoxOp):
self
.
x_shape
=
(
self
.
batch_size
,
an_num
*
(
5
+
self
.
class_num
),
13
,
9
)
self
.
imgsize_shape
=
(
self
.
batch_size
,
2
)
self
.
scale_x_y
=
1.
self
.
iou_aware
=
False
self
.
iou_aware_factor
=
0.5
if
__name__
==
"__main__"
:
...
...
python/paddle/vision/ops.py
浏览文件 @
b154470c
...
...
@@ -247,7 +247,9 @@ def yolo_box(x,
downsample_ratio
,
clip_bbox
=
True
,
name
=
None
,
scale_x_y
=
1.
):
scale_x_y
=
1.
,
iou_aware
=
False
,
iou_aware_factor
=
0.5
):
r
"""
This operator generates YOLO detection boxes from output of YOLOv3 network.
...
...
@@ -256,7 +258,8 @@ def yolo_box(x,
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors. In the second dimension(the channel
dimension), C should be equal to S * (5 + class_num), class_num is the object
dimension), C should be equal to S * (5 + class_num) if :attr:`iou_aware` is false,
otherwise C should be equal to S * (6 + class_num). class_num is the object
category number of source dataset(such as 80 in coco dataset), so the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor
...
...
@@ -292,6 +295,15 @@ def yolo_box(x,
score_{pred} = score_{conf} * score_{class}
$$
where the confidence scores follow the formula bellow
.. math::
score_{conf} = \begin{case}
obj, \text{if } iou_aware == flase \\
obj^{1 - iou_aware_factor} * iou^{iou_aware_factor}, \text{otherwise}
\end{case}
Args:
x (Tensor): The input tensor of YoloBox operator is a 4-D tensor with
shape of [N, C, H, W]. The second dimension(C) stores box
...
...
@@ -313,13 +325,14 @@ def yolo_box(x,
should be set for the first, second, and thrid
:attr:`yolo_box` layer.
clip_bbox (bool): Whether clip output bonding box in :attr:`img_size`
boundary. Default true."
"
boundary. Default true.
scale_x_y (float): Scale the center point of decoded bounding box.
Default 1.0
name (string): The default value is None. Normally there is no need
for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
iou_aware (bool): Whether use iou aware. Default false
iou_aware_factor (float): iou aware factor. Default 0.5
Returns:
Tensor: A 3-D tensor with shape [N, M, 4], the coordinates of boxes,
...
...
@@ -358,7 +371,8 @@ def yolo_box(x,
boxes
,
scores
=
core
.
ops
.
yolo_box
(
x
,
img_size
,
'anchors'
,
anchors
,
'class_num'
,
class_num
,
'conf_thresh'
,
conf_thresh
,
'downsample_ratio'
,
downsample_ratio
,
'clip_bbox'
,
clip_bbox
,
'scale_x_y'
,
scale_x_y
)
'clip_bbox'
,
clip_bbox
,
'scale_x_y'
,
scale_x_y
,
'iou_aware'
,
iou_aware
,
'iou_aware_factor'
,
iou_aware_factor
)
return
boxes
,
scores
helper
=
LayerHelper
(
'yolo_box'
,
**
locals
())
...
...
@@ -378,6 +392,8 @@ def yolo_box(x,
"downsample_ratio"
:
downsample_ratio
,
"clip_bbox"
:
clip_bbox
,
"scale_x_y"
:
scale_x_y
,
"iou_aware"
:
iou_aware
,
"iou_aware_factor"
:
iou_aware_factor
}
helper
.
append_op
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录