提交 53c288a3 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/cuda): fix topk grid oversize

GitOrigin-RevId: d3c811a034e09f72576173130d9aac26d601fbf6
上级 124767b4
...@@ -470,7 +470,17 @@ static size_t get_scan_workspace(uint32_t size) { ...@@ -470,7 +470,17 @@ static size_t get_scan_workspace(uint32_t size) {
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) {
using namespace cuda_topk_impl::kth; using namespace cuda_topk_impl::kth;
return (batch * get_grid_dim_x(length) * NR_BUCKET + batch * 2) * int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch;
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) *
sizeof(uint32_t); sizeof(uint32_t);
} }
...@@ -491,35 +501,65 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, ...@@ -491,35 +501,65 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
// assert // assert
megdnn_trap(); megdnn_trap();
} }
int device_id;
if (cudaGetDevice(&device_id) != cudaSuccess) {
megdnn_trap();
}
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t grid_dim_y_limit = prop.maxGridSize[1];
uint32_t batch_idx = 0;
uint32_t grid_dim_x = get_grid_dim_x(length); uint32_t grid_dim_x = get_grid_dim_x(length);
dim3 grid_dim(grid_dim_x, batch); uint32_t grid_dim_y = 1;
while (batch_idx < batch) {
if (batch - batch_idx >= grid_dim_y_limit) {
grid_dim_y = grid_dim_y_limit;
} else {
grid_dim_y = batch - batch_idx;
}
dim3 grid_dim(grid_dim_x, grid_dim_y);
uint32_t* dev_k = static_cast<uint32_t*>(workspace); uint32_t* dev_k = static_cast<uint32_t*>(workspace);
uint32_t* dev_prefix = dev_k + batch; uint32_t* dev_prefix = dev_k + grid_dim_y;
uint32_t* bucket_cnt = dev_prefix + batch; uint32_t* bucket_cnt = dev_prefix + grid_dim_y;
compute_histogram<ctype, false, 24><<<grid_dim, BLOCK_DIM, 0, stream>>>( compute_histogram<ctype, false, 24><<<grid_dim, BLOCK_DIM, 0, stream>>>(
input, bucket_cnt, length, lda, nullptr); input + batch_idx * lda, bucket_cnt, length, lda, nullptr);
// use float to make compiler happy; it is not used since last == false // use float to make compiler happy; it is not used since last == false
update_prefix_and_k<true, false, 24, float> update_prefix_and_k<true, false, 24, float>
<<<batch, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, dev_k, k, <<<grid_dim_y, NR_BUCKET, 0, stream>>>(
grid_dim_x, nullptr); bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr);
compute_histogram<ctype, true, 16><<<grid_dim, BLOCK_DIM, 0, stream>>>( compute_histogram<ctype, true, 16><<<grid_dim, BLOCK_DIM, 0, stream>>>(
input, bucket_cnt, length, lda, dev_prefix); input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix);
update_prefix_and_k<false, false, 16, float> update_prefix_and_k<false, false, 16, float>
<<<batch, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, dev_k, k, <<<grid_dim_y, NR_BUCKET, 0, stream>>>(
grid_dim_x, nullptr); bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr);
compute_histogram<ctype, true, 8><<<grid_dim, BLOCK_DIM, 0, stream>>>( compute_histogram<ctype, true, 8><<<grid_dim, BLOCK_DIM, 0, stream>>>(
input, bucket_cnt, length, lda, dev_prefix); input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix);
update_prefix_and_k<false, false, 8, float> update_prefix_and_k<false, false, 8, float>
<<<batch, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, dev_k, k, <<<grid_dim_y, NR_BUCKET, 0, stream>>>(
grid_dim_x, nullptr); bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr);
compute_histogram<ctype, true, 0><<<grid_dim, BLOCK_DIM, 0, stream>>>( compute_histogram<ctype, true, 0><<<grid_dim, BLOCK_DIM, 0, stream>>>(
input, bucket_cnt, length, lda, dev_prefix); input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix);
update_prefix_and_k<false, true, 0, ctype><<<batch, NR_BUCKET, 0, stream>>>(
bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, output); update_prefix_and_k<false, true, 0, ctype>
<<<grid_dim_y, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix,
dev_k, k, grid_dim_x,
output + batch_idx);
batch_idx += grid_dim_y;
}
return cudaGetLastError(); return cudaGetLastError();
} }
...@@ -530,12 +570,18 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, ...@@ -530,12 +570,18 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
int32_t lda, int32_t k, cudaStream_t stream) { int32_t lda, int32_t k, cudaStream_t stream) {
using namespace cuda_topk_impl; using namespace cuda_topk_impl;
using namespace cuda_topk_impl::select; using namespace cuda_topk_impl::select;
uint32_t length_split = DIVUP(length, REDUCE_SIZE),
scan_size = batch * length_split; int device_id;
size_t scan_wk = get_scan_workspace(scan_size); if (cudaGetDevice(&device_id) != cudaSuccess) {
uint64_t *scan_inp = static_cast<uint64_t*>(workspace) + megdnn_trap();
scan_wk / sizeof(uint64_t), }
*scan_out = scan_inp + scan_size; cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) {
megdnn_trap();
}
uint32_t batch_upper_limit = prop.maxGridSize[1];
uint32_t length_split = DIVUP(length, REDUCE_SIZE);
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t,
uint64_t*, uint32_t); uint64_t*, uint32_t);
...@@ -585,25 +631,47 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, ...@@ -585,25 +631,47 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh,
#undef CASE_SHARD #undef CASE_SHARD
#undef CASE_SHARD_ON #undef CASE_SHARD_ON
uint32_t batch_idx = 0;
uint32_t batch_real = 1;
while (batch_idx < batch) {
if (batch - batch_idx >= batch_upper_limit) {
batch_real = batch_upper_limit;
} else {
batch_real = batch - batch_idx;
}
size_t scan_size = batch_real * length_split;
size_t scan_wk = get_scan_workspace(scan_size);
uint64_t *scan_inp = static_cast<uint64_t*>(workspace) +
scan_wk / sizeof(uint64_t),
*scan_out = scan_inp + scan_size;
// reduce to scan_inp // reduce to scan_inp
kptr_reduce_block_cnt<<<dim3(DIVUP(length_split, REDUCE_SHARD), batch), kptr_reduce_block_cnt<<<
dim3(DIVUP(length_split, REDUCE_SHARD), batch_real),
dim3(REDUCE_WARP_SIZE, REDUCE_SHARD), 0, stream>>>( dim3(REDUCE_WARP_SIZE, REDUCE_SHARD), 0, stream>>>(
input, thresh, length, lda, scan_inp, length_split); input + batch_idx * lda, thresh + batch_idx, length, lda,
scan_inp, length_split);
// scan to scan_out // scan to scan_out
scan_out += 1; // set scan[-1] to 0 scan_out += 1; // set scan[-1] to 0
cudaError_t err = invoke_cub_scan(scan_inp, scan_out, workspace, scan_wk, cudaError_t err = invoke_cub_scan(scan_inp, scan_out, workspace,
scan_size, stream); scan_wk, scan_size, stream);
if (err != cudaSuccess) { if (err != cudaSuccess) {
return err; return err;
} }
kern_init_zero<<<1, 1, 0, stream>>>(scan_out - 1); kern_init_zero<<<1, 1, 0, stream>>>(scan_out - 1);
// copy result // copy result
kptr_copy<<<dim3(DIVUP(length_split, kern_copy_shard), batch), kptr_copy<<<dim3(DIVUP(length_split, kern_copy_shard), batch_real),
dim3(WARP_SIZE, kern_copy_shard), 0, stream>>>( dim3(WARP_SIZE, kern_copy_shard), 0, stream>>>(
input, thresh, scan_out, length_split, output_value, output_idx, input + batch_idx * lda, thresh + batch_idx, scan_out,
length, k, lda); length_split, output_value + std::abs(k) * batch_idx,
output_idx + std::abs(k) * batch_idx, length, k, lda);
batch_idx += batch_real;
}
return cudaGetLastError(); return cudaGetLastError();
} }
......
...@@ -169,6 +169,12 @@ void test::run_topk_test(Handle* handle) { ...@@ -169,6 +169,12 @@ void test::run_topk_test(Handle* handle) {
run(5, 123, 3, mode); // equiv to sort run(5, 123, 3, mode); // equiv to sort
run(-5, 123, 3, mode); // equiv to rev sort run(-5, 123, 3, mode); // equiv to rev sort
run(5, 3, 1231, mode, 2000); // non contig run(5, 3, 1231, mode, 2000); // non contig
//! opencl on armv7's CI does not support large batch.
//! but P30 and MI9 are ok. fix it in the future.
#if !defined(MEGDNN_ARMV7) && defined(MGB_CUDA)
run(3, 70000, 5, mode, 10); // non contig
#endif
} }
// special case to check if tie-break is correct // special case to check if tie-break is correct
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册