Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
04b8b9e9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04b8b9e9
编写于
2月 19, 2019
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add yolo_box_op CUDA kernel
上级
452373de
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
62 addition
and
29 deletion
+62
-29
paddle/fluid/operators/detection/yolo_box_op.cc
paddle/fluid/operators/detection/yolo_box_op.cc
+0
-1
paddle/fluid/operators/detection/yolo_box_op.cu
paddle/fluid/operators/detection/yolo_box_op.cu
+47
-15
paddle/fluid/operators/detection/yolo_box_op.h
paddle/fluid/operators/detection/yolo_box_op.h
+15
-13
未找到文件。
paddle/fluid/operators/detection/yolo_box_op.cc
浏览文件 @
04b8b9e9
...
...
@@ -35,7 +35,6 @@ class YoloBoxOp : public framework::OperatorWithKernel {
auto
anchors
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"anchors"
);
int
anchor_num
=
anchors
.
size
()
/
2
;
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
auto
conf_thresh
=
ctx
->
Attrs
().
Get
<
float
>
(
"conf_thresh"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"Input(X) should be a 4-D tensor."
);
PADDLE_ENFORCE_EQ
(
...
...
paddle/fluid/operators/detection/yolo_box_op.cu
浏览文件 @
04b8b9e9
...
...
@@ -20,15 +20,44 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
static
__global__
void
GenDensityPriorBox
(
const
int
height
,
const
int
width
,
const
int
im_height
,
const
int
im_width
,
const
T
offset
,
const
T
step_width
,
const
T
step_height
,
const
int
num_priors
,
const
T
*
ratios_shift
,
bool
is_clip
,
const
T
var_xmin
,
const
T
var_ymin
,
const
T
var_xmax
,
const
T
var_ymax
,
T
*
out
,
T
*
var
)
{
int
gidx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
gidy
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
step_x
=
blockDim
.
x
*
gridDim
.
x
;
int
step_y
=
blockDim
.
y
*
gridDim
.
y
;
__global__
void
KeYoloBoxFw
(
const
T
*
input
,
const
int
*
imgsize
,
T
*
boxes
,
T
*
scores
,
const
float
conf_thresh
,
std
::
vector
<
int
>
anchors
,
const
int
h
,
const
in
w
,
const
int
an_num
,
const
int
class_num
,
const
int
box_num
,
const
int
input_size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(;
tid
<
box_num
;
tid
+=
stride
)
{
int
grid_num
=
h
*
w
;
int
i
=
tid
/
box_num
;
int
j
=
(
tid
%
box_num
)
/
grid_num
;
int
k
=
(
tid
%
grid_num
)
/
w
;
int
l
=
tid
%
w
;
int
an_stride
=
an_num
*
grid_num
;
int
img_height
=
imgsize
[
2
*
i
];
int
img_width
=
imgsize
[
2
*
i
+
1
];
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
4
);
T
conf
=
sigmoid
<
T
>
(
input
[
obj_idx
]);
if
(
conf
<
conf_thresh
)
{
continue
;
}
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
);
Box
<
T
>
pred
=
GetYoloBox
<
T
>
(
input
,
anchors
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
grid_num
,
img_height
,
img_width
);
box_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes
,
pred
,
box_idx
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
5
);
int
score_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
class_num
;
CalcLabelScore
<
T
>
(
scores
,
input
,
label_idx
,
score_idx
,
class_num
,
conf
,
grid_num
);
}
}
template
<
typename
T
>
...
...
@@ -36,6 +65,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
img_size
=
ctx
.
Input
<
Tensor
>
(
"ImgSize"
);
auto
*
boxes
=
ctx
.
Output
<
Tensor
>
(
"Boxes"
);
auto
*
scores
=
ctx
.
Output
<
Tensor
>
(
"Scores"
);
...
...
@@ -51,14 +81,16 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const
int
an_num
=
anchors
.
size
()
/
2
;
int
input_size
=
downsample_ratio
*
h
;
const
int
stride
=
h
*
w
;
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
({
n
},
ctx
.
GetPlace
());
memset
(
loss_data
,
0
,
boxes
->
numel
()
*
sizeof
(
T
));
T
*
scores_data
=
scores
->
mutable_data
<
T
>
({
n
},
ctx
.
GetPlace
());
const
int
*
imgsize_data
=
imgsize
->
data
<
int
>
();
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
({
n
,
box_num
,
4
},
ctx
.
GetPlace
());
memset
(
boxes_data
,
0
,
boxes
->
numel
()
*
sizeof
(
T
));
T
*
scores_data
=
scores
->
mutable_data
<
T
>
({
n
,
box_num
,
class_num
},
ctx
.
GetPlace
());
memset
(
scores_data
,
0
,
scores
->
numel
()
*
sizeof
(
T
));
int
grid_dim
=
(
n
*
box_num
+
512
-
1
)
/
512
;
grid_dim
=
grid_dim
>
8
?
8
:
grid_dim
;
}
};
// namespace operators
...
...
paddle/fluid/operators/detection/yolo_box_op.h
浏览文件 @
04b8b9e9
...
...
@@ -30,10 +30,10 @@ static inline T sigmoid(T x) {
}
template
<
typename
T
>
static
inline
Box
<
T
>
GetYoloBox
(
const
T
*
x
,
std
::
vector
<
int
>
anchors
,
int
i
,
int
j
,
int
an_idx
,
int
grid_size
,
int
input_size
,
int
index
,
int
stride
,
int
img_height
,
int
img_width
)
{
HOSTDEVICE
inline
Box
<
T
>
GetYoloBox
(
const
T
*
x
,
std
::
vector
<
int
>
anchors
,
int
i
,
int
j
,
int
an_idx
,
int
grid_size
,
int
input_size
,
int
index
,
int
stride
,
int
img_height
,
int
img_width
)
{
Box
<
T
>
b
;
b
.
x
=
(
i
+
sigmoid
<
T
>
(
x
[
index
]))
*
img_width
/
grid_size
;
b
.
y
=
(
j
+
sigmoid
<
T
>
(
x
[
index
+
stride
]))
*
img_height
/
grid_size
;
...
...
@@ -44,13 +44,15 @@ static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
return
b
;
}
static
inline
int
GetEntryIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
int
an_num
,
int
an_stride
,
int
stride
,
int
entry
)
{
HOSTDEVICE
inline
int
GetEntryIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
int
an_num
,
int
an_stride
,
int
stride
,
int
entry
)
{
return
(
batch
*
an_num
+
an_idx
)
*
an_stride
+
entry
*
stride
+
hw_idx
;
}
template
<
typename
T
>
static
inline
void
CalcDetectionBox
(
T
*
boxes
,
Box
<
T
>
pred
,
const
int
box_idx
)
{
HOSTDEVICE
inline
void
CalcDetectionBox
(
T
*
boxes
,
Box
<
T
>
pred
,
const
int
box_idx
)
{
boxes
[
box_idx
]
=
pred
.
x
-
pred
.
w
/
2
;
boxes
[
box_idx
+
1
]
=
pred
.
y
-
pred
.
h
/
2
;
boxes
[
box_idx
+
2
]
=
pred
.
x
+
pred
.
w
/
2
;
...
...
@@ -58,10 +60,10 @@ static inline void CalcDetectionBox(T* boxes, Box<T> pred, const int box_idx) {
}
template
<
typename
T
>
static
inline
void
CalcLabelScore
(
T
*
scores
,
const
T
*
input
,
const
int
label_idx
,
const
int
score_idx
,
const
int
class_num
,
const
T
conf
,
const
int
stride
)
{
HOSTDEVICE
inline
void
CalcLabelScore
(
T
*
scores
,
const
T
*
input
,
const
int
label_idx
,
const
int
score_idx
,
const
int
class_num
,
const
T
conf
,
const
int
stride
)
{
for
(
int
i
=
0
;
i
<
class_num
;
i
++
)
{
scores
[
score_idx
+
i
]
=
conf
*
sigmoid
<
T
>
(
input
[
label_idx
+
i
*
stride
]);
}
...
...
@@ -115,8 +117,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
);
Box
<
T
>
pred
=
GetYoloBox
(
input_data
,
anchors
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
stride
,
img_height
,
img_width
);
GetYoloBox
<
T
>
(
input_data
,
anchors
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
stride
,
img_height
,
img_width
);
box_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes_data
,
pred
,
box_idx
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录