未验证 提交 b818429a 编写于 作者: W wangchaochaohu 提交者: GitHub

optimize cumsum OP (#29193)

上级 27b42183
......@@ -14,8 +14,10 @@ limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/cum_op.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
......@@ -25,223 +27,157 @@ using LoDTensor = paddle::framework::LoDTensor;
namespace paddle {
namespace operators {
template <typename T>
__global__ void OuterScan(const T* in, T* out, int inner_dim_size,
int outer_dim_size, int scan_dim_size, bool exclusive,
bool reverse) {
int id = blockIdx.y * blockDim.x + threadIdx.x;
for (int outer_index = blockIdx.x; outer_index < outer_dim_size;
outer_index += gridDim.x) {
for (int inner_index = blockIdx.y * blockDim.x + threadIdx.x;
inner_index < inner_dim_size; inner_index += gridDim.y * blockDim.x) {
int scan_index_init = 0;
int forward_direction = 1;
int src_index =
outer_index * scan_dim_size * inner_dim_size + inner_index;
int dst_index =
outer_index * scan_dim_size * inner_dim_size + inner_index;
if (reverse) {
src_index = src_index + (scan_dim_size - 1) * inner_dim_size;
dst_index = dst_index + (scan_dim_size - 1) * inner_dim_size;
forward_direction = -1;
}
if (exclusive) {
scan_index_init = 1;
out[dst_index] = 0;
dst_index = dst_index + (forward_direction * inner_dim_size);
}
T acc = 0;
for (int scan_index = scan_index_init; scan_index < scan_dim_size;
++scan_index) {
acc = in[src_index] + acc;
out[dst_index] = acc;
src_index += (forward_direction * inner_dim_size);
dst_index += (forward_direction * inner_dim_size);
}
}
}
}
// inclusive scan
template <typename T, int num_threads_x, int num_threads_y>
__global__ void InnerMostDimInclusiveScan(const T* in, T* out,
int inner_dim_size,
int outer_dim_size, int scan_dim_size,
bool reverse) {
__shared__ T share_data[num_threads_y][num_threads_x * 2];
T* share_row = share_data[threadIdx.y];
int forward_direction = 1;
if (reverse) forward_direction = -1;
for (int block_row = blockIdx.x * blockDim.y; block_row < outer_dim_size;
block_row += blockDim.y * gridDim.x) {
int row = block_row + threadIdx.y;
T acc = 0;
const T* row_src = in + row * scan_dim_size;
T* row_dst = out + row * scan_dim_size;
int block_col = 0;
bool loop_condition = (block_col < scan_dim_size);
if (reverse) {
loop_condition = (block_col >= 0);
block_col = scan_dim_size - 1;
}
while (loop_condition) {
// Load data into share memory(two value per thread)
int col1 = block_col + threadIdx.x * forward_direction;
int col2 = block_col + (num_threads_x + threadIdx.x) * forward_direction;
if (row < outer_dim_size) {
if (col1 < scan_dim_size && col1 >= 0) {
share_row[threadIdx.x] = row_src[col1];
} else {
share_row[threadIdx.x] = 0;
}
template <typename T, int BLOCK_SIZE>
__device__ void BlockReverse(const T* idata, T* odata, int src_base,
int dst_base, int valid_item) {
__shared__ T sh_mem[BLOCK_SIZE];
int tx = threadIdx.x;
if (col2 < scan_dim_size && col2 >= 0) {
share_row[num_threads_x + threadIdx.x] = row_src[col2];
int offset = tx;
int in_index = src_base + offset;
if (offset >= valid_item) {
sh_mem[offset] = 0;
} else {
share_row[num_threads_x + threadIdx.x] = 0;
int sh_mem_index = BLOCK_SIZE - offset - 1;
T data = idata[in_index];
sh_mem[sh_mem_index] = data;
}
// Add the previous block acc to the result
if (threadIdx.x == 0) {
share_row[0] = share_row[0] + acc;
}
}
__syncthreads();
// Up-Sweep
for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < outer_dim_size && threadIdx.x < s) {
unsigned offset = (2 * threadIdx.x + 1) * d - 1;
share_row[offset + d] = share_row[offset] + share_row[offset + d];
}
__syncthreads();
}
// Down-Sweep
for (unsigned s = 2, d = blockDim.x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < outer_dim_size && threadIdx.x < s - 1) {
unsigned offset = 2 * (threadIdx.x + 1) * d - 1;
share_row[offset + d] = share_row[offset] + share_row[offset + d];
}
__syncthreads();
int out_index = dst_base - offset;
if (offset < valid_item) {
int sh_mem_index = BLOCK_SIZE - offset - 1;
odata[out_index] = sh_mem[sh_mem_index];
}
}
// Write to the output
if (row < outer_dim_size) {
if (col1 < scan_dim_size && col1 >= 0)
row_dst[col1] = share_row[threadIdx.x];
if (col2 < scan_dim_size && col2 >= 0)
row_dst[col2] = share_row[num_threads_x + threadIdx.x];
}
acc = share_row[2 * num_threads_x - 1];
__syncthreads();
block_col += 2 * num_threads_x * forward_direction;
if (reverse)
loop_condition = (block_col >= 0);
else
loop_condition = (block_col < scan_dim_size);
}
template <typename T>
__global__ void MatrixRowReverse(const T* matrix_data, T* reverse_data,
int reverse_size, int outer_size,
int inner_size) {
int bx = blockIdx.x;
int by = blockIdx.y;
int item_per_block = 1024;
for (int block_offset = 0; block_offset < reverse_size;
block_offset += item_per_block) {
int valid_item = (reverse_size - block_offset > item_per_block)
? item_per_block
: reverse_size - block_offset;
int src_offset =
bx * reverse_size + block_offset + by * (inner_size * reverse_size);
int dst_offset = bx * reverse_size + by * (inner_size * reverse_size) +
reverse_size - 1 - block_offset;
if (reverse_size < item_per_block) {
valid_item = reverse_size;
}
BlockReverse<T, 1024>(matrix_data, reverse_data, src_offset, dst_offset,
valid_item);
}
}
// exclusive block scan and store block sum for large scan
template <typename T>
__global__ void InnerMostDimExclusiveScan(const T* in, T* out, T* sum_data,
int inner_dim_size,
int outer_dim_size, int scan_dim_size,
int two_power, bool reverse) {
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern __shared__ __align__(sizeof(T)) unsigned char raw_tmp[];
T* share_tmp = reinterpret_cast<T*>(raw_tmp);
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
int block_scan_size = blockDim.x * 2;
int remain = scan_dim_size % (2 * blockDim.x);
if (block_id == gridDim.x - 1 && remain != 0) block_scan_size = remain;
int col1 = thread_id;
int col2 = thread_id + (block_scan_size) / 2;
int index1 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col1;
int index2 = blockIdx.y * (scan_dim_size) + block_id * blockDim.x * 2 + col2;
if (reverse) {
index1 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 -
(block_id * blockDim.x * 2 + col1);
index2 = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 -
(block_id * blockDim.x * 2 + col2);
struct BlockPrefixCallbackOp {
// Running prefix
T running_total;
// Constructor
__device__ BlockPrefixCallbackOp(T running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total;
running_total = old_prefix + block_aggregate;
return old_prefix;
}
int sum_index = blockIdx.y * gridDim.x + block_id;
if (thread_id < block_scan_size) {
share_tmp[col1 + (col1 >> 5)] = in[index1];
share_tmp[col2 + (col2 >> 5)] = in[index2];
};
// No bank-conflict transpose
// Same as transposeCoalesced except the first tile dimension is padded
// to avoid shared memory bank conflicts.
template <typename T, int TILE_DIM, int BLOCK_ROWS>
__global__ void MatrixTranspose(T* odata, const T* idata, size_t height,
size_t width) {
__shared__ T tile[TILE_DIM][TILE_DIM + 1];
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < width && (y + j) < height) {
tile[threadIdx.y + j][threadIdx.x] = idata[(y + j) * width + x];
} else {
share_tmp[col1 + (col1 >> 5)] = 0;
share_tmp[col2 + (col2 >> 5)] = 0;
tile[threadIdx.y + j][threadIdx.x] = 0;
}
}
// Up-Sweep
int offset = 1;
for (int d = (two_power / 2); d > 0; d >>= 1) {
__syncthreads();
if (thread_id < d) {
int tmp_index1 = offset * (2 * thread_id + 1) - 1;
int tmp_index2 = offset * (2 * thread_id + 2) - 1;
tmp_index1 = tmp_index1 + (tmp_index1 >> 5);
tmp_index2 = tmp_index2 + (tmp_index2 >> 5);
share_tmp[tmp_index2] += share_tmp[tmp_index1];
x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
y = blockIdx.x * TILE_DIM + threadIdx.y;
for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < height && (y + j) < width) {
odata[(y + j) * height + x] = tile[threadIdx.x][threadIdx.y + j];
}
offset *= 2;
}
__syncthreads();
}
if (thread_id == 0) {
int tmp_index = (two_power - 1) + ((two_power - 1) >> 5);
sum_data[sum_index] = share_tmp[tmp_index];
share_tmp[tmp_index] = 0;
}
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD>
__global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size,
int outer_size, int scan_size, bool exclusive) {
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT;
typedef cub::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT;
typedef cub::BlockScan<T, BLOCK_THREADS> BlockScanT;
// Allocate type-safe, repurposable shared memory for collectives
__shared__ union {
typename BlockLoadT::TempStorage load;
typename BlockStoreT::TempStorage store;
typename BlockScanT::TempStorage scan;
} temp_storage;
int bx = blockIdx.x;
int by = blockIdx.y;
BlockPrefixCallbackOp<T> prefix_op(0);
T block_aggregate = static_cast<T>(0);
// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
for (int block_offset = 0; block_offset < scan_size;
block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) {
int valid_item = (scan_size - block_offset > item_per_block)
? item_per_block
: (scan_size - block_offset);
if (scan_size < item_per_block) {
valid_item = scan_size;
}
int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
.Load(d_in + offset, thread_keys, valid_item, 0);
// Down Sweep
for (int d = 1; d < two_power; d *= 2) {
offset >>= 1;
__syncthreads();
if (thread_id < d) {
int tmp_index1 = offset * (2 * thread_id + 1) - 1;
int tmp_index2 = offset * (2 * thread_id + 2) - 1;
tmp_index1 = tmp_index1 + (tmp_index1 >> 5);
tmp_index2 = tmp_index2 + (tmp_index2 >> 5);
T tmp = share_tmp[tmp_index1];
share_tmp[tmp_index1] = share_tmp[tmp_index2];
share_tmp[tmp_index2] += tmp;
}
if (exclusive) {
T init_value = static_cast<T>(0);
BlockScanT(temp_storage.scan)
.ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op);
} else {
BlockScanT(temp_storage.scan)
.InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op);
}
__syncthreads();
if (col1 < block_scan_size) out[index1] = share_tmp[col1 + (col1 >> 5)];
if (col2 < block_scan_size) out[index2] = share_tmp[col2 + (col2 >> 5)];
}
// for large scan_dim_size array we need to add for correct result
template <typename T>
__global__ void AddBlockScan(T* result, T* sum, int size, int scan_dim_size,
int sum_size, bool reverse) {
int idx = threadIdx.x + blockDim.x * (blockIdx.x + blockIdx.y * gridDim.x);
int block_id_start = blockIdx.y * sum_size;
int block_id_end = blockIdx.x + blockIdx.y * sum_size;
int block_id = blockIdx.x;
int thread_id = threadIdx.x;
int col = block_id * blockDim.x + thread_id + size;
int index = blockIdx.y * (scan_dim_size) + col;
if (reverse) {
index = blockIdx.y * (scan_dim_size) + scan_dim_size - 1 - col;
}
if (col >= scan_dim_size || col < 0) return;
for (int i = block_id_start; i <= block_id_end; i++) {
result[index] += sum[i];
BlockStoreT(temp_storage.store)
.Store(d_out + offset, thread_keys, valid_item);
}
}
......@@ -298,72 +234,79 @@ class CumCUDAKernel : public framework::OpKernel<T> {
return;
}
const int& scan_dim_size = out_dims[axis];
bool optimize_condition = (axis == (out_dims.size() - 1)) ? true : false;
int outer_dim_size = 1;
int inner_dim_size = 1;
// treat all dim index < axis as outer_dim_size
for (size_t i = 0; i < axis; i++) {
outer_dim_size *= out_dims[i];
size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
height *= out_dims[i];
}
// treat all dim index > axis as innner_dim_size
for (size_t i = axis + 1; i < out_dims.size(); i++) {
inner_dim_size *= out_dims[i];
width *= out_dims[i];
}
int scan_size = out_dims[axis];
bool transpose = (axis != out_dims.size() - 1);
int tile_size = 32;
dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
auto& dev_ctx = context.template device_context<DeviceContext>();
if (optimize_condition) {
auto nextPowerOfTwo = [](int x) -> int {
int ret = 1;
while (ret < x) ret = ret * 2;
return ret;
Tensor tmp;
tmp.Resize(out_dims);
auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
T* next_in_data = out_data;
T* next_out_data = tmp_data;
if (transpose) {
MatrixTranspose<T, 32,
8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
out_data, in_data, height, width);
next_in_data = out_data;
next_out_data = tmp_data;
}
auto swap_ptr = [](T*& ptr1, T*& ptr2) {
T* tmp = ptr2;
ptr2 = ptr1;
ptr1 = tmp;
};
if (exclusive) {
int element_per_block = nextPowerOfTwo(scan_dim_size) / 2;
if (element_per_block > 512 || element_per_block < 32) {
element_per_block = 64;
int outer_size = height / scan_size;
int inner_size = width;
// Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
if (reverse) {
if (transpose) {
reverse_grid.x = scan_grid.y;
reverse_grid.y = scan_grid.x;
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
next_in_data, next_out_data, scan_size, outer_size, inner_size);
if (!transpose) next_in_data = tmp_data;
swap_ptr(next_in_data, next_out_data);
} else {
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
in_data, out_data, scan_size, outer_size, inner_size);
}
int two_power = element_per_block * 2;
dim3 block(element_per_block);
dim3 grid(((scan_dim_size + 1) / 2 + block.x - 1) / block.x,
outer_dim_size);
int offset_size = (element_per_block * 2) >> 5;
int share_mem_size = (element_per_block * 2 + offset_size) * sizeof(T);
Tensor scan_sum;
paddle::framework::DDim dims{
((scan_dim_size + 1) / 2 + block.x - 1) / block.x, outer_dim_size};
scan_sum.Resize(dims);
T* sum_data = scan_sum.mutable_data<T>(context.GetPlace());
InnerMostDimExclusiveScan<
T><<<grid, block, share_mem_size, dev_ctx.stream()>>>(
in_data, out_data, sum_data, inner_dim_size, outer_dim_size,
scan_dim_size, two_power, reverse);
// for large scan array we need to do add for correct result
int element_size = element_per_block * 2;
if (scan_dim_size > element_size) {
dim3 sum_block(element_per_block * 2);
dim3 sum_grid((scan_dim_size - element_size + block.x - 1) / block.x,
outer_dim_size);
int sum_size = ((scan_dim_size + 1) / 2 + block.x - 1) / block.x;
AddBlockScan<T><<<sum_grid, sum_block, 0, dev_ctx.stream()>>>(
out_data, sum_data, element_size, scan_dim_size, sum_size,
reverse);
}
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive);
} else {
dim3 block(32, 16);
dim3 grid((outer_dim_size + block.y - 1) / block.y);
InnerMostDimInclusiveScan<T, 32,
16><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size,
reverse);
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
next_out_data, next_in_data, outer_size, inner_size, scan_size,
exclusive);
}
} else {
dim3 block(std::min(512, inner_dim_size));
dim3 grid(outer_dim_size, (inner_dim_size + block.x - 1) / block.x);
OuterScan<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, out_data, inner_dim_size, outer_dim_size, scan_dim_size,
exclusive, reverse);
swap_ptr(next_in_data, next_out_data);
if (reverse) {
MatrixRowReverse<T><<<reverse_grid, 1024, 0, dev_ctx.stream()>>>(
next_in_data, next_out_data, scan_size, outer_size, inner_size);
swap_ptr(next_in_data, next_out_data);
}
if (transpose) {
transpose_grids.x = (height + tile_size - 1) / tile_size;
transpose_grids.y = (width + tile_size - 1) / tile_size;
MatrixTranspose<T, 32,
8><<<transpose_grids, blocks, 0, dev_ctx.stream()>>>(
next_out_data, next_in_data, width, height);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册