未验证 提交 edff59b1 编写于 作者: W wawltor 提交者: GitHub

[cherry-pick] fix the cumsum big shape and random bug (#43777)

上级 e700ffdc
...@@ -152,10 +152,8 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -152,10 +152,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage; } temp_storage;
int bx = blockIdx.x; int bx = blockIdx.x;
int by = blockIdx.y;
BlockPrefixCallbackOp<T> prefix_op(0); BlockPrefixCallbackOp<T> prefix_op(0);
T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads) // Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
...@@ -168,7 +166,7 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -168,7 +166,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size; valid_item = scan_size;
} }
int offset = bx * scan_size + block_offset + by * (inner_size * scan_size); int offset = block_offset + bx * scan_size;
T thread_keys[ITEMS_PER_THREAD]; T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load) BlockLoadT(temp_storage.load)
...@@ -260,8 +258,10 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -260,8 +258,10 @@ void CumsumKernel(const Context& dev_ctx,
dim3 blocks(32, 8); dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size, dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size); (height + tile_size - 1) / tile_size);
out->Resize(out_dims);
auto* tmp_data = out->data<T>(); DenseTensor tmp_tensor;
tmp_tensor.Resize(out_dims);
auto* tmp_data = dev_ctx.template Alloc<T>(&tmp_tensor);
T* next_in_data = out_data; T* next_in_data = out_data;
T* next_out_data = tmp_data; T* next_out_data = tmp_data;
...@@ -281,6 +281,8 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -281,6 +281,8 @@ void CumsumKernel(const Context& dev_ctx,
// Consider the size of shared memory, here block size is 128 // Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size); dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid; dim3 reverse_grid = scan_grid;
int64_t grid_size = outer_size * inner_size;
if (reverse) { if (reverse) {
if (transpose) { if (transpose) {
reverse_grid.x = scan_grid.y; reverse_grid.x = scan_grid.y;
...@@ -295,12 +297,12 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -295,12 +297,12 @@ void CumsumKernel(const Context& dev_ctx,
} }
} }
if (!transpose && !reverse) { if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>( BlockScanKernel<T, 128, 4><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive); out_data, in_data, outer_size, inner_size, scan_size, exclusive);
} else { } else {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>( BlockScanKernel<T, 128, 4>
next_out_data, <<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data, next_in_data,
outer_size, outer_size,
inner_size, inner_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册