// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/cumsum_kernel.h" #include #include #include #include #ifdef __NVCC__ #include #endif #ifdef __HIPCC__ #include namespace cub = hipcub; #endif #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { template __device__ void BlockReverse( const T* idata, T* odata, int src_base, int dst_base, int valid_item) { __shared__ T sh_mem[BLOCK_SIZE]; int tx = threadIdx.x; int offset = tx; int in_index = src_base + offset; if (offset >= valid_item) { sh_mem[offset] = 0; } else { int sh_mem_index = BLOCK_SIZE - offset - 1; T data = idata[in_index]; sh_mem[sh_mem_index] = data; } __syncthreads(); int out_index = dst_base - offset; if (offset < valid_item) { int sh_mem_index = BLOCK_SIZE - offset - 1; odata[out_index] = sh_mem[sh_mem_index]; } } template __global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data, int reverse_size, int outer_size, int inner_size) { int bx = blockIdx.x; int by = blockIdx.y; int item_per_block = 1024; for (int block_offset = 0; block_offset < reverse_size; block_offset += item_per_block) { int valid_item = (reverse_size - block_offset > item_per_block) ? item_per_block : reverse_size - block_offset; int src_offset = bx * reverse_size + block_offset + by * (inner_size * reverse_size); int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) + reverse_size - 1 - block_offset; if (reverse_size < item_per_block) { valid_item = reverse_size; } BlockReverse( matrix_data, reverse_data, src_offset, dst_offset, valid_item); } } template struct BlockPrefixCallbackOp { // Running prefix T running_total; // Constructor __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {} // Callback operator to be entered by the first warp of threads in the block. // Thread-0 is responsible for returning a value for seeding the block-wide // scan. __device__ T operator()(T block_aggregate) { T old_prefix = running_total; running_total = old_prefix + block_aggregate; return old_prefix; } }; // No bank-conflict transpose template __global__ void MatrixTranspose(T* odata, const T* idata, size_t height, size_t width) { __shared__ T tile[TILE_DIM][TILE_DIM + 1]; int x = blockIdx.x * TILE_DIM + threadIdx.x; int y = blockIdx.y * TILE_DIM + threadIdx.y; for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { if (x < width && (y + j) < height) { tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x]; } else { tile[threadIdx.y + j][threadIdx.x] = 0; } } __syncthreads(); x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset y = blockIdx.x * TILE_DIM + threadIdx.y; for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { if (x < height && (y + j) < width) { odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j]; } } } template __global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size, int outer_size, int scan_size, bool exclusive) { // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types typedef cub:: BlockLoad BlockLoadT; typedef cub:: BlockStore BlockStoreT; typedef cub::BlockScan BlockScanT; // Allocate type-safe, repurposable shared memory for collectives __shared__ union { typename BlockLoadT::TempStorage load; typename BlockStoreT::TempStorage store; typename BlockScanT::TempStorage scan; } temp_storage; int bx = blockIdx.x; int by = blockIdx.y; BlockPrefixCallbackOp prefix_op(0); T block_aggregate = static_cast(0); // Obtain this block's segment of consecutive keys (blocked across threads) int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; for (int block_offset = 0; block_offset < scan_size; block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) { int valid_item = (scan_size - block_offset > item_per_block) ? item_per_block : (scan_size - block_offset); if (scan_size < item_per_block) { valid_item = scan_size; } int offset = bx * scan_size + block_offset + by * (inner_size * scan_size); T thread_keys[ITEMS_PER_THREAD]; BlockLoadT(temp_storage.load) .Load(d_in + offset, thread_keys, valid_item, 0); __syncthreads(); if (exclusive) { T init_value = static_cast(0); BlockScanT(temp_storage.scan) .ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); } else { BlockScanT(temp_storage.scan) .InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); } __syncthreads(); BlockStoreT(temp_storage.store) .Store(d_out + offset, thread_keys, valid_item); } } template void CumsumKernel(const Context& dev_ctx, const DenseTensor& x, int axis, bool flatten, bool exclusive, bool reverse, DenseTensor* out) { auto out_dims = out->dims(); auto size = x.numel(); PADDLE_ENFORCE_EQ( axis < out_dims.size() && axis >= (0 - out_dims.size()), true, phi::errors::OutOfRange( "Attr(axis) is out of range, It's expected " "to be in range of [-%d, %d]. But received Attr(axis) = %d.", out_dims.size(), out_dims.size() - 1, axis)); if (axis < 0) { axis += out_dims.size(); } T* out_data = dev_ctx.template Alloc(out); const T* in_data = x.data(); // Use thrust for parallel acceleration when the input size is equal to the // length of the ‘axis’ dimension. if (size == out_dims[axis]) { if (reverse) { thrust::device_ptr dev_ptr = thrust::device_pointer_cast(in_data); thrust::device_vector vec(dev_ptr, dev_ptr + size); if (exclusive) { thrust::exclusive_scan( thrust::device, vec.rbegin(), vec.rend(), out_data); } else { thrust::inclusive_scan( thrust::device, vec.rbegin(), vec.rend(), out_data); } thrust::reverse(thrust::device, out_data, out_data + size); } else { if (exclusive) { thrust::exclusive_scan( thrust::device, in_data, in_data + size, out_data); } else { thrust::inclusive_scan( thrust::device, in_data, in_data + size, out_data); } } return; } size_t height = 1; size_t width = 1; for (size_t i = 0; i <= axis; i++) { height *= out_dims[i]; } for (size_t i = axis + 1; i < out_dims.size(); i++) { width *= out_dims[i]; } int scan_size = out_dims[axis]; bool transpose = (axis != out_dims.size() - 1); int tile_size = 32; dim3 blocks(32, 8); dim3 transpose_grids((width + tile_size - 1) / tile_size, (height + tile_size - 1) / tile_size); out->Resize(out_dims); auto* tmp_data = out->data(); T* next_in_data = out_data; T* next_out_data = tmp_data; if (transpose) { MatrixTranspose<<>>( out_data, in_data, height, width); next_in_data = out_data; next_out_data = tmp_data; } auto swap_ptr = [](T*& ptr1, T*& ptr2) { T* tmp = ptr2; ptr2 = ptr1; ptr1 = tmp; }; int outer_size = height / scan_size; int inner_size = width; // Consider the size of shared memory, here block size is 128 dim3 scan_grid(outer_size, inner_size); dim3 reverse_grid = scan_grid; if (reverse) { if (transpose) { reverse_grid.x = scan_grid.y; reverse_grid.y = scan_grid.x; MatrixRowReverse<<>>( next_in_data, next_out_data, scan_size, outer_size, inner_size); if (!transpose) next_in_data = tmp_data; swap_ptr(next_in_data, next_out_data); } else { MatrixRowReverse<<>>( in_data, out_data, scan_size, outer_size, inner_size); } } if (!transpose && !reverse) { BlockScanKernel<<>>( out_data, in_data, outer_size, inner_size, scan_size, exclusive); } else { BlockScanKernel<<>>( next_out_data, next_in_data, outer_size, inner_size, scan_size, exclusive); } swap_ptr(next_in_data, next_out_data); if (reverse) { MatrixRowReverse<<>>( next_in_data, next_out_data, scan_size, outer_size, inner_size); swap_ptr(next_in_data, next_out_data); } if (transpose) { transpose_grids.x = (height + tile_size - 1) / tile_size; transpose_grids.y = (width + tile_size - 1) / tile_size; MatrixTranspose<<>>( next_out_data, next_in_data, width, height); } } } // namespace phi PD_REGISTER_KERNEL(cumsum, GPU, ALL_LAYOUT, phi::CumsumKernel, float, double, int16_t, int, int64_t) {}