Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6bfa339c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6bfa339c
编写于
9月 15, 2020
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update yolo_box support h != w. test=develop
上级
264e76ca
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
16 deletion
+22
-16
paddle/fluid/operators/detection/yolo_box_op.cu
paddle/fluid/operators/detection/yolo_box_op.cu
+10
-7
paddle/fluid/operators/detection/yolo_box_op.h
paddle/fluid/operators/detection/yolo_box_op.h
+12
-9
未找到文件。
paddle/fluid/operators/detection/yolo_box_op.cu
浏览文件 @
6bfa339c
...
...
@@ -26,8 +26,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T
*
scores
,
const
float
conf_thresh
,
const
int
*
anchors
,
const
int
n
,
const
int
h
,
const
int
w
,
const
int
an_num
,
const
int
class_num
,
const
int
box_num
,
int
input_size
,
bool
clip_bbox
,
const
float
scale
,
const
float
bias
)
{
const
int
box_num
,
int
input_size_h
,
int
input_size_w
,
bool
clip_bbox
,
const
float
scale
,
const
float
bias
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
T
box
[
4
];
...
...
@@ -51,8 +52,9 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
);
GetYoloBox
<
T
>
(
box
,
input
,
anchors
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
grid_num
,
img_height
,
img_width
,
scale
,
bias
);
GetYoloBox
<
T
>
(
box
,
input
,
anchors
,
l
,
k
,
j
,
h
,
w
,
input_size_h
,
input_size_w
,
box_idx
,
grid_num
,
img_height
,
img_width
,
scale
,
bias
);
box_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes
,
box
,
box_idx
,
img_height
,
img_width
,
clip_bbox
);
...
...
@@ -86,7 +88,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
const
int
w
=
input
->
dims
()[
3
];
const
int
box_num
=
boxes
->
dims
()[
1
];
const
int
an_num
=
anchors
.
size
()
/
2
;
int
input_size
=
downsample_ratio
*
h
;
int
input_size_h
=
downsample_ratio
*
h
;
int
input_size_w
=
downsample_ratio
*
w
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
int
bytes
=
sizeof
(
int
)
*
anchors
.
size
();
...
...
@@ -111,8 +114,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw
<
T
><<<
grid_dim
,
512
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
imgsize_data
,
boxes_data
,
scores_data
,
conf_thresh
,
anchors_data
,
n
,
h
,
w
,
an_num
,
class_num
,
box_num
,
input_size
,
clip_bbox
,
scale
,
bias
);
anchors_data
,
n
,
h
,
w
,
an_num
,
class_num
,
box_num
,
input_size
_h
,
input_size_w
,
clip_bbox
,
scale
,
bias
);
}
};
...
...
paddle/fluid/operators/detection/yolo_box_op.h
浏览文件 @
6bfa339c
...
...
@@ -27,17 +27,18 @@ HOSTDEVICE inline T sigmoid(T x) {
template
<
typename
T
>
HOSTDEVICE
inline
void
GetYoloBox
(
T
*
box
,
const
T
*
x
,
const
int
*
anchors
,
int
i
,
int
j
,
int
an_idx
,
int
grid_size
,
int
input_size
,
int
index
,
int
stride
,
int
j
,
int
an_idx
,
int
grid_size_h
,
int
grid_size_w
,
int
input_size_h
,
int
input_size_w
,
int
index
,
int
stride
,
int
img_height
,
int
img_width
,
float
scale
,
float
bias
)
{
box
[
0
]
=
(
i
+
sigmoid
<
T
>
(
x
[
index
])
*
scale
+
bias
)
*
img_width
/
grid_size
;
box
[
0
]
=
(
i
+
sigmoid
<
T
>
(
x
[
index
])
*
scale
+
bias
)
*
img_width
/
grid_size
_w
;
box
[
1
]
=
(
j
+
sigmoid
<
T
>
(
x
[
index
+
stride
])
*
scale
+
bias
)
*
img_height
/
grid_size
;
grid_size
_h
;
box
[
2
]
=
std
::
exp
(
x
[
index
+
2
*
stride
])
*
anchors
[
2
*
an_idx
]
*
img_width
/
input_size
;
input_size
_w
;
box
[
3
]
=
std
::
exp
(
x
[
index
+
3
*
stride
])
*
anchors
[
2
*
an_idx
+
1
]
*
img_height
/
input_size
;
img_height
/
input_size
_h
;
}
HOSTDEVICE
inline
int
GetEntryIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
...
...
@@ -99,7 +100,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const
int
w
=
input
->
dims
()[
3
];
const
int
box_num
=
boxes
->
dims
()[
1
];
const
int
an_num
=
anchors
.
size
()
/
2
;
int
input_size
=
downsample_ratio
*
h
;
int
input_size_h
=
downsample_ratio
*
h
;
int
input_size_w
=
downsample_ratio
*
w
;
const
int
stride
=
h
*
w
;
const
int
an_stride
=
(
class_num
+
5
)
*
stride
;
...
...
@@ -134,8 +136,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
stride
,
0
);
GetYoloBox
<
T
>
(
box
,
input_data
,
anchors_data
,
l
,
k
,
j
,
h
,
input_size
,
box_idx
,
stride
,
img_height
,
img_width
,
scale
,
bias
);
GetYoloBox
<
T
>
(
box
,
input_data
,
anchors_data
,
l
,
k
,
j
,
h
,
w
,
input_size_h
,
input_size_w
,
box_idx
,
stride
,
img_height
,
img_width
,
scale
,
bias
);
box_idx
=
(
i
*
box_num
+
j
*
stride
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes_data
,
box
,
box_idx
,
img_height
,
img_width
,
clip_bbox
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录