未验证 提交 b818429a 编写于 作者: W wangchaochaohu 提交者: GitHub

optimize cumsum OP (#29193)

上级 27b42183
...@@ -14,8 +14,10 @@ limitations under the License. */ ...@@ -14,8 +14,10 @@ limitations under the License. */
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/reverse.h> #include <thrust/reverse.h>
#include <thrust/scan.h> #include <thrust/scan.h>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/cum_op.h" #include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_config.h" #include "paddle/fluid/platform/gpu_launch_config.h"
...@@ -25,223 +27,157 @@ using LoDTensor = paddle::framework::LoDTensor; ...@@ -25,223 +27,157 @@ using LoDTensor = paddle::framework::LoDTensor;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, int BLOCK_SIZE>
__global__ void OuterScan(const T* in, T* out, int inner_dim_size, __device__ void BlockReverse(const T* idata, T* odata, int src_base,
int outer_dim_size, int scan_dim_size, bool exclusive, int dst_base, int valid_item) {
bool reverse) { __shared__ T sh_mem[BLOCK_SIZE];
int id = blockIdx.y * blockDim.x + threadIdx.x; int tx = threadIdx.x;
for (int outer_index = blockIdx.x; outer_index < outer_dim_size; int offset = tx;
outer_index += gridDim.x) { int in_index = src_base + offset;
for (int inner_index = blockIdx.y * blockDim.x + threadIdx.x; if (offset >= valid_item) {
inner_index < inner_dim_size; inner_index += gridDim.y * blockDim.x) { sh_mem[offset] = 0;
int scan_index_init = 0; } else {
int forward_direction = 1; int sh_mem_index = BLOCK_SIZE - offset - 1;
int src_index = T data = idata[in_index];
outer_index * scan_dim_size * inner_dim_size + inner_index; sh_mem[sh_mem_index] = data;
int dst_index = }
outer_index * scan_dim_size * inner_dim_size + inner_index;
if (reverse) { __syncthreads();
src_index = src_index + (scan_dim_size - 1) * inner_dim_size; int out_index = dst_base - offset;
dst_index = dst_index + (scan_dim_size - 1) * inner_dim_size; if (offset < valid_item) {
forward_direction = -1; int sh_mem_index = BLOCK_SIZE - offset - 1;
} odata[out_index] = sh_mem[sh_mem_index];
if (exclusive) {
scan_index_init = 1;
out[dst_index] = 0;
dst_index = dst_index + (forward_direction * inner_dim_size);
}
T acc = 0;
for (int scan_index = scan_index_init; scan_index < scan_dim_size;
++scan_index) {
acc = in[src_index] + acc;
out[dst_index] = acc;
src_index += (forward_direction * inner_dim_size);
dst_index += (forward_direction * inner_dim_size);
}
}
} }
} }
// inclusive scan template <typename T>
template <typename T, int num_threads_x, int num_threads_y> __global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data,
__global__ void InnerMostDimInclusiveScan(const T* in, T* out, int reverse_size, int outer_size,
int inner_dim_size, int inner_size) {
int outer_dim_size, int scan_dim_size, int bx = blockIdx.x;
bool reverse) { int by = blockIdx.y;
__shared__ T share_data[num_threads_y][num_threads_x * 2]; int item_per_block = 1024;
T* share_row = share_data[threadIdx.y];
int forward_direction = 1; for (int block_offset = 0; block_offset < reverse_size;
if (reverse) forward_direction = -1; block_offset += item_per_block) {
int valid_item = (reverse_size - block_offset > item_per_block)
for (int block_row = blockIdx.x * blockDim.y; block_row < outer_dim_size; ? item_per_block
block_row += blockDim.y * gridDim.x) { : reverse_size - block_offset;
int row = block_row + threadIdx.y; int src_offset =
T acc = 0; bx * reverse_size + block_offset + by * (inner_size * reverse_size);
const T* row_src = in + row * scan_dim_size; int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) +
T* row_dst = out + row * scan_dim_size; reverse_size - 1 - block_offset;
int block_col = 0; if (reverse_size < item_per_block) {
bool loop_condition = (block_col < scan_dim_size); valid_item = reverse_size;
if (reverse) {
loop_condition = (block_col >= 0);
block_col = scan_dim_size - 1;
} }
while (loop_condition) {
// Load data into share memory(two value per thread)
int col1 = block_col + threadIdx.x * forward_direction;
int col2 = block_col + (num_threads_x + threadIdx.x) * forward_direction;
if (row < outer_dim_size) {
if (col1 < scan_dim_size && col1 >= 0) {
share_row[threadIdx.x] = row_src[col1];
} else {
share_row[threadIdx.x] = 0;
}
if (col2 < scan_dim_size && col2 >= 0) { BlockReverse<T, 1024>(matrix_data, reverse_data, src_offset, dst_offset,
share_row[num_threads_x + threadIdx.x] = row_src[col2]; valid_item);
} else {
share_row[num_threads_x + threadIdx.x] = 0;
}
// Add the previous block acc to the result
if (threadIdx.x == 0) {
share_row[0] = share_row[0] + acc;
}
}
__syncthreads();
// Up-Sweep
for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < outer_dim_size && threadIdx.x < s) {
unsigned offset = (2 * threadIdx.x + 1) * d - 1;
share_row[offset + d] = share_row[offset] + share_row[offset + d];
}
__syncthreads();
}
// Down-Sweep
for (unsigned s = 2, d = blockDim.x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < outer_dim_size && threadIdx.x < s - 1) {
unsigned offset = 2 * (threadIdx.x + 1) * d - 1;
share_row[offset + d] = share_row[offset] + share_row[offset + d];
}
__syncthreads();
}
// Write to the output
if (row < outer_dim_size) {
if (col1 < scan_dim_size && col1 >= 0)
row_dst[col1] = share_row[threadIdx.x];
if (col2 < scan_dim_size && col2 >= 0)
row_dst[col2] = share_row[num_threads_x + threadIdx.x];
}
acc = share_row[2 * num_threads_x - 1];
__syncthreads();
block_col += 2 * num_threads_x * forward_direction;
if (reverse)
loop_condition = (block_col >= 0);
else
loop_condition = (block_col < scan_dim_size);
}
} }
} }
// exclusive block scan and store block sum for large scan
template <typename T> template <typename T>
__global__ void InnerMostDimExclusiveScan(const T* in, T* out, T* sum_data, struct BlockPrefixCallbackOp {
int inner_dim_size, // Running prefix
int outer_dim_size, int scan_dim_size, T running_total;
int two_power, bool reverse) { // Constructor
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory __device__ BlockPrefixCallbackOp(T running_total)
extern __shared__ __align__(sizeof(T)) unsigned char raw_tmp[]; : running_total(running_total) {}
T* share_tmp = reinterpret_cast<T*>(raw_tmp); // Callback operator to be entered by the first warp of threads in the block.
int thread_id = threadIdx.x; // Thread-0 is responsible for returning a value for seeding the block-wide
int block_id = blockIdx.x; // scan.
int block_scan_size = blockDim.x * 2; __device__ T operator()(T block_aggregate) {
int remain = scan_dim_size % (2 * blockDim.x); T old_prefix = running_total;
if (block_id == gridDim.x - 1 && remain != 0) block_scan_size = remain; running_total = old_prefix + block_aggregate;
int col1 = thread_id; return old_prefix;
int col2 = thread_id + (block_scan_size) / 2;
int index1 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col1;
int index2 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col2;
if (reverse) {
index1 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 -
(block_id * blockDim.x * 2 + col1);
index2 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 -
(block_id * blockDim.x * 2 + col2);
}
int sum_index = blockIdx.y * gridDim.x + block_id;
if (thread_id < block_scan_size) {
share_tmp[col1 + (col1 >> 5)] = in[index1];
share_tmp[col2 + (col2 >> 5)] = in[index2];
} else {
share_tmp[col1 + (col1 >> 5)] = 0;
share_tmp[col2 + (col2 >> 5)] = 0;
} }
};
// Up-Sweep // No bank-conflict transpose
int offset = 1; // Same as transposeCoalesced except the first tile dimension is padded
for (int d = (two_power / 2); d > 0; d >>= 1) { // to avoid shared memory bank conflicts.
__syncthreads(); template <typename T, int TILE_DIM, int BLOCK_ROWS>
if (thread_id < d) { __global__ void MatrixTranspose(T* odata, const T* idata, size_t height,
int tmp_index1 = offset * (2 * thread_id + 1) - 1; size_t width) {
int tmp_index2 = offset * (2 * thread_id + 2) - 1; __shared__ T tile[TILE_DIM][TILE_DIM + 1];
tmp_index1 = tmp_index1 + (tmp_index1 >> 5);
tmp_index2 = tmp_index2 + (tmp_index2 >> 5); int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
share_tmp[tmp_index2] += share_tmp[tmp_index1]; for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < width && (y + j) < height) {
tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x];
} else {
tile[threadIdx.y + j][threadIdx.x] = 0;
} }
offset *= 2;
} }
__syncthreads(); __syncthreads();
if (thread_id == 0) { x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
int tmp_index = (two_power - 1) + ((two_power - 1) >> 5); y = blockIdx.x * TILE_DIM + threadIdx.y;
sum_data[sum_index] = share_tmp[tmp_index];
share_tmp[tmp_index] = 0;
}
// Down Sweep for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
for (int d = 1; d < two_power; d *= 2) { if (x < height && (y + j) < width) {
offset >>= 1; odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j];
__syncthreads();
if (thread_id < d) {
int tmp_index1 = offset * (2 * thread_id + 1) - 1;
int tmp_index2 = offset * (2 * thread_id + 2) - 1;
tmp_index1 = tmp_index1 + (tmp_index1 >> 5);
tmp_index2 = tmp_index2 + (tmp_index2 >> 5);
T tmp = share_tmp[tmp_index1];
share_tmp[tmp_index1] = share_tmp[tmp_index2];
share_tmp[tmp_index2] += tmp;
} }
} }
}
__syncthreads(); template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD>
__global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size,
int outer_size, int scan_size, bool exclusive) {
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT;
typedef cub::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT;
typedef cub::BlockScan<T, BLOCK_THREADS> BlockScanT;
// Allocate type-safe, repurposable shared memory for collectives
__shared__ union {
typename BlockLoadT::TempStorage load;
typename BlockStoreT::TempStorage store;
typename BlockScanT::TempStorage scan;
} temp_storage;
int bx = blockIdx.x;
int by = blockIdx.y;
BlockPrefixCallbackOp<T> prefix_op(0);
T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
for (int block_offset = 0; block_offset < scan_size;
block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
int valid_item = (scan_size - block_offset > item_per_block)
? item_per_block
: (scan_size - block_offset);
if (scan_size < item_per_block) {
valid_item = scan_size;
}
if (col1 < block_scan_size) out[index1] = share_tmp[col1 + (col1 >> 5)]; int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
if (col2 < block_scan_size) out[index2] = share_tmp[col2 + (col2 >> 5)];
}
// for large scan_dim_size array we need to add for correct result T thread_keys[ITEMS_PER_THREAD];
template <typename T> BlockLoadT(temp_storage.load)
__global__ void AddBlockScan(T* result, T* sum, int size, int scan_dim_size, .Load(d_in + offset, thread_keys, valid_item, 0);
int sum_size, bool reverse) {
int idx = threadIdx.x + blockDim.x * (blockIdx.x + blockIdx.y * gridDim.x);
int block_id_start = blockIdx.y * sum_size;
int block_id_end = blockIdx.x + blockIdx.y * sum_size;
int block_id = blockIdx.x;
int thread_id = threadIdx.x;
int col = block_id * blockDim.x + thread_id + size;
int index = blockIdx.y * (scan_dim_size) + col;
if (reverse) {
index = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - col;
}
if (col >= scan_dim_size || col < 0) return; __syncthreads();
for (int i = block_id_start; i <= block_id_end; i++) { if (exclusive) {
result[index] += sum[i]; T init_value = static_cast<T>(0);
BlockScanT(temp_storage.scan)
.ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op);
} else {
BlockScanT(temp_storage.scan)
.InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op);
}
__syncthreads();
BlockStoreT(temp_storage.store)
.Store(d_out + offset, thread_keys, valid_item);
} }
} }
...@@ -298,72 +234,79 @@ class CumCUDAKernel : public framework::OpKernel<T> { ...@@ -298,72 +234,79 @@ class CumCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
const int& scan_dim_size = out_dims[axis]; size_t height = 1;
bool optimize_condition = (axis == (out_dims.size() - 1)) ? true : false; size_t width = 1;
int outer_dim_size = 1; for (size_t i = 0; i <= axis; i++) {
int inner_dim_size = 1; height *= out_dims[i];
// treat all dim index < axis as outer_dim_size
for (size_t i = 0; i < axis; i++) {
outer_dim_size *= out_dims[i];
} }
// treat all dim index > axis as innner_dim_size
for (size_t i = axis + 1; i < out_dims.size(); i++) { for (size_t i = axis + 1; i < out_dims.size(); i++) {
inner_dim_size *= out_dims[i]; width *= out_dims[i];
} }
int scan_size = out_dims[axis];
bool transpose = (axis != out_dims.size() - 1);
int tile_size = 32;
dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
if (optimize_condition) { Tensor tmp;
auto nextPowerOfTwo = [](int x) -> int { tmp.Resize(out_dims);
int ret = 1; auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
while (ret < x) ret = ret * 2; T* next_in_data = out_data;
return ret; T* next_out_data = tmp_data;
}; if (transpose) {
if (exclusive) { MatrixTranspose<T, 32,
int element_per_block = nextPowerOfTwo(scan_dim_size) / 2; 8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
if (element_per_block > 512 || element_per_block < 32) { out_data, in_data, height, width);
element_per_block = 64; next_in_data = out_data;
} next_out_data = tmp_data;
int two_power = element_per_block * 2; }
dim3 block(element_per_block); auto swap_ptr = [](T*& ptr1, T*& ptr2) {
dim3 grid(((scan_dim_size + 1) / 2 + block.x - 1) / block.x, T* tmp = ptr2;
outer_dim_size); ptr2 = ptr1;
int offset_size = (element_per_block * 2) >> 5; ptr1 = tmp;
int share_mem_size = (element_per_block * 2 + offset_size) * sizeof(T); };
Tensor scan_sum; int outer_size = height / scan_size;
paddle::framework::DDim dims{ int inner_size = width;
((scan_dim_size + 1) / 2 + block.x - 1) / block.x, outer_dim_size}; // Consider the size of shared memory, here block size is 128
scan_sum.Resize(dims); dim3 scan_grid(outer_size, inner_size);
T* sum_data = scan_sum.mutable_data<T>(context.GetPlace()); dim3 reverse_grid = scan_grid;
InnerMostDimExclusiveScan< if (reverse) {
T><<<grid, block, share_mem_size, dev_ctx.stream()>>>( if (transpose) {
in_data, out_data, sum_data, inner_dim_size, outer_dim_size, reverse_grid.x = scan_grid.y;
scan_dim_size, two_power, reverse); reverse_grid.y = scan_grid.x;
// for large scan array we need to do add for correct result MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
int element_size = element_per_block * 2; next_in_data, next_out_data, scan_size, outer_size, inner_size);
if (scan_dim_size > element_size) { if (!transpose) next_in_data = tmp_data;
dim3 sum_block(element_per_block * 2); swap_ptr(next_in_data, next_out_data);
dim3 sum_grid((scan_dim_size - element_size + block.x - 1) / block.x,
outer_dim_size);
int sum_size = ((scan_dim_size + 1) / 2 + block.x - 1) / block.x;
AddBlockScan<T><<<sum_grid, sum_block, 0, dev_ctx.stream()>>>(
out_data, sum_data, element_size, scan_dim_size, sum_size,
reverse);
}
} else { } else {
dim3 block(32, 16); MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
dim3 grid((outer_dim_size + block.y - 1) / block.y); in_data, out_data, scan_size, outer_size, inner_size);
InnerMostDimInclusiveScan<T, 32,
16><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size,
reverse);
} }
}
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive);
} else { } else {
dim3 block(std::min(512, inner_dim_size)); BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
dim3 grid(outer_dim_size, (inner_dim_size + block.x - 1) / block.x); next_out_data, next_in_data, outer_size, inner_size, scan_size,
OuterScan<T><<<grid, block, 0, dev_ctx.stream()>>>( exclusive);
in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size, }
exclusive, reverse); swap_ptr(next_in_data, next_out_data);
if (reverse) {
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
next_in_data, next_out_data, scan_size, outer_size, inner_size);
swap_ptr(next_in_data, next_out_data);
}
if (transpose) {
transpose_grids.x = (height + tile_size - 1) / tile_size;
transpose_grids.y = (width + tile_size - 1) / tile_size;
MatrixTranspose<T, 32,
8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
next_out_data, next_in_data, width, height);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册