未验证 提交 4b3e8d56 编写于 作者: W wawltor 提交者: GitHub

fix the cumsum bug for large size (#43722)

上级 561d09b9
......@@ -176,10 +176,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage;
int bx = blockIdx.x;
int by = blockIdx.y;
BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op);
T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
......@@ -192,7 +190,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size;
}
int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
int offset = block_offset + bx * scan_size;
T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
......@@ -307,6 +305,7 @@ void ScanKernel(const Context& dev_ctx,
int outer_size = height / scan_size;
int inner_size = width;
// Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
if (reverse) {
......@@ -322,13 +321,14 @@ void ScanKernel(const Context& dev_ctx,
in_data, out_data, scan_size, outer_size, inner_size);
}
}
int64_t grid_size = outer_size * inner_size;
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
BlockScanKernel<T, 128, 4, Op><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);
} else {
BlockScanKernel<T, 128, 4, Op>
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(next_out_data,
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data,
outer_size,
inner_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册