未验证 提交 edff59b1 编写于 作者: W wawltor 提交者: GitHub

[cherry-pick] fix the cumsum big shape and random bug (#43777)

上级 e700ffdc
......@@ -152,10 +152,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage;
int bx = blockIdx.x;
int by = blockIdx.y;
BlockPrefixCallbackOp<T> prefix_op(0);
T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
......@@ -168,7 +166,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size;
}
int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
int offset = block_offset + bx * scan_size;
T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
......@@ -260,8 +258,10 @@ void CumsumKernel(const Context& dev_ctx,
dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
out->Resize(out_dims);
auto* tmp_data = out->data<T>();
DenseTensor tmp_tensor;
tmp_tensor.Resize(out_dims);
auto* tmp_data = dev_ctx.template Alloc<T>(&tmp_tensor);
T* next_in_data = out_data;
T* next_out_data = tmp_data;
......@@ -281,6 +281,8 @@ void CumsumKernel(const Context& dev_ctx,
// Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
int64_t grid_size = outer_size * inner_size;
if (reverse) {
if (transpose) {
reverse_grid.x = scan_grid.y;
......@@ -295,17 +297,17 @@ void CumsumKernel(const Context& dev_ctx,
}
}
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
BlockScanKernel<T, 128, 4><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive);
} else {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
BlockScanKernel<T, 128, 4>
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
}
swap_ptr(next_in_data, next_out_data);
if (reverse) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册