From 0de94cd994710ecd49e9aab72be36550c366f8ac Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Tue, 3 Jan 2023 20:35:15 +0800 Subject: [PATCH] H2D data transfer optimization for concat kernel (#49040) --- paddle/phi/backends/gpu/gpu_launch_config.h | 11 +- .../kernels/funcs/concat_and_split_functor.cu | 490 ++++++++++-------- paddle/phi/kernels/gpu/stack_kernel.cu | 120 +++-- 3 files changed, 342 insertions(+), 279 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index f0a37d4fb7b..05d97dc45a4 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -53,20 +53,25 @@ inline T DivUp(T a, T b) { // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 // for round integer value into next highest power of 2. -inline int64_t RoundToPowerOfTwo(int64_t n) { +inline int64_t RoundToNextHighPowOfTwo(int64_t n, int64_t min_val = 1) { n--; n |= (n >> 1); n |= (n >> 2); n |= (n >> 4); n |= (n >> 8); n |= (n >> 16); - int64_t min_val = 32; + return std::max(min_val, (n + 1)); +} + +inline int64_t RoundToPowerOfTwo(int64_t n) { + constexpr int64_t min_val = 32; + int64_t num = RoundToNextHighPowOfTwo(n, min_val); #ifdef __HIPCC__ int64_t max_val = 256; #else int64_t max_val = 1024; #endif - return std::min(max_val, std::max(min_val, (n + 1))); + return std::min(max_val, num); } #ifdef WITH_NV_JETSON diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index 3c1f9ec6cc1..d238853f6a2 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -15,49 +15,155 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" namespace phi { namespace funcs { +template +struct PointerWrapper { + public: + const T* ins_addr[Size]; + __device__ inline const T* operator[](int i) const { return ins_addr[i]; } + + PointerWrapper() {} + PointerWrapper(const phi::GPUContext& ctx, + const std::vector& ins, + const T** pre_alloced_host_ptr) { + for (auto i = 0; i < ins.size(); ++i) { + ins_addr[i] = ins[i].data(); + } + } +}; + template -__global__ void ConcatKernel_(const T** inputs, - const int64_t* input_cols, - int col_size, - const int64_t output_rows, - const int64_t output_cols, - T* output) { - int64_t curr_segment = 0; - int64_t curr_offset = input_cols[0]; - CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, int64_t) { - int64_t curr_col_offset = input_cols[curr_segment + 1]; +struct PointerToPointer { + public: + T** ins_addr{nullptr}; + __device__ inline const T* operator[](int i) const { return ins_addr[i]; } + + PointerToPointer() {} + PointerToPointer(const phi::GPUContext& ctx, + const std::vector& ins, + const T** pre_alloced_host_ptr, + paddle::memory::AllocationPtr* dev_ins_ptr) { + auto in_num = ins.size(); + for (auto i = 0; i < in_num; ++i) { + pre_alloced_host_ptr[i] = ins[i].data(); + } + *dev_ins_ptr = paddle::memory::Alloc( + ctx.GetPlace(), + in_num * sizeof(T*), + phi::Stream(reinterpret_cast(ctx.stream()))); + auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( + pre_alloced_host_ptr, in_num); + paddle::memory::Copy(ctx.GetPlace(), + (*dev_ins_ptr)->ptr(), + phi::CPUPlace(), + restored, + in_num * sizeof(T*), + ctx.stream()); + ins_addr = reinterpret_cast((*dev_ins_ptr)->ptr()); + } +}; + +template +struct PointerAndColWrapper { + public: + IndexT col_length[Size]; + PointerAndColWrapper(const phi::GPUContext& ctx, + const std::vector& ins, + const IndexT& inputs_col_num, + const T** pre_alloced_host_ptr, + IndexT* inputs_col) { + for (auto i = 0; i < inputs_col_num; ++i) { + col_length[i] = inputs_col[i]; + } + ins_ptr_wrapper = PointerWrapper(ctx, ins, pre_alloced_host_ptr); + } + + __device__ inline const T* operator[](int i) const { + return ins_ptr_wrapper[i]; + } + + private: + PointerWrapper ins_ptr_wrapper; +}; + +template +struct PointerToPointerAndCol { + public: + IndexT* col_length{nullptr}; + PointerToPointerAndCol(const phi::GPUContext& ctx, + const std::vector& ins, + const IndexT inputs_col_num, + const T** pre_alloced_host_ptr, + IndexT* inputs_col, + paddle::memory::AllocationPtr* dev_ins_ptr, + paddle::memory::AllocationPtr* dev_col_ptr) { + *dev_col_ptr = paddle::memory::Alloc( + ctx.GetPlace(), + inputs_col_num * sizeof(IndexT), + phi::Stream(reinterpret_cast(ctx.stream()))); + auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( + inputs_col, inputs_col_num); + paddle::memory::Copy(ctx.GetPlace(), + (*dev_col_ptr)->ptr(), + phi::CPUPlace(), + restored, + inputs_col_num * sizeof(IndexT), + ctx.stream()); + col_length = static_cast((*dev_col_ptr)->ptr()); + ins_ptr_wrapper = + PointerToPointer(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr); + } + + __device__ inline const T* operator[](int i) const { + return ins_ptr_wrapper[i]; + } + + private: + PointerToPointer ins_ptr_wrapper; +}; + +template +__global__ void ConcatTensorWithDifferentShape(PointerAndColWrapperT ins_datas, + int col_size, + const IndexT output_rows, + const IndexT output_cols, + T* output) { + IndexT curr_segment = 0; + IndexT curr_offset = ins_datas.col_length[0]; + CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) { + IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1]; while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; - curr_col_offset = input_cols[curr_segment + 1]; + curr_col_offset = ins_datas.col_length[curr_segment + 1]; } - int64_t local_col = tid_x - curr_offset; - int64_t segment_width = curr_col_offset - curr_offset; + IndexT local_col = tid_x - curr_offset; + IndexT segment_width = curr_col_offset - curr_offset; - const T* input_ptr = inputs[curr_segment]; - int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; + const T* input_ptr = ins_datas[curr_segment]; + IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) output[tid_y * output_cols + tid_x] = input_ptr[tid_y * segment_width + local_col]; } } -template -__device__ void ConcatKernelDetail(const T** inputs_data, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, int64_t) { - int64_t split = tid_x * 1.0 / fixed_in_col; - int64_t in_offset = tid_x - split * fixed_in_col; - const T* input_ptr = inputs_data[split]; - int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; +template +__global__ void ConcatTensorWithSameShape(PointerWrapperT ins_data, + const IndexT fixed_in_col, + const IndexT out_rows, + const IndexT out_cols, + T* output_data) { + CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) { + IndexT split = tid_x / fixed_in_col; + IndexT in_offset = tid_x - split * fixed_in_col; + const T* input_ptr = ins_data[split]; + IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { output_data[tid_y * out_cols + tid_x] = input_ptr[tid_y * fixed_in_col + in_offset]; @@ -65,65 +171,6 @@ __device__ void ConcatKernelDetail(const T** inputs_data, } } -template -__global__ void ConcatKernel_(const T* input_addr0, - const T* input_addr1, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - const T* inputs_data[2]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel_(const T* input_addr0, - const T* input_addr1, - const T* input_addr2, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - const T* inputs_data[3]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel_(const T* input_addr0, - const T* input_addr1, - const T* input_addr2, - const T* input_addr3, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - const T* inputs_data[4]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - inputs_data[3] = input_addr3; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel_(const T** inputs_data, - const int in_num, - const int64_t fixed_in_col, - const int64_t out_rows, - const int64_t out_cols, - T* output_data) { - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - template __global__ void SplitKernel_(const T* input_data, const int64_t in_row, @@ -254,155 +301,146 @@ static inline void GetBlockDims(const phi::GPUContext& context, * All tensors' dimension should be the same and the values of * each dimension must be the same, except the axis dimension. */ - -template -struct ConcatFunctor { - void operator()(const phi::GPUContext& context, - const std::vector& input, - int axis, - phi::DenseTensor* output) { - // TODO(zcd): Add input data validity checking - int64_t in_num = input.size(); - int64_t in_row = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - in_row *= dim_0[i]; - } - int64_t in_col = input[0].numel() / in_row; - int64_t out_row = in_row, out_col = 0; - - int64_t inputs_col_num = in_num + 1; - std::vector inputs_data_vec(in_num); - std::vector inputs_col_vec(inputs_col_num); - const T** inputs_data = inputs_data_vec.data(); - int64_t* inputs_col = inputs_col_vec.data(); - -// There are some differences between hip runtime and NV runtime. -// In NV, when the pageable memory data less than 64K is transferred from -// hosttodevice, it will be automatically asynchronous. -// However, only pinned memory in hip can copy asynchronously -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device -// 3.2.6.1. Concurrent Execution between Host and Device -// Memory copies from host to device of a memory block of 64 KB or less +template +void ConcatFunctorWithIndexType(const phi::GPUContext& ctx, + const std::vector& ins, + int axis, + phi::DenseTensor* output) { + // TODO(zcd): Add input data validity checking + IndexT in_num = ins.size(); + IndexT in_row = 1; + auto dim_0 = ins[0].dims(); + for (int i = 0; i < axis; ++i) { + in_row *= dim_0[i]; + } + IndexT in_col = ins[0].numel() / in_row; + IndexT out_row = in_row, out_col = 0; + + IndexT inputs_col_num = in_num + 1; + std::vector inputs_data_vec(in_num, nullptr); + std::vector inputs_col_vec(inputs_col_num, 0); + const T** inputs_data = inputs_data_vec.data(); + IndexT* inputs_col = inputs_col_vec.data(); #ifdef PADDLE_WITH_HIP - paddle::memory::AllocationPtr data_alloc, col_alloc; - // TODO(chentianyu03): try to find a method to remove the Alloc function - data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - in_num * sizeof(T*)); - inputs_data = reinterpret_cast(data_alloc->ptr()); - // TODO(chentianyu03): try to find a method to remove the Alloc function - col_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), - inputs_col_num * sizeof(int)); - inputs_col = reinterpret_cast(col_alloc->ptr()); + // TODO(chentianyu03): try to find a method to remove the Alloc function + paddle::memory::AllocationPtr data_alloc = paddle::memory::Alloc( + paddle::platform::CUDAPinnedPlace(), in_num * sizeof(T*)); + inputs_data = reinterpret_cast(data_alloc->ptr()); + paddle::memory::AllocationPtr col_alloc = paddle::memory::Alloc( + paddle::platform::CUDAPinnedPlace(), inputs_col_num * sizeof(IndexT)); + inputs_col = reinterpret_cast(col_alloc->ptr()); #endif - inputs_col[0] = 0; - bool has_same_shape = true; - for (int i = 0; i < in_num; ++i) { - int64_t t_cols = input[i].numel() / in_row; - if (has_same_shape) { - if (t_cols != in_col) has_same_shape = false; - } - out_col += t_cols; - inputs_col[i + 1] = out_col; - inputs_data[i] = input[i].data(); + bool has_same_shape = true; + for (int i = 0; i < in_num; ++i) { + IndexT t_cols = ins[i].numel() / in_row; + if (has_same_shape) { + has_same_shape &= (t_cols == in_col); } - - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims); - - paddle::memory::allocation::AllocationPtr tmp_dev_ins_data; - const T** dev_ins_data = nullptr; - if (!has_same_shape || in_num < 2 || in_num > 4) { - tmp_dev_ins_data = paddle::memory::Alloc( - context.GetPlace(), - in_num * sizeof(T*), - phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( - inputs_data, in_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_ins_data->ptr(), - paddle::platform::CPUPlace(), - restored, - in_num * sizeof(T*), - context.stream()); - dev_ins_data = reinterpret_cast(tmp_dev_ins_data->ptr()); + out_col += t_cols; + inputs_col[i + 1] = out_col; + } + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); + IndexT limit_num = has_same_shape ? in_num : inputs_col_num; + +#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \ + func_impl(4, ##__VA_ARGS__); \ + func_impl(8, ##__VA_ARGS__); \ + func_impl(16, ##__VA_ARGS__); \ + func_impl(32, ##__VA_ARGS__); \ + func_impl(64, ##__VA_ARGS__); \ + func_impl(128, ##__VA_ARGS__); + + if (has_same_shape) { +#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ + case size_: { \ + PointerWrapper ptr_array(ctx, ins, inputs_data); \ + __VA_ARGS__; \ + } break; + + switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { + IMPL_CONCATE_CUDA_KERNEL_HELPER( + IMPL_CONCAT_CUDA_KERNEL_CASE, + ConcatTensorWithSameShape + <<>>( + ptr_array, in_col, out_row, out_col, output->data())); + default: { + paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; + PointerToPointer ptr_array(ctx, ins, inputs_data, &dev_ins_ptr); + ConcatTensorWithSameShape + <<>>( + ptr_array, in_col, out_row, out_col, output->data()); + } } - - if (has_same_shape) { - if (in_num == 2) { - ConcatKernel_<<>>( - inputs_data[0], - inputs_data[1], - in_col, - out_row, - out_col, - output->data()); - } else if (in_num == 3) { - ConcatKernel_<<>>( - inputs_data[0], - inputs_data[1], - inputs_data[2], - in_col, - out_row, - out_col, - output->data()); - } else if (in_num == 4) { - ConcatKernel_<<>>( - inputs_data[0], - inputs_data[1], - inputs_data[2], - inputs_data[3], - in_col, - out_row, - out_col, - output->data()); - } else { - ConcatKernel_<<>>( - dev_ins_data, in_num, in_col, out_row, out_col, output->data()); +#undef IMPL_CONCAT_CUDA_KERNEL_CASE + } else { +#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ + case size_: { \ + PointerAndColWrapper ptr_col_array( \ + ctx, ins, inputs_col_num, inputs_data, inputs_col); \ + __VA_ARGS__; \ + } break; + + switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) { + IMPL_CONCATE_CUDA_KERNEL_HELPER( + IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE, + ConcatTensorWithDifferentShape + <<>>(ptr_col_array, + inputs_col_num, + out_row, + out_col, + output->data())); + default: { + paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; + paddle::memory::AllocationPtr dev_col_ptr{nullptr}; + PointerToPointerAndCol ptr_col_array(ctx, + ins, + inputs_col_num, + inputs_data, + inputs_col, + &dev_ins_ptr, + &dev_col_ptr); + ConcatTensorWithDifferentShape + <<>>(ptr_col_array, + inputs_col_num, + out_row, + out_col, + output->data()); } - } else { - auto tmp_dev_ins_col_data = paddle::memory::Alloc( - context.GetPlace(), - inputs_col_num * sizeof(int64_t), - phi::Stream(reinterpret_cast(context.stream()))); - - auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( - inputs_col, inputs_col_num); - paddle::memory::Copy(context.GetPlace(), - tmp_dev_ins_col_data->ptr(), - paddle::platform::CPUPlace(), - restored, - inputs_col_num * sizeof(int64_t), - context.stream()); - int64_t* dev_ins_col_data = - static_cast(tmp_dev_ins_col_data->ptr()); - - ConcatKernel_<<>>( - dev_ins_data, - dev_ins_col_data, - static_cast(inputs_col_num), - out_row, - out_col, - output->data()); } +#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE + } +#undef IMPL_CONCATE_CUDA_KERNEL_HELPER #ifdef PADDLE_WITH_HIP - // Prevent the pinned memory value from being covered and release the memory - // after the launch kernel of the stream is executed (reapply pinned memory - // next time) - auto* data_alloc_released = data_alloc.release(); - auto* col_alloc_released = col_alloc.release(); - context.AddStreamCallback([data_alloc_released, col_alloc_released] { - VLOG(4) << "Delete cuda pinned at " << data_alloc_released; - VLOG(4) << "Delete cuda pinned at " << col_alloc_released; - paddle::memory::allocation::Allocator::AllocationDeleter( - data_alloc_released); - paddle::memory::allocation::Allocator::AllocationDeleter( - col_alloc_released); - }); + // Prevent pinned memory from being covered and release the memory after + // kernel launch of the stream is executed (reapply pinned memory next time) + auto* data_alloc_released = data_alloc.release(); + auto* col_alloc_released = col_alloc.release(); + ctx.AddStreamCallback([data_alloc_released, col_alloc_released] { + VLOG(4) << "Delete cuda pinned at " << data_alloc_released; + VLOG(4) << "Delete cuda pinned at " << col_alloc_released; + paddle::memory::allocation::Allocator::AllocationDeleter( + data_alloc_released); + paddle::memory::allocation::Allocator::AllocationDeleter( + col_alloc_released); + }); #endif +} + +template +struct ConcatFunctor { + void operator()(const phi::GPUContext& context, + const std::vector& input, + int axis, + phi::DenseTensor* output) { + if (output->numel() < std::numeric_limits::max()) { + ConcatFunctorWithIndexType(context, input, axis, output); + } else { + ConcatFunctorWithIndexType(context, input, axis, output); + } } }; @@ -488,7 +526,7 @@ class SplitFunctor { outputs_data, o_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_outs_data->ptr(), - paddle::platform::CPUPlace(), + phi::CPUPlace(), restored, o_num * sizeof(T*), context.stream()); @@ -539,7 +577,7 @@ class SplitFunctor { outputs_cols, outputs_cols_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), - paddle::platform::CPUPlace(), + phi::CPUPlace(), restored, outputs_cols_num * sizeof(int64_t), context.stream()); diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index 3cfb98beca3..d3f24095069 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -25,7 +25,9 @@ namespace phi { template struct DivmodWarpper { public: - void SetDivden(IndexT dividen) { divmoder = phi::funcs::FastDivMod(dividen); } + void SetDivisor(IndexT divisor) { + divmoder = phi::funcs::FastDivMod(divisor); + } __device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) { return divmoder.Divmod(val); } @@ -39,7 +41,7 @@ struct DivmodWarpper { public: using DivModT = phi::AlignedVector; - void SetDivden(int64_t dividen) { dividen_ = dividen; } + void SetDivisor(int64_t divisor) { dividen_ = divisor; } __device__ inline DivModT div_mod(int64_t val) { DivModT data; data[0] = val / dividen_; @@ -51,15 +53,14 @@ struct DivmodWarpper { int64_t dividen_; }; -constexpr int kWarpperSize = 64; -template +template struct PointerArray : public DivmodWarpper { public: - const T* data[kWarpperSize]; + const T* data[Size]; PointerArray(const std::vector& x, int num, - int64_t dividen) { - this->SetDivden(dividen); + IndexT divisor) { + this->SetDivisor(divisor); for (auto i = 0; i < num; ++i) { data[i] = x[i]->data(); } @@ -69,33 +70,33 @@ struct PointerArray : public DivmodWarpper { template struct PointerToPointer : public DivmodWarpper { public: - T** data; + T** data{nullptr}; PointerToPointer(const Context& ctx, const std::vector& x, - int num, - int64_t dividen) { - this->SetDivden(dividen); - auto byte_len = num * sizeof(T*); + IndexT num, + IndexT divisor, + paddle::memory::AllocationPtr* dev_ins_ptr) { + this->SetDivisor(divisor); std::vector x_datas(num); for (int i = 0; i < num; ++i) { x_datas[i] = x[i]->data(); } - auto tmp_x_data = paddle::memory::Alloc( + *dev_ins_ptr = paddle::memory::Alloc( ctx.GetPlace(), - byte_len, + num * sizeof(T*), phi::Stream(reinterpret_cast(ctx.stream()))); paddle::memory::Copy(ctx.GetPlace(), - tmp_x_data->ptr(), + (*dev_ins_ptr)->ptr(), phi::CPUPlace(), reinterpret_cast(x_datas.data()), - x_datas.size() * sizeof(T*), + num * sizeof(T*), ctx.stream()); - data = reinterpret_cast(tmp_x_data->ptr()); + data = reinterpret_cast((*dev_ins_ptr)->ptr()); } }; -template -__global__ void StackCUDAKernel(WarpT input_warpper, +template +__global__ void StackCUDAKernel(WrapT input_warpper, IndexT split_size, IndexT rows, IndexT cols, @@ -117,14 +118,56 @@ __global__ void StackCUDAKernel(WarpT input_warpper, } } +template +void LaunchStackCUDAKernelWithIndexType( + const Context& ctx, + const IndexT x_col, + const IndexT x_row, + const IndexT out_col, + const phi::backends::gpu::GpuLaunchConfig& cfg, + const std::vector& x, + T* dst_data) { + int num = static_cast(x.size()); +#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \ + case size_: { \ + PointerArray ptr_array(x, num, x_col); \ + __VA_ARGS__; \ + } break; + +#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \ + IMPL_STACK_CUDA_KERNEL_CASE(2, ##__VA_ARGS__); \ + IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \ + IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \ + IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \ + IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \ + IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__); + + switch (phi::backends::gpu::RoundToNextHighPowOfTwo(num, 4)) { + IMPL_STACK_CUDA_KERNEL_HELPER( + StackCUDAKernel + <<>>( + ptr_array, x_col, x_row, out_col, dst_data)); + default: { + paddle::memory::AllocationPtr dev_ins_ptr{nullptr}; + PointerToPointer ptr_array( + ctx, x, num, x_col, &dev_ins_ptr); + StackCUDAKernel + <<>>( + ptr_array, x_col, x_row, out_col, dst_data); + } + } +#undef IMPL_STACK_CUDA_KERNEL_HELPER +#undef IMPL_STACK_CUDA_KERNEL_CASE +} + template void StackKernel(const Context& dev_ctx, const std::vector& x, int axis, DenseTensor* out) { if (axis < 0) axis += (x[0]->dims().size() + 1); - int n = static_cast(x.size()); - T* y_data = dev_ctx.template Alloc(out); + int num = static_cast(x.size()); + T* dst_data = dev_ctx.template Alloc(out); // Split x dim from axis to matrix int64_t x_row = 1, x_col = 1; @@ -132,40 +175,17 @@ void StackKernel(const Context& dev_ctx, x_row *= x[0]->dims()[i]; } x_col = x[0]->numel() / x_row; - int64_t out_col = x_col * n; + int64_t out_col = x_col * num; auto config = phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); -#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \ - StackCUDAKernel \ - <<>>(input_warpper, \ - static_cast(x_col), \ - static_cast(x_row), \ - static_cast(out_col), \ - y_data); - - bool use_int32 = out->numel() < std::numeric_limits::max(); - if (n <= kWarpperSize) { - if (use_int32) { - PointerArray ptr_array(x, n, x_col); - IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array); - } else { - PointerArray ptr_array(x, n, x_col); - IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array); - } + if (out->numel() < std::numeric_limits::max()) { + LaunchStackCUDAKernelWithIndexType( + dev_ctx, x_col, x_row, out_col, config, x, dst_data); } else { - if (use_int32) { - PointerToPointer ptr_array(dev_ctx, x, n, x_col); - IMPL_STACK_CUDA_KERNEL(int32_t, ptr_array); - } else { - PointerToPointer ptr_array(dev_ctx, x, n, x_col); - IMPL_STACK_CUDA_KERNEL(int64_t, ptr_array); - } + LaunchStackCUDAKernelWithIndexType( + dev_ctx, x_col, x_row, out_col, config, x, dst_data); } -#undef IMPL_STACK_CUDA_KERNEL } } // namespace phi -- GitLab