Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
53c288a3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
53c288a3
编写于
5月 21, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/cuda): fix topk grid oversize
GitOrigin-RevId: d3c811a034e09f72576173130d9aac26d601fbf6
上级
124767b4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
125 addition
and
51 deletion
+125
-51
dnn/src/cuda/topk/topk_radix.cu
dnn/src/cuda/topk/topk_radix.cu
+119
-51
dnn/test/common/topk.cpp
dnn/test/common/topk.cpp
+6
-0
未找到文件。
dnn/src/cuda/topk/topk_radix.cu
浏览文件 @
53c288a3
...
...
@@ -470,7 +470,17 @@ static size_t get_scan_workspace(uint32_t size) {
uint32_t
topk
::
find_kth_radix_workspace
(
uint32_t
batch
,
uint32_t
length
)
{
using
namespace
cuda_topk_impl
::
kth
;
return
(
batch
*
get_grid_dim_x
(
length
)
*
NR_BUCKET
+
batch
*
2
)
*
int
device_id
;
if
(
cudaGetDevice
(
&
device_id
)
!=
cudaSuccess
)
{
megdnn_trap
();
}
cudaDeviceProp
prop
;
if
(
cudaGetDeviceProperties
(
&
prop
,
device_id
)
!=
cudaSuccess
)
{
megdnn_trap
();
}
uint32_t
grid_dim_y_limit
=
prop
.
maxGridSize
[
1
];
uint32_t
limit
=
batch
>
grid_dim_y_limit
?
grid_dim_y_limit
:
batch
;
return
(
limit
*
get_grid_dim_x
(
length
)
*
NR_BUCKET
+
limit
*
2
)
*
sizeof
(
uint32_t
);
}
...
...
@@ -491,35 +501,65 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
// assert
megdnn_trap
();
}
int
device_id
;
if
(
cudaGetDevice
(
&
device_id
)
!=
cudaSuccess
)
{
megdnn_trap
();
}
cudaDeviceProp
prop
;
if
(
cudaGetDeviceProperties
(
&
prop
,
device_id
)
!=
cudaSuccess
)
{
megdnn_trap
();
}
uint32_t
grid_dim_y_limit
=
prop
.
maxGridSize
[
1
];
uint32_t
batch_idx
=
0
;
uint32_t
grid_dim_x
=
get_grid_dim_x
(
length
);
dim3
grid_dim
(
grid_dim_x
,
batch
);
uint32_t
*
dev_k
=
static_cast
<
uint32_t
*>
(
workspace
);
uint32_t
*
dev_prefix
=
dev_k
+
batch
;
uint32_t
*
bucket_cnt
=
dev_prefix
+
batch
;
compute_histogram
<
ctype
,
false
,
24
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
,
bucket_cnt
,
length
,
lda
,
nullptr
);
// use float to make compiler happy; it is not used since last == false
update_prefix_and_k
<
true
,
false
,
24
,
float
>
<<<
batch
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
nullptr
);
compute_histogram
<
ctype
,
true
,
16
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
,
bucket_cnt
,
length
,
lda
,
dev_prefix
);
update_prefix_and_k
<
false
,
false
,
16
,
float
>
<<<
batch
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
nullptr
);
compute_histogram
<
ctype
,
true
,
8
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
,
bucket_cnt
,
length
,
lda
,
dev_prefix
);
update_prefix_and_k
<
false
,
false
,
8
,
float
>
<<<
batch
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
nullptr
);
compute_histogram
<
ctype
,
true
,
0
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
,
bucket_cnt
,
length
,
lda
,
dev_prefix
);
update_prefix_and_k
<
false
,
true
,
0
,
ctype
><<<
batch
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
output
);
uint32_t
grid_dim_y
=
1
;
while
(
batch_idx
<
batch
)
{
if
(
batch
-
batch_idx
>=
grid_dim_y_limit
)
{
grid_dim_y
=
grid_dim_y_limit
;
}
else
{
grid_dim_y
=
batch
-
batch_idx
;
}
dim3
grid_dim
(
grid_dim_x
,
grid_dim_y
);
uint32_t
*
dev_k
=
static_cast
<
uint32_t
*>
(
workspace
);
uint32_t
*
dev_prefix
=
dev_k
+
grid_dim_y
;
uint32_t
*
bucket_cnt
=
dev_prefix
+
grid_dim_y
;
compute_histogram
<
ctype
,
false
,
24
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
+
batch_idx
*
lda
,
bucket_cnt
,
length
,
lda
,
nullptr
);
// use float to make compiler happy; it is not used since last == false
update_prefix_and_k
<
true
,
false
,
24
,
float
>
<<<
grid_dim_y
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
nullptr
);
compute_histogram
<
ctype
,
true
,
16
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
+
batch_idx
*
lda
,
bucket_cnt
,
length
,
lda
,
dev_prefix
);
update_prefix_and_k
<
false
,
false
,
16
,
float
>
<<<
grid_dim_y
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
nullptr
);
compute_histogram
<
ctype
,
true
,
8
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
+
batch_idx
*
lda
,
bucket_cnt
,
length
,
lda
,
dev_prefix
);
update_prefix_and_k
<
false
,
false
,
8
,
float
>
<<<
grid_dim_y
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
nullptr
);
compute_histogram
<
ctype
,
true
,
0
><<<
grid_dim
,
BLOCK_DIM
,
0
,
stream
>>>
(
input
+
batch_idx
*
lda
,
bucket_cnt
,
length
,
lda
,
dev_prefix
);
update_prefix_and_k
<
false
,
true
,
0
,
ctype
>
<<<
grid_dim_y
,
NR_BUCKET
,
0
,
stream
>>>
(
bucket_cnt
,
dev_prefix
,
dev_k
,
k
,
grid_dim_x
,
output
+
batch_idx
);
batch_idx
+=
grid_dim_y
;
}
return
cudaGetLastError
();
}
...
...
@@ -530,12 +570,18 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
int32_t
lda
,
int32_t
k
,
cudaStream_t
stream
)
{
using
namespace
cuda_topk_impl
;
using
namespace
cuda_topk_impl
::
select
;
uint32_t
length_split
=
DIVUP
(
length
,
REDUCE_SIZE
),
scan_size
=
batch
*
length_split
;
size_t
scan_wk
=
get_scan_workspace
(
scan_size
);
uint64_t
*
scan_inp
=
static_cast
<
uint64_t
*>
(
workspace
)
+
scan_wk
/
sizeof
(
uint64_t
),
*
scan_out
=
scan_inp
+
scan_size
;
int
device_id
;
if
(
cudaGetDevice
(
&
device_id
)
!=
cudaSuccess
)
{
megdnn_trap
();
}
cudaDeviceProp
prop
;
if
(
cudaGetDeviceProperties
(
&
prop
,
device_id
)
!=
cudaSuccess
)
{
megdnn_trap
();
}
uint32_t
batch_upper_limit
=
prop
.
maxGridSize
[
1
];
uint32_t
length_split
=
DIVUP
(
length
,
REDUCE_SIZE
);
void
(
*
kptr_reduce_block_cnt
)(
const
ctype
*
,
const
ctype
*
,
uint32_t
,
int32_t
,
uint64_t
*
,
uint32_t
);
...
...
@@ -585,25 +631,47 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
#undef CASE_SHARD
#undef CASE_SHARD_ON
// reduce to scan_inp
kptr_reduce_block_cnt
<<<
dim3
(
DIVUP
(
length_split
,
REDUCE_SHARD
),
batch
),
dim3
(
REDUCE_WARP_SIZE
,
REDUCE_SHARD
),
0
,
stream
>>>
(
input
,
thresh
,
length
,
lda
,
scan_inp
,
length_split
);
uint32_t
batch_idx
=
0
;
uint32_t
batch_real
=
1
;
// scan to scan_out
scan_out
+=
1
;
// set scan[-1] to 0
cudaError_t
err
=
invoke_cub_scan
(
scan_inp
,
scan_out
,
workspace
,
scan_wk
,
scan_size
,
stream
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
kern_init_zero
<<<
1
,
1
,
0
,
stream
>>>
(
scan_out
-
1
);
while
(
batch_idx
<
batch
)
{
if
(
batch
-
batch_idx
>=
batch_upper_limit
)
{
batch_real
=
batch_upper_limit
;
}
else
{
batch_real
=
batch
-
batch_idx
;
}
// copy result
kptr_copy
<<<
dim3
(
DIVUP
(
length_split
,
kern_copy_shard
),
batch
),
dim3
(
WARP_SIZE
,
kern_copy_shard
),
0
,
stream
>>>
(
input
,
thresh
,
scan_out
,
length_split
,
output_value
,
output_idx
,
length
,
k
,
lda
);
size_t
scan_size
=
batch_real
*
length_split
;
size_t
scan_wk
=
get_scan_workspace
(
scan_size
);
uint64_t
*
scan_inp
=
static_cast
<
uint64_t
*>
(
workspace
)
+
scan_wk
/
sizeof
(
uint64_t
),
*
scan_out
=
scan_inp
+
scan_size
;
// reduce to scan_inp
kptr_reduce_block_cnt
<<<
dim3
(
DIVUP
(
length_split
,
REDUCE_SHARD
),
batch_real
),
dim3
(
REDUCE_WARP_SIZE
,
REDUCE_SHARD
),
0
,
stream
>>>
(
input
+
batch_idx
*
lda
,
thresh
+
batch_idx
,
length
,
lda
,
scan_inp
,
length_split
);
// scan to scan_out
scan_out
+=
1
;
// set scan[-1] to 0
cudaError_t
err
=
invoke_cub_scan
(
scan_inp
,
scan_out
,
workspace
,
scan_wk
,
scan_size
,
stream
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
kern_init_zero
<<<
1
,
1
,
0
,
stream
>>>
(
scan_out
-
1
);
// copy result
kptr_copy
<<<
dim3
(
DIVUP
(
length_split
,
kern_copy_shard
),
batch_real
),
dim3
(
WARP_SIZE
,
kern_copy_shard
),
0
,
stream
>>>
(
input
+
batch_idx
*
lda
,
thresh
+
batch_idx
,
scan_out
,
length_split
,
output_value
+
std
::
abs
(
k
)
*
batch_idx
,
output_idx
+
std
::
abs
(
k
)
*
batch_idx
,
length
,
k
,
lda
);
batch_idx
+=
batch_real
;
}
return
cudaGetLastError
();
}
...
...
dnn/test/common/topk.cpp
浏览文件 @
53c288a3
...
...
@@ -169,6 +169,12 @@ void test::run_topk_test(Handle* handle) {
run
(
5
,
123
,
3
,
mode
);
// equiv to sort
run
(
-
5
,
123
,
3
,
mode
);
// equiv to rev sort
run
(
5
,
3
,
1231
,
mode
,
2000
);
// non contig
//! opencl on armv7's CI does not support large batch.
//! but P30 and MI9 are ok. fix it in the future.
#if !defined(MEGDNN_ARMV7) && defined(MGB_CUDA)
run
(
3
,
70000
,
5
,
mode
,
10
);
// non contig
#endif
}
// special case to check if tie-break is correct
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录