diff --git a/paddle/phi/kernels/gpu/cumsum_kernel.cu b/paddle/phi/kernels/gpu/cumsum_kernel.cu index 9073bcb6c027b0ee924ac3d08f72bbd7c1160fd1..4e7eda7537b5a43ba076fcaacda8364e4bdf8e6e 100644 --- a/paddle/phi/kernels/gpu/cumsum_kernel.cu +++ b/paddle/phi/kernels/gpu/cumsum_kernel.cu @@ -152,10 +152,8 @@ __global__ void BlockScanKernel(T* d_out, } temp_storage; int bx = blockIdx.x; - int by = blockIdx.y; BlockPrefixCallbackOp prefix_op(0); - T block_aggregate = static_cast(0); // Obtain this block's segment of consecutive keys (blocked across threads) int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; @@ -168,7 +166,7 @@ __global__ void BlockScanKernel(T* d_out, valid_item = scan_size; } - int offset = bx * scan_size + block_offset + by * (inner_size * scan_size); + int offset = block_offset + bx * scan_size; T thread_keys[ITEMS_PER_THREAD]; BlockLoadT(temp_storage.load) @@ -260,8 +258,10 @@ void CumsumKernel(const Context& dev_ctx, dim3 blocks(32, 8); dim3 transpose_grids((width + tile_size - 1) / tile_size, (height + tile_size - 1) / tile_size); - out->Resize(out_dims); - auto* tmp_data = out->data(); + + DenseTensor tmp_tensor; + tmp_tensor.Resize(out_dims); + auto* tmp_data = dev_ctx.template Alloc(&tmp_tensor); T* next_in_data = out_data; T* next_out_data = tmp_data; @@ -281,6 +281,8 @@ void CumsumKernel(const Context& dev_ctx, // Consider the size of shared memory, here block size is 128 dim3 scan_grid(outer_size, inner_size); dim3 reverse_grid = scan_grid; + int64_t grid_size = outer_size * inner_size; + if (reverse) { if (transpose) { reverse_grid.x = scan_grid.y; @@ -295,17 +297,17 @@ void CumsumKernel(const Context& dev_ctx, } } if (!transpose && !reverse) { - BlockScanKernel<<>>( + BlockScanKernel<<>>( out_data, in_data, outer_size, inner_size, scan_size, exclusive); } else { - BlockScanKernel<<>>( - next_out_data, - next_in_data, - outer_size, - inner_size, - scan_size, - exclusive); + BlockScanKernel + <<>>(next_out_data, + next_in_data, + outer_size, + inner_size, + scan_size, + exclusive); } swap_ptr(next_in_data, next_out_data); if (reverse) {