Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ad897304
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ad897304
编写于
2月 26, 2019
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pre-commit. test=develop
上级
72a18bb1
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
43 addition
and
29 deletion
+43
-29
paddle/fluid/operators/detection/yolo_box_op.cu
paddle/fluid/operators/detection/yolo_box_op.cu
+9
-9
paddle/fluid/operators/detection/yolo_box_op.h
paddle/fluid/operators/detection/yolo_box_op.h
+25
-17
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
+9
-3
未找到文件。
paddle/fluid/operators/detection/yolo_box_op.cu
浏览文件 @
ad897304
...
...
@@ -22,9 +22,9 @@ using Tensor = framework::Tensor;
template
<
typename
T
>
__global__
void
KeYoloBoxFw
(
const
T
*
input
,
const
int
*
imgsize
,
T
*
boxes
,
T
*
scores
,
const
float
conf_thresh
,
const
int
*
anchors
,
const
int
n
,
const
int
h
,
const
int
w
,
const
int
an_num
,
const
int
class_num
,
T
*
scores
,
const
float
conf_thresh
,
const
int
*
anchors
,
const
int
n
,
const
int
h
,
const
int
w
,
const
int
an_num
,
const
int
class_num
,
const
int
box_num
,
int
input_size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
...
...
@@ -50,7 +50,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
);
GetYoloBox
<
T
>
(
box
,
input
,
anchors
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
grid_num
,
img_height
,
img_width
);
grid_num
,
img_height
,
img_width
);
box_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes
,
box
,
box_idx
,
img_height
,
img_width
);
...
...
@@ -84,7 +84,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
int
input_size
=
downsample_ratio
*
h
;
Tensor
anchors_t
,
cpu_anchors_t
;
auto
cpu_anchors_data
=
cpu_anchors_t
.
mutable_data
<
int
>
({
an_num
*
2
},
platform
::
CPUPlace
());
auto
cpu_anchors_data
=
cpu_anchors_t
.
mutable_data
<
int
>
({
an_num
*
2
},
platform
::
CPUPlace
());
std
::
copy
(
anchors
.
begin
(),
anchors
.
end
(),
cpu_anchors_data
);
TensorCopySync
(
cpu_anchors_t
,
ctx
.
GetPlace
(),
&
anchors_t
);
auto
anchors_data
=
anchors_t
.
data
<
int
>
();
...
...
@@ -103,8 +104,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
KeYoloBoxFw
<
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
imgsize_data
,
boxes_data
,
scores_data
,
conf_thresh
,
anchors_data
,
n
,
h
,
w
,
an_num
,
class_num
,
box_num
,
input_size
);
input_data
,
imgsize_data
,
boxes_data
,
scores_data
,
conf_thresh
,
anchors_data
,
n
,
h
,
w
,
an_num
,
class_num
,
box_num
,
input_size
);
}
};
...
...
@@ -112,6 +113,5 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
yolo_box
,
ops
::
YoloBoxOpCUDAKernel
<
float
>
,
REGISTER_OP_CUDA_KERNEL
(
yolo_box
,
ops
::
YoloBoxOpCUDAKernel
<
float
>
,
ops
::
YoloBoxOpCUDAKernel
<
double
>
);
paddle/fluid/operators/detection/yolo_box_op.h
浏览文件 @
ad897304
...
...
@@ -20,7 +20,6 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
HOSTDEVICE
inline
T
sigmoid
(
T
x
)
{
return
1.0
/
(
1.0
+
std
::
exp
(
-
x
));
...
...
@@ -28,15 +27,15 @@ HOSTDEVICE inline T sigmoid(T x) {
template
<
typename
T
>
HOSTDEVICE
inline
void
GetYoloBox
(
T
*
box
,
const
T
*
x
,
const
int
*
anchors
,
int
i
,
int
j
,
int
an_idx
,
int
grid_size
,
int
input_size
,
int
index
,
int
stride
,
int
img_height
,
int
img_width
)
{
int
j
,
int
an_idx
,
int
grid_size
,
int
input_size
,
int
index
,
int
stride
,
int
img_height
,
int
img_width
)
{
box
[
0
]
=
(
i
+
sigmoid
<
T
>
(
x
[
index
]))
*
img_width
/
grid_size
;
box
[
1
]
=
(
j
+
sigmoid
<
T
>
(
x
[
index
+
stride
]))
*
img_height
/
grid_size
;
box
[
2
]
=
std
::
exp
(
x
[
index
+
2
*
stride
])
*
anchors
[
2
*
an_idx
]
*
img_width
/
input_size
;
box
[
3
]
=
std
::
exp
(
x
[
index
+
3
*
stride
])
*
anchors
[
2
*
an_idx
+
1
]
*
img_height
/
input_size
;
input_size
;
box
[
3
]
=
std
::
exp
(
x
[
index
+
3
*
stride
])
*
anchors
[
2
*
an_idx
+
1
]
*
img_height
/
input_size
;
}
HOSTDEVICE
inline
int
GetEntryIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
...
...
@@ -47,16 +46,22 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
template
<
typename
T
>
HOSTDEVICE
inline
void
CalcDetectionBox
(
T
*
boxes
,
T
*
box
,
const
int
box_idx
,
const
int
img_height
,
const
int
img_width
)
{
const
int
img_height
,
const
int
img_width
)
{
boxes
[
box_idx
]
=
box
[
0
]
-
box
[
2
]
/
2
;
boxes
[
box_idx
+
1
]
=
box
[
1
]
-
box
[
3
]
/
2
;
boxes
[
box_idx
+
2
]
=
box
[
0
]
+
box
[
2
]
/
2
;
boxes
[
box_idx
+
3
]
=
box
[
1
]
+
box
[
3
]
/
2
;
boxes
[
box_idx
]
=
boxes
[
box_idx
]
>
0
?
boxes
[
box_idx
]
:
static_cast
<
T
>
(
0
);
boxes
[
box_idx
+
1
]
=
boxes
[
box_idx
+
1
]
>
0
?
boxes
[
box_idx
+
1
]
:
static_cast
<
T
>
(
0
);
boxes
[
box_idx
+
2
]
=
boxes
[
box_idx
+
2
]
<
img_width
-
1
?
boxes
[
box_idx
+
2
]
:
static_cast
<
T
>
(
img_width
-
1
);
boxes
[
box_idx
+
3
]
=
boxes
[
box_idx
+
3
]
<
img_height
-
1
?
boxes
[
box_idx
+
3
]
:
static_cast
<
T
>
(
img_height
-
1
);
boxes
[
box_idx
+
1
]
=
boxes
[
box_idx
+
1
]
>
0
?
boxes
[
box_idx
+
1
]
:
static_cast
<
T
>
(
0
);
boxes
[
box_idx
+
2
]
=
boxes
[
box_idx
+
2
]
<
img_width
-
1
?
boxes
[
box_idx
+
2
]
:
static_cast
<
T
>
(
img_width
-
1
);
boxes
[
box_idx
+
3
]
=
boxes
[
box_idx
+
3
]
<
img_height
-
1
?
boxes
[
box_idx
+
3
]
:
static_cast
<
T
>
(
img_height
-
1
);
}
template
<
typename
T
>
...
...
@@ -92,8 +97,10 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const
int
stride
=
h
*
w
;
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
int
anchors_
[
anchors
.
size
()];
std
::
copy
(
anchors
.
begin
(),
anchors
.
end
(),
anchors_
);
Tensor
anchors_
;
auto
anchors_data
=
anchors_
.
mutable_data
<
int
>
({
an_num
*
2
},
ctx
.
GetPlace
());
std
::
copy
(
anchors
.
begin
(),
anchors
.
end
(),
anchors_data
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
int
*
imgsize_data
=
imgsize
->
data
<
int
>
();
...
...
@@ -120,10 +127,11 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
);
GetYoloBox
<
T
>
(
box
,
input_data
,
anchors_
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
stride
,
img_height
,
img_width
);
GetYoloBox
<
T
>
(
box
,
input_data
,
anchors_data
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
stride
,
img_height
,
img_width
);
box_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes_data
,
box
,
box_idx
,
img_height
,
img_width
);
CalcDetectionBox
<
T
>
(
boxes_data
,
box
,
box_idx
,
img_height
,
img_width
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
5
);
...
...
python/paddle/fluid/tests/unittests/test_yolo_box_op.py
浏览文件 @
ad897304
...
...
@@ -59,12 +59,19 @@ def YoloBox(x, img_size, attrs):
pred_box
[:,
:,
:
2
],
pred_box
[:,
:,
2
:
4
]
=
\
pred_box
[:,
:,
:
2
]
-
pred_box
[:,
:,
2
:
4
]
/
2.
,
\
pred_box
[:,
:,
:
2
]
+
pred_box
[:,
:,
2
:
4
]
/
2.0
# pred_box = pred_box * input_size
pred_box
[:,
:,
0
]
=
pred_box
[:,
:,
0
]
*
img_size
[:,
1
][:,
np
.
newaxis
]
pred_box
[:,
:,
1
]
=
pred_box
[:,
:,
1
]
*
img_size
[:,
0
][:,
np
.
newaxis
]
pred_box
[:,
:,
2
]
=
pred_box
[:,
:,
2
]
*
img_size
[:,
1
][:,
np
.
newaxis
]
pred_box
[:,
:,
3
]
=
pred_box
[:,
:,
3
]
*
img_size
[:,
0
][:,
np
.
newaxis
]
for
i
in
range
(
len
(
pred_box
)):
pred_box
[
i
,
:,
0
]
=
np
.
clip
(
pred_box
[
i
,
:,
0
],
0
,
np
.
inf
)
pred_box
[
i
,
:,
1
]
=
np
.
clip
(
pred_box
[
i
,
:,
1
],
0
,
np
.
inf
)
pred_box
[
i
,
:,
2
]
=
np
.
clip
(
pred_box
[
i
,
:,
2
],
-
np
.
inf
,
img_size
[
i
,
1
]
-
1
)
pred_box
[
i
,
:,
3
]
=
np
.
clip
(
pred_box
[
i
,
:,
3
],
-
np
.
inf
,
img_size
[
i
,
0
]
-
1
)
return
pred_box
,
pred_score
.
reshape
((
n
,
-
1
,
class_num
))
...
...
@@ -93,8 +100,7 @@ class TestYoloBoxOp(OpTest):
}
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
self
.
check_output
()
def
initTestCase
(
self
):
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录