From 1a0b3661ae5ffea421cdbb7fe32ccd307294abe7 Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Mon, 9 Jan 2023 19:35:43 +0800 Subject: [PATCH] Add concat optimization (#49540) * add concat optimization * refine * remove annotation * use alignas instead of aligned_storage --- .../kernels/funcs/concat_and_split_functor.cu | 493 +++++++++++++----- 1 file changed, 376 insertions(+), 117 deletions(-) diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index d238853f6a2..af2ab0a8a70 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -20,18 +20,43 @@ limitations under the License. */ namespace phi { namespace funcs { +static inline void GetBlockDims(const phi::GPUContext& context, + int64_t num_rows, + int64_t num_cols, + dim3* block_dims, + dim3* grid_dims) { + // Set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((num_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + *block_dims = dim3(block_cols, block_rows, 1); + + constexpr int waves = 1; + int max_threads = context.GetMaxPhysicalThreadCount() * waves; + int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((num_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = std::min(max_blocks / grid_cols, + std::max(num_rows / block_rows, (int64_t)1)); + *grid_dims = dim3(grid_cols, grid_rows, 1); +} + template struct PointerWrapper { public: - const T* ins_addr[Size]; - __device__ inline const T* operator[](int i) const { return ins_addr[i]; } + const void* ins_addr[Size]; + __device__ inline const void* operator[](int i) const { return ins_addr[i]; } PointerWrapper() {} PointerWrapper(const phi::GPUContext& ctx, const std::vector& ins, const T** pre_alloced_host_ptr) { for (auto i = 0; i < ins.size(); ++i) { - ins_addr[i] = ins[i].data(); + ins_addr[i] = ins[i].data(); } } }; @@ -39,8 +64,8 @@ struct PointerWrapper { template struct PointerToPointer { public: - T** ins_addr{nullptr}; - __device__ inline const T* operator[](int i) const { return ins_addr[i]; } + void** ins_addr{nullptr}; + __device__ inline const void* operator[](int i) const { return ins_addr[i]; } PointerToPointer() {} PointerToPointer(const phi::GPUContext& ctx, @@ -63,7 +88,7 @@ struct PointerToPointer { restored, in_num * sizeof(T*), ctx.stream()); - ins_addr = reinterpret_cast((*dev_ins_ptr)->ptr()); + ins_addr = reinterpret_cast((*dev_ins_ptr)->ptr()); } }; @@ -82,7 +107,7 @@ struct PointerAndColWrapper { ins_ptr_wrapper = PointerWrapper(ctx, ins, pre_alloced_host_ptr); } - __device__ inline const T* operator[](int i) const { + __device__ inline const void* operator[](int i) const { return ins_ptr_wrapper[i]; } @@ -118,7 +143,7 @@ struct PointerToPointerAndCol { PointerToPointer(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr); } - __device__ inline const T* operator[](int i) const { + __device__ inline const void* operator[](int i) const { return ins_ptr_wrapper[i]; } @@ -126,16 +151,31 @@ struct PointerToPointerAndCol { PointerToPointer ins_ptr_wrapper; }; -template -__global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas, - int col_size, - const IndexT output_rows, - const IndexT output_cols, - T* output) { +template +struct alignas(MovSize) Packed { + __device__ Packed() { + // do nothing + } + union { + char buf[MovSize]; + }; +}; + +template +__global__ void ConcatTensorWithDifferentShape( + const PointerAndColWrapperT ins_datas, + int col_size, + const IndexT output_rows, + const IndexT output_cols, + void* output) { + Packed* dst = reinterpret_cast*>(output); + IndexT curr_segment = 0; IndexT curr_offset = ins_datas.col_length[0]; + CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) { IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1]; + while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; @@ -145,32 +185,335 @@ __global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas, IndexT local_col = tid_x - curr_offset; IndexT segment_width = curr_col_offset - curr_offset; - const T* input_ptr = ins_datas[curr_segment]; + const Packed* input_ptr = + reinterpret_cast*>(ins_datas[curr_segment]); + IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) - output[tid_y * output_cols + tid_x] = + + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { + dst[tid_y * output_cols + tid_x] = input_ptr[tid_y * segment_width + local_col]; + } } } -template -__global__ void ConcatTensorWithSameShape(PointerWrapperT ins_data, +template +__global__ void ConcatTensorWithSameShape(const PointerWrapperT ins_data, const IndexT fixed_in_col, const IndexT out_rows, const IndexT out_cols, - T* output_data) { + void* output_data) { + Packed* dst = reinterpret_cast*>(output_data); CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) { IndexT split = tid_x / fixed_in_col; IndexT in_offset = tid_x - split * fixed_in_col; - const T* input_ptr = ins_data[split]; + const Packed* input_ptr = + reinterpret_cast*>(ins_data[split]); IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { - output_data[tid_y * out_cols + tid_x] = + dst[tid_y * out_cols + tid_x] = input_ptr[tid_y * fixed_in_col + in_offset]; } } } +#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \ + func_impl(4, ##__VA_ARGS__); \ + func_impl(8, ##__VA_ARGS__); \ + func_impl(16, ##__VA_ARGS__); \ + func_impl(32, ##__VA_ARGS__); \ + func_impl(64, ##__VA_ARGS__); \ + func_impl(128, ##__VA_ARGS__); + +template +void DispatchConcatWithDifferentShapeKernelLimitNum( + const phi::GPUContext& ctx, + const std::vector& ins, + const IndexT inputs_col_num, + const T** inputs_data, + IndexT* inputs_col, + const IndexT out_row, + const IndexT out_col, + phi::DenseTensor* output, + const IndexT in_num, + const IndexT limit_num) { + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); + +#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ + case size_: { \ + PointerAndColWrapper ptr_col_array( \ + ctx, ins, inputs_col_num, inputs_data, inputs_col); \ + __VA_ARGS__; \ + } break; + switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { + IMPL_CONCATE_CUDA_KERNEL_HELPER( + IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE, + ConcatTensorWithDifferentShape + <<>>( + ptr_col_array, inputs_col_num, out_row, out_col, output->data())); + default: { + paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; + paddle::memory::AllocationPtr dev_col_ptr{nullptr}; + PointerToPointerAndCol ptr_col_array(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + &dev_ins_ptr, + &dev_col_ptr); + ConcatTensorWithDifferentShape + <<>>( + ptr_col_array, inputs_col_num, out_row, out_col, output->data()); + } + } +#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE +} + +template +void DispatchConcatWithDifferentShapeMovsize( + const phi::GPUContext& ctx, + const std::vector& ins, + const IndexT inputs_col_num, + const T** inputs_data, + IndexT* inputs_col, + const IndexT out_row, + const IndexT out_col, + phi::DenseTensor* output, + const IndexT mov_size, + const IndexT in_num, + const IndexT limit_num) { + if (mov_size == 16) { + DispatchConcatWithDifferentShapeKernelLimitNum( + ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else if (mov_size == 8) { + DispatchConcatWithDifferentShapeKernelLimitNum(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else if (mov_size == 4) { + DispatchConcatWithDifferentShapeKernelLimitNum(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else if (mov_size == 2) { + DispatchConcatWithDifferentShapeKernelLimitNum(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else { + DispatchConcatWithDifferentShapeKernelLimitNum(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + out_col, + output, + in_num, + limit_num); + } +} + +template +void DispatchConcatWithSameShapeKernelLimitNum( + const phi::GPUContext& ctx, + const std::vector& ins, + const T** inputs_data, + IndexT in_col, + const IndexT out_row, + const IndexT out_col, + phi::DenseTensor* output, + const IndexT in_num, + const IndexT limit_num) { + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); + +#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ + case size_: { \ + PointerWrapper ptr_array(ctx, ins, inputs_data); \ + __VA_ARGS__; \ + } break; + + switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { + IMPL_CONCATE_CUDA_KERNEL_HELPER( + IMPL_CONCAT_CUDA_KERNEL_CASE, + ConcatTensorWithSameShape + <<>>( + ptr_array, in_col, out_row, out_col, output->data())); + default: { + paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; + PointerToPointer ptr_array(ctx, ins, inputs_data, &dev_ins_ptr); + ConcatTensorWithSameShape + <<>>( + ptr_array, in_col, out_row, out_col, output->data()); + } + } +#undef IMPL_CONCAT_CUDA_KERNEL_CASE +} + +#undef IMPL_CONCATE_CUDA_KERNEL_HELPER + +template +void DispatchConcatWithSameShapeMovsize( + const phi::GPUContext& ctx, + const std::vector& ins, + const T** inputs_data, + IndexT in_col, + const IndexT out_row, + const IndexT out_col, + phi::DenseTensor* output, + const IndexT mov_size, + const IndexT in_num, + const IndexT limit_num) { + if (mov_size == 16) { + DispatchConcatWithSameShapeKernelLimitNum(ctx, + ins, + inputs_data, + in_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else if (mov_size == 8) { + DispatchConcatWithSameShapeKernelLimitNum(ctx, + ins, + inputs_data, + in_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else if (mov_size == 4) { + DispatchConcatWithSameShapeKernelLimitNum(ctx, + ins, + inputs_data, + in_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else if (mov_size == 2) { + DispatchConcatWithSameShapeKernelLimitNum(ctx, + ins, + inputs_data, + in_col, + out_row, + out_col, + output, + in_num, + limit_num); + } else { + DispatchConcatWithSameShapeKernelLimitNum(ctx, + ins, + inputs_data, + in_col, + out_row, + out_col, + output, + in_num, + limit_num); + } +} + +template +void DispatchConcatKernel(const phi::GPUContext& ctx, + const std::vector& ins, + const IndexT inputs_col_num, + const T** inputs_data, + IndexT* inputs_col, + const IndexT out_row, + const IndexT out_col, + phi::DenseTensor* output, + const IndexT in_num, + const IndexT limit_num, + bool has_same_shape) { + constexpr IndexT MaxVecSize = 16 / sizeof(T); + bool find_vecsize_flag = false; + IndexT dispatch_vec_size = 1; + for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) { + for (IndexT idx = 0; idx < in_num + 1; idx++) { + // Since input_cols[0] is 0, we need to jump. + const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx]; + if (input_col % vec_size == 0) { + if (idx == in_num - 1) { + find_vecsize_flag = true; + } + } else { + break; + } + } + if (find_vecsize_flag) { + dispatch_vec_size = vec_size; + break; + } + } + + const int64_t vectorized_out_col = out_col / dispatch_vec_size; + for (IndexT idx = 0; idx < in_num + 1; idx++) { + inputs_col[idx] /= dispatch_vec_size; + } + const IndexT mov_size = sizeof(T) * dispatch_vec_size; + if (has_same_shape) { + // In same shape situation, each input's col are equal, so here we select to + // use inputs_col[1]. + DispatchConcatWithSameShapeMovsize(ctx, + ins, + inputs_data, + inputs_col[1], + out_row, + vectorized_out_col, + output, + mov_size, + in_num, + limit_num); + } else { + DispatchConcatWithDifferentShapeMovsize(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + vectorized_out_col, + output, + mov_size, + in_num, + limit_num); + } +} + template __global__ void SplitKernel_(const T* input_data, const int64_t in_row, @@ -273,30 +616,6 @@ __global__ void SplitKernel_(const T* input_data, SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); } -static inline void GetBlockDims(const phi::GPUContext& context, - int64_t num_rows, - int64_t num_cols, - dim3* block_dims, - dim3* grid_dims) { - // Set the thread block and grid according to CurrentDeviceId - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((num_cols + 31) >> 5) << 5; - } - int block_rows = kThreadsPerBlock / block_cols; - *block_dims = dim3(block_cols, block_rows, 1); - - int max_threads = context.GetMaxPhysicalThreadCount(); - int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); - - int grid_cols = - std::min((num_cols + block_cols - 1) / block_cols, max_blocks); - int grid_rows = std::min(max_blocks / grid_cols, - std::max(num_rows / block_rows, (int64_t)1)); - *grid_dims = dim3(grid_cols, grid_rows, 1); -} - /* * All tensors' dimension should be the same and the values of * each dimension must be the same, except the axis dimension. @@ -340,79 +659,19 @@ void ConcatFunctorWithIndexType(const phi::GPUContext& ctx, out_col += t_cols; inputs_col[i + 1] = out_col; } - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); IndexT limit_num = has_same_shape ? in_num : inputs_col_num; -#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \ - func_impl(4, ##__VA_ARGS__); \ - func_impl(8, ##__VA_ARGS__); \ - func_impl(16, ##__VA_ARGS__); \ - func_impl(32, ##__VA_ARGS__); \ - func_impl(64, ##__VA_ARGS__); \ - func_impl(128, ##__VA_ARGS__); - - if (has_same_shape) { -#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ - case size_: { \ - PointerWrapper ptr_array(ctx, ins, inputs_data); \ - __VA_ARGS__; \ - } break; - - switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { - IMPL_CONCATE_CUDA_KERNEL_HELPER( - IMPL_CONCAT_CUDA_KERNEL_CASE, - ConcatTensorWithSameShape - <<>>( - ptr_array, in_col, out_row, out_col, output->data())); - default: { - paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; - PointerToPointer ptr_array(ctx, ins, inputs_data, &dev_ins_ptr); - ConcatTensorWithSameShape - <<>>( - ptr_array, in_col, out_row, out_col, output->data()); - } - } -#undef IMPL_CONCAT_CUDA_KERNEL_CASE - } else { -#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ - case size_: { \ - PointerAndColWrapper ptr_col_array( \ - ctx, ins, inputs_col_num, inputs_data, inputs_col); \ - __VA_ARGS__; \ - } break; - - switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { - IMPL_CONCATE_CUDA_KERNEL_HELPER( - IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE, - ConcatTensorWithDifferentShape - <<>>(ptr_col_array, - inputs_col_num, - out_row, - out_col, - output->data())); - default: { - paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; - paddle::memory::AllocationPtr dev_col_ptr{nullptr}; - PointerToPointerAndCol ptr_col_array(ctx, - ins, - inputs_col_num, - inputs_data, - inputs_col, - &dev_ins_ptr, - &dev_col_ptr); - ConcatTensorWithDifferentShape - <<>>(ptr_col_array, - inputs_col_num, - out_row, - out_col, - output->data()); - } - } -#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE - } -#undef IMPL_CONCATE_CUDA_KERNEL_HELPER + DispatchConcatKernel(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + out_row, + out_col, + output, + in_num, + limit_num, + has_same_shape); #ifdef PADDLE_WITH_HIP // Prevent pinned memory from being covered and release the memory after -- GitLab