未验证 提交 4b3e8d56 编写于 作者: W wawltor 提交者: GitHub

fix the cumsum bug for large size (#43722)

上级 561d09b9
...@@ -176,10 +176,8 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -176,10 +176,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage; } temp_storage;
int bx = blockIdx.x; int bx = blockIdx.x;
int by = blockIdx.y;
BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op); BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op);
T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads) // Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
...@@ -192,7 +190,7 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -192,7 +190,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size; valid_item = scan_size;
} }
int offset = bx * scan_size + block_offset + by * (inner_size * scan_size); int offset = block_offset + bx * scan_size;
T thread_keys[ITEMS_PER_THREAD]; T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load) BlockLoadT(temp_storage.load)
...@@ -307,6 +305,7 @@ void ScanKernel(const Context& dev_ctx, ...@@ -307,6 +305,7 @@ void ScanKernel(const Context& dev_ctx,
int outer_size = height / scan_size; int outer_size = height / scan_size;
int inner_size = width; int inner_size = width;
// Consider the size of shared memory, here block size is 128 // Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size); dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid; dim3 reverse_grid = scan_grid;
if (reverse) { if (reverse) {
...@@ -322,13 +321,14 @@ void ScanKernel(const Context& dev_ctx, ...@@ -322,13 +321,14 @@ void ScanKernel(const Context& dev_ctx,
in_data, out_data, scan_size, outer_size, inner_size); in_data, out_data, scan_size, outer_size, inner_size);
} }
} }
int64_t grid_size = outer_size * inner_size;
if (!transpose && !reverse) { if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>( BlockScanKernel<T, 128, 4, Op><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive, op); out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);
} else { } else {
BlockScanKernel<T, 128, 4, Op> BlockScanKernel<T, 128, 4, Op>
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(next_out_data, <<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data, next_in_data,
outer_size, outer_size,
inner_size, inner_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册