From b818429ae7ecadde20136c7340bc6dc1497ebc0b Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Sat, 28 Nov 2020 22:58:19 +0800 Subject: [PATCH] optimize cumsum OP (#29193) --- paddle/fluid/operators/cumsum_op.cu | 449 ++++++++++++---------------- 1 file changed, 196 insertions(+), 253 deletions(-) diff --git a/paddle/fluid/operators/cumsum_op.cu b/paddle/fluid/operators/cumsum_op.cu index 85cbf444a5..4bf839f748 100644 --- a/paddle/fluid/operators/cumsum_op.cu +++ b/paddle/fluid/operators/cumsum_op.cu @@ -14,8 +14,10 @@ limitations under the License. */ #include #include +#include #include #include +#include "cub/cub.cuh" #include "paddle/fluid/operators/cum_op.h" #include "paddle/fluid/platform/gpu_launch_config.h" @@ -25,223 +27,157 @@ using LoDTensor = paddle::framework::LoDTensor; namespace paddle { namespace operators { -template -__global__ void OuterScan(const T* in, T* out, int inner_dim_size, - int outer_dim_size, int scan_dim_size, bool exclusive, - bool reverse) { - int id = blockIdx.y * blockDim.x + threadIdx.x; - - for (int outer_index = blockIdx.x; outer_index < outer_dim_size; - outer_index += gridDim.x) { - for (int inner_index = blockIdx.y * blockDim.x + threadIdx.x; - inner_index < inner_dim_size; inner_index += gridDim.y * blockDim.x) { - int scan_index_init = 0; - int forward_direction = 1; - int src_index = - outer_index * scan_dim_size * inner_dim_size + inner_index; - int dst_index = - outer_index * scan_dim_size * inner_dim_size + inner_index; - if (reverse) { - src_index = src_index + (scan_dim_size - 1) * inner_dim_size; - dst_index = dst_index + (scan_dim_size - 1) * inner_dim_size; - forward_direction = -1; - } - if (exclusive) { - scan_index_init = 1; - out[dst_index] = 0; - dst_index = dst_index + (forward_direction * inner_dim_size); - } - T acc = 0; - - for (int scan_index = scan_index_init; scan_index < scan_dim_size; - ++scan_index) { - acc = in[src_index] + acc; - out[dst_index] = acc; - src_index += (forward_direction * inner_dim_size); - dst_index += (forward_direction * inner_dim_size); - } - } +template +__device__ void BlockReverse(const T* idata, T* odata, int src_base, + int dst_base, int valid_item) { + __shared__ T sh_mem[BLOCK_SIZE]; + int tx = threadIdx.x; + + int offset = tx; + int in_index = src_base + offset; + if (offset >= valid_item) { + sh_mem[offset] = 0; + } else { + int sh_mem_index = BLOCK_SIZE - offset - 1; + T data = idata[in_index]; + sh_mem[sh_mem_index] = data; + } + + __syncthreads(); + int out_index = dst_base - offset; + if (offset < valid_item) { + int sh_mem_index = BLOCK_SIZE - offset - 1; + odata[out_index] = sh_mem[sh_mem_index]; } } -// inclusive scan -template -__global__ void InnerMostDimInclusiveScan(const T* in, T* out, - int inner_dim_size, - int outer_dim_size, int scan_dim_size, - bool reverse) { - __shared__ T share_data[num_threads_y][num_threads_x * 2]; - T* share_row = share_data[threadIdx.y]; - int forward_direction = 1; - if (reverse) forward_direction = -1; - - for (int block_row = blockIdx.x * blockDim.y; block_row < outer_dim_size; - block_row += blockDim.y * gridDim.x) { - int row = block_row + threadIdx.y; - T acc = 0; - const T* row_src = in + row * scan_dim_size; - T* row_dst = out + row * scan_dim_size; - int block_col = 0; - bool loop_condition = (block_col < scan_dim_size); - if (reverse) { - loop_condition = (block_col >= 0); - block_col = scan_dim_size - 1; +template +__global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data, + int reverse_size, int outer_size, + int inner_size) { + int bx = blockIdx.x; + int by = blockIdx.y; + int item_per_block = 1024; + + for (int block_offset = 0; block_offset < reverse_size; + block_offset += item_per_block) { + int valid_item = (reverse_size - block_offset > item_per_block) + ? item_per_block + : reverse_size - block_offset; + int src_offset = + bx * reverse_size + block_offset + by * (inner_size * reverse_size); + int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) + + reverse_size - 1 - block_offset; + if (reverse_size < item_per_block) { + valid_item = reverse_size; } - while (loop_condition) { - // Load data into share memory(two value per thread) - int col1 = block_col + threadIdx.x * forward_direction; - int col2 = block_col + (num_threads_x + threadIdx.x) * forward_direction; - if (row < outer_dim_size) { - if (col1 < scan_dim_size && col1 >= 0) { - share_row[threadIdx.x] = row_src[col1]; - } else { - share_row[threadIdx.x] = 0; - } - if (col2 < scan_dim_size && col2 >= 0) { - share_row[num_threads_x + threadIdx.x] = row_src[col2]; - } else { - share_row[num_threads_x + threadIdx.x] = 0; - } - - // Add the previous block acc to the result - if (threadIdx.x == 0) { - share_row[0] = share_row[0] + acc; - } - } - __syncthreads(); - - // Up-Sweep - for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { - if (row < outer_dim_size && threadIdx.x < s) { - unsigned offset = (2 * threadIdx.x + 1) * d - 1; - share_row[offset + d] = share_row[offset] + share_row[offset + d]; - } - __syncthreads(); - } - // Down-Sweep - for (unsigned s = 2, d = blockDim.x / 2; d >= 1; s <<= 1, d >>= 1) { - if (row < outer_dim_size && threadIdx.x < s - 1) { - unsigned offset = 2 * (threadIdx.x + 1) * d - 1; - share_row[offset + d] = share_row[offset] + share_row[offset + d]; - } - __syncthreads(); - } - - // Write to the output - if (row < outer_dim_size) { - if (col1 < scan_dim_size && col1 >= 0) - row_dst[col1] = share_row[threadIdx.x]; - if (col2 < scan_dim_size && col2 >= 0) - row_dst[col2] = share_row[num_threads_x + threadIdx.x]; - } - acc = share_row[2 * num_threads_x - 1]; - __syncthreads(); - block_col += 2 * num_threads_x * forward_direction; - if (reverse) - loop_condition = (block_col >= 0); - else - loop_condition = (block_col < scan_dim_size); - } + BlockReverse(matrix_data, reverse_data, src_offset, dst_offset, + valid_item); } } -// exclusive block scan and store block sum for large scan template -__global__ void InnerMostDimExclusiveScan(const T* in, T* out, T* sum_data, - int inner_dim_size, - int outer_dim_size, int scan_dim_size, - int two_power, bool reverse) { - // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory - extern __shared__ __align__(sizeof(T)) unsigned char raw_tmp[]; - T* share_tmp = reinterpret_cast(raw_tmp); - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - int block_scan_size = blockDim.x * 2; - int remain = scan_dim_size % (2 * blockDim.x); - if (block_id == gridDim.x - 1 && remain != 0) block_scan_size = remain; - int col1 = thread_id; - int col2 = thread_id + (block_scan_size) / 2; - int index1 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col1; - int index2 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col2; - if (reverse) { - index1 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - - (block_id * blockDim.x * 2 + col1); - index2 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - - (block_id * blockDim.x * 2 + col2); - } - int sum_index = blockIdx.y * gridDim.x + block_id; - if (thread_id < block_scan_size) { - share_tmp[col1 + (col1 >> 5)] = in[index1]; - share_tmp[col2 + (col2 >> 5)] = in[index2]; - } else { - share_tmp[col1 + (col1 >> 5)] = 0; - share_tmp[col2 + (col2 >> 5)] = 0; +struct BlockPrefixCallbackOp { + // Running prefix + T running_total; + // Constructor + __device__ BlockPrefixCallbackOp(T running_total) + : running_total(running_total) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ T operator()(T block_aggregate) { + T old_prefix = running_total; + running_total = old_prefix + block_aggregate; + return old_prefix; } +}; - // Up-Sweep - int offset = 1; - for (int d = (two_power / 2); d > 0; d >>= 1) { - __syncthreads(); - if (thread_id < d) { - int tmp_index1 = offset * (2 * thread_id + 1) - 1; - int tmp_index2 = offset * (2 * thread_id + 2) - 1; - tmp_index1 = tmp_index1 + (tmp_index1 >> 5); - tmp_index2 = tmp_index2 + (tmp_index2 >> 5); - - share_tmp[tmp_index2] += share_tmp[tmp_index1]; +// No bank-conflict transpose +// Same as transposeCoalesced except the first tile dimension is padded +// to avoid shared memory bank conflicts. +template +__global__ void MatrixTranspose(T* odata, const T* idata, size_t height, + size_t width) { + __shared__ T tile[TILE_DIM][TILE_DIM + 1]; + + int x = blockIdx.x * TILE_DIM + threadIdx.x; + int y = blockIdx.y * TILE_DIM + threadIdx.y; + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + if (x < width && (y + j) < height) { + tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x]; + } else { + tile[threadIdx.y + j][threadIdx.x] = 0; } - offset *= 2; } + __syncthreads(); - if (thread_id == 0) { - int tmp_index = (two_power - 1) + ((two_power - 1) >> 5); - sum_data[sum_index] = share_tmp[tmp_index]; - share_tmp[tmp_index] = 0; - } + x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset + y = blockIdx.x * TILE_DIM + threadIdx.y; - // Down Sweep - for (int d = 1; d < two_power; d *= 2) { - offset >>= 1; - __syncthreads(); - if (thread_id < d) { - int tmp_index1 = offset * (2 * thread_id + 1) - 1; - int tmp_index2 = offset * (2 * thread_id + 2) - 1; - tmp_index1 = tmp_index1 + (tmp_index1 >> 5); - tmp_index2 = tmp_index2 + (tmp_index2 >> 5); - - T tmp = share_tmp[tmp_index1]; - share_tmp[tmp_index1] = share_tmp[tmp_index2]; - share_tmp[tmp_index2] += tmp; + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + if (x < height && (y + j) < width) { + odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j]; } } +} - __syncthreads(); +template +__global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size, + int outer_size, int scan_size, bool exclusive) { + // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types + typedef cub::BlockLoad + BlockLoadT; + typedef cub::BlockStore + BlockStoreT; + typedef cub::BlockScan BlockScanT; + // Allocate type-safe, repurposable shared memory for collectives + __shared__ union { + typename BlockLoadT::TempStorage load; + typename BlockStoreT::TempStorage store; + typename BlockScanT::TempStorage scan; + } temp_storage; + + int bx = blockIdx.x; + int by = blockIdx.y; + + BlockPrefixCallbackOp prefix_op(0); + T block_aggregate = static_cast(0); + + // Obtain this block's segment of consecutive keys (blocked across threads) + int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; + for (int block_offset = 0; block_offset < scan_size; + block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) { + int valid_item = (scan_size - block_offset > item_per_block) + ? item_per_block + : (scan_size - block_offset); + if (scan_size < item_per_block) { + valid_item = scan_size; + } - if (col1 < block_scan_size) out[index1] = share_tmp[col1 + (col1 >> 5)]; - if (col2 < block_scan_size) out[index2] = share_tmp[col2 + (col2 >> 5)]; -} + int offset = bx * scan_size + block_offset + by * (inner_size * scan_size); -// for large scan_dim_size array we need to add for correct result -template -__global__ void AddBlockScan(T* result, T* sum, int size, int scan_dim_size, - int sum_size, bool reverse) { - int idx = threadIdx.x + blockDim.x * (blockIdx.x + blockIdx.y * gridDim.x); - int block_id_start = blockIdx.y * sum_size; - int block_id_end = blockIdx.x + blockIdx.y * sum_size; - int block_id = blockIdx.x; - int thread_id = threadIdx.x; - - int col = block_id * blockDim.x + thread_id + size; - int index = blockIdx.y * (scan_dim_size) + col; - if (reverse) { - index = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - col; - } + T thread_keys[ITEMS_PER_THREAD]; + BlockLoadT(temp_storage.load) + .Load(d_in + offset, thread_keys, valid_item, 0); - if (col >= scan_dim_size || col < 0) return; - for (int i = block_id_start; i <= block_id_end; i++) { - result[index] += sum[i]; + __syncthreads(); + if (exclusive) { + T init_value = static_cast(0); + BlockScanT(temp_storage.scan) + .ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); + } else { + BlockScanT(temp_storage.scan) + .InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); + } + __syncthreads(); + + BlockStoreT(temp_storage.store) + .Store(d_out + offset, thread_keys, valid_item); } } @@ -298,72 +234,79 @@ class CumCUDAKernel : public framework::OpKernel { return; } - const int& scan_dim_size = out_dims[axis]; - bool optimize_condition = (axis == (out_dims.size() - 1)) ? true : false; - int outer_dim_size = 1; - int inner_dim_size = 1; - // treat all dim index < axis as outer_dim_size - for (size_t i = 0; i < axis; i++) { - outer_dim_size *= out_dims[i]; + size_t height = 1; + size_t width = 1; + for (size_t i = 0; i <= axis; i++) { + height *= out_dims[i]; } - // treat all dim index > axis as innner_dim_size + for (size_t i = axis + 1; i < out_dims.size(); i++) { - inner_dim_size *= out_dims[i]; + width *= out_dims[i]; } + int scan_size = out_dims[axis]; + bool transpose = (axis != out_dims.size() - 1); + int tile_size = 32; + dim3 blocks(32, 8); + dim3 transpose_grids((width + tile_size - 1) / tile_size, + (height + tile_size - 1) / tile_size); auto& dev_ctx = context.template device_context(); - if (optimize_condition) { - auto nextPowerOfTwo = [](int x) -> int { - int ret = 1; - while (ret < x) ret = ret * 2; - return ret; - }; - if (exclusive) { - int element_per_block = nextPowerOfTwo(scan_dim_size) / 2; - if (element_per_block > 512 || element_per_block < 32) { - element_per_block = 64; - } - int two_power = element_per_block * 2; - dim3 block(element_per_block); - dim3 grid(((scan_dim_size + 1) / 2 + block.x - 1) / block.x, - outer_dim_size); - int offset_size = (element_per_block * 2) >> 5; - int share_mem_size = (element_per_block * 2 + offset_size) * sizeof(T); - Tensor scan_sum; - paddle::framework::DDim dims{ - ((scan_dim_size + 1) / 2 + block.x - 1) / block.x, outer_dim_size}; - scan_sum.Resize(dims); - T* sum_data = scan_sum.mutable_data(context.GetPlace()); - InnerMostDimExclusiveScan< - T><<>>( - in_data, out_data, sum_data, inner_dim_size, outer_dim_size, - scan_dim_size, two_power, reverse); - // for large scan array we need to do add for correct result - int element_size = element_per_block * 2; - if (scan_dim_size > element_size) { - dim3 sum_block(element_per_block * 2); - dim3 sum_grid((scan_dim_size - element_size + block.x - 1) / block.x, - outer_dim_size); - int sum_size = ((scan_dim_size + 1) / 2 + block.x - 1) / block.x; - AddBlockScan<<>>( - out_data, sum_data, element_size, scan_dim_size, sum_size, - reverse); - } - + Tensor tmp; + tmp.Resize(out_dims); + auto* tmp_data = tmp.mutable_data(context.GetPlace()); + T* next_in_data = out_data; + T* next_out_data = tmp_data; + if (transpose) { + MatrixTranspose<<>>( + out_data, in_data, height, width); + next_in_data = out_data; + next_out_data = tmp_data; + } + auto swap_ptr = [](T*& ptr1, T*& ptr2) { + T* tmp = ptr2; + ptr2 = ptr1; + ptr1 = tmp; + }; + int outer_size = height / scan_size; + int inner_size = width; + // Consider the size of shared memory, here block size is 128 + dim3 scan_grid(outer_size, inner_size); + dim3 reverse_grid = scan_grid; + if (reverse) { + if (transpose) { + reverse_grid.x = scan_grid.y; + reverse_grid.y = scan_grid.x; + MatrixRowReverse<<>>( + next_in_data, next_out_data, scan_size, outer_size, inner_size); + if (!transpose) next_in_data = tmp_data; + swap_ptr(next_in_data, next_out_data); } else { - dim3 block(32, 16); - dim3 grid((outer_dim_size + block.y - 1) / block.y); - InnerMostDimInclusiveScan<<>>( - in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size, - reverse); + MatrixRowReverse<<>>( + in_data, out_data, scan_size, outer_size, inner_size); } + } + if (!transpose && !reverse) { + BlockScanKernel<<>>( + out_data, in_data, outer_size, inner_size, scan_size, exclusive); + } else { - dim3 block(std::min(512, inner_dim_size)); - dim3 grid(outer_dim_size, (inner_dim_size + block.x - 1) / block.x); - OuterScan<<>>( - in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size, - exclusive, reverse); + BlockScanKernel<<>>( + next_out_data, next_in_data, outer_size, inner_size, scan_size, + exclusive); + } + swap_ptr(next_in_data, next_out_data); + if (reverse) { + MatrixRowReverse<<>>( + next_in_data, next_out_data, scan_size, outer_size, inner_size); + swap_ptr(next_in_data, next_out_data); + } + if (transpose) { + transpose_grids.x = (height + tile_size - 1) / tile_size; + transpose_grids.y = (width + tile_size - 1) / tile_size; + MatrixTranspose<<>>( + next_out_data, next_in_data, width, height); } } }; -- GitLab