From 15577630c940bf94279f881ecc58d27d956fd620 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 15 Jun 2022 13:02:19 +0800 Subject: [PATCH] Use int64_t in GetGpuLaunchConfig1D and ElementwiseKernel as index type to support large tensor. (#43506) * Change some data type from int to int64_t in GetGpuLaunchConfig1D to support large tensor. * Use int64_t in ElementwiseKernel as index type to support large tensor. --- paddle/phi/backends/gpu/gpu_launch_config.h | 62 +++++++----- paddle/phi/kernels/funcs/elementwise_base.h | 96 +++++++------------ .../phi/kernels/primitive/kernel_primitives.h | 13 +++ 3 files changed, 81 insertions(+), 90 deletions(-) diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 2dd1431ff58..04b2786c4d0 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -37,8 +37,7 @@ // HIP results in error or nan if > 256 #define PREDEFINED_BLOCK_SIZE 256 #else -/* CUDA performs better as thread_per_block - num is between [64, 512] */ +// CUDA performs better when thread_per_block is between [64, 512] #define PREDEFINED_BLOCK_SIZE 512 #endif @@ -46,22 +45,27 @@ namespace phi { namespace backends { namespace gpu { -inline int DivUp(int a, int b) { return (a + b - 1) / b; } +template +inline T DivUp(T a, T b) { + return (a + b - 1) / b; +} -/* https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 - for round integer value into next highest power of 2. */ -static inline int RoundToPowerOfTwo(int n) { +// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 +// for round integer value into next highest power of 2. +inline int64_t RoundToPowerOfTwo(int64_t n) { n--; n |= (n >> 1); n |= (n >> 2); n |= (n >> 4); n |= (n >> 8); n |= (n >> 16); + int64_t min_val = 32; #ifdef __HIPCC__ - return std::min(256, std::max(32, (n + 1))); + int64_t max_val = 256; #else - return std::min(1024, std::max(32, (n + 1))); + int64_t max_val = 1024; #endif + return std::min(max_val, std::max(min_val, (n + 1))); } #ifdef WITH_NV_JETSON @@ -106,12 +110,17 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, PADDLE_ENFORCE_GE(numel, 0, phi::errors::InvalidArgument( - "element quantity should be greater than or equal 0," - " but received value is: %d.", + "numel is expected to be greater than or equal 0," + " but received %d.", numel)); + PADDLE_ENFORCE_GE( + vec_size, + 1, + phi::errors::InvalidArgument( + "vec_size is expected greater than 0, but received %d.", vec_size)); // Get compute_capability const int capability = context.GetComputeCapability(); - /* If thread number per block is 64/128/256/512, cuda performs better.*/ + // If thread number per block is 64/128/256/512, cuda performs better. int limit_threads = std::min(PREDEFINED_BLOCK_SIZE, context.GetMaxThreadsPerBlock()); #ifdef WITH_NV_JETSON @@ -121,7 +130,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, #endif int threads = limit_threads; int sm_count = context.GetSMCount(); - int active_threads_num = numel / vec_size; + int64_t active_threads_num = numel / vec_size; if (active_threads_num / (sm_count << 1) < limit_threads) { // Round up threads number into an exponential multiple of 2, while number // of acitve blocks is about twice of SM, to acquire better performance. @@ -133,7 +142,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, } // Number of threads per block shall be larger than 64. threads = std::max(64, threads); - int blocks = DivUp(DivUp(numel, vec_size), threads); + int blocks = DivUp(DivUp(numel, vec_size), threads); int limit_blocks = context.GetCUDAMaxGridDimSize()[0]; if (blocks > limit_blocks) { blocks = limit_blocks; @@ -143,6 +152,11 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, config.thread_per_block.x = threads; config.block_per_grid.x = blocks; config.compute_capability = capability; + + VLOG(3) << "Get 1-D launch config: numel=" << numel + << ", vec_size=" << vec_size << ", block_size=" << threads + << ", grid_size=" << blocks << ", limit_blocks=" << limit_blocks + << ", limit_threads=" << limit_threads; return config; } @@ -163,19 +177,18 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, y_dim)); const int kThreadsPerBlock = 256; - int block_cols = (std::min)(x_dim, kThreadsPerBlock); - int block_rows = (std::max)(kThreadsPerBlock / block_cols, 1); + int block_cols = std::min(x_dim, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int max_physical_threads = context.GetMaxPhysicalThreadCount(); - const int max_blocks = (std::max)(max_physical_threads / kThreadsPerBlock, 1); + const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1); GpuLaunchConfig config; // Noticed, block size is not align to 32, if needed do it yourself. config.thread_per_block = dim3(block_cols, block_rows, 1); - int grid_x = (std::min)(DivUp(x_dim, block_cols), max_blocks); - int grid_y = - (std::min)(max_blocks / grid_x, (std::max)(y_dim / block_rows, 1)); + int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); + int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1)); config.block_per_grid = dim3(grid_x, grid_y, 1); return config; @@ -202,13 +215,10 @@ inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context, int block_y = std::min(GetLastPow2(height), max_threads / block_x); int block_z = std::min(num_img, max_threads / block_x / block_y); - auto max_grid_dim = context.GetCUDAMaxGridDimSize(); - int grid_x = - std::min(max_grid_dim[0], backends::gpu::DivUp(width, block_x)); - int grid_y = - std::min(max_grid_dim[1], backends::gpu::DivUp(height, block_y)); - int grid_z = std::min(max_grid_dim[2], - backends::gpu::DivUp(num_img, block_z * 4)); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + int grid_x = std::min(max_grid_dim[0], DivUp(width, block_x)); + int grid_y = std::min(max_grid_dim[1], DivUp(height, block_y)); + int grid_z = std::min(max_grid_dim[2], DivUp(num_img, block_z * 4)); const int capability = context.GetComputeCapability(); GpuLaunchConfig config; diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 8b5a3cf8aaa..daaf88a2395 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -511,8 +511,8 @@ struct Loader { template static __device__ void Apply(const Array &in, ArgsT *args, + kps::IndexType offset, int num, - int data_offset, int read_lens, bool is_boundary) { using Type = std::tuple_element_t; @@ -521,13 +521,13 @@ struct Loader { if (is_boundary) { kps::ReadData( args, - reinterpret_cast(in[Index]) + data_offset, + reinterpret_cast(in[Index]) + offset, num, read_lens); } else { kps::ReadData( args, - reinterpret_cast(in[Index]) + data_offset, + reinterpret_cast(in[Index]) + offset, num, read_lens); } @@ -681,46 +681,12 @@ struct SameDimsElementwisePrimitiveCaller { } }; -template -struct ElementwiseWriteDataCaller { - __device__ __forceinline__ void operator()( - phi::Array<_ptr_ OutT *, NumOuts> outs, - ConditionalT src[VecSize], - int block_offset, - int num) { - OutT dst[NumOuts][VecSize]; -#pragma unroll - for (int i = 0; i < VecSize; ++i) { -#pragma unroll - for (int j = 0; j < NumOuts; ++j) { - dst[j][i] = (src[i])[j]; - } - } -#pragma unroll - for (int i = 0; i < NumOuts; ++i) { - kps::WriteData( - outs[i] + block_offset, dst[i], num); - } - } -}; - -template -struct ElementwiseWriteDataCaller { - __device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs, - OutT src[VecSize], - int block_offset, - int num) { - kps::WriteData( - outs[0] + block_offset, src, num); - } -}; - template struct ElementwiseWriteDataCallerBc { __device__ __forceinline__ void operator()( phi::Array<_ptr_ OutT *, NumOuts> outs, ConditionalT src[VecSize], - int block_offset, + kps::IndexType block_offset, int num, int read_lens) { OutT dst[NumOuts][VecSize]; @@ -743,7 +709,7 @@ template struct ElementwiseWriteDataCallerBc { __device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs, OutT src[VecSize], - int block_offset, + kps::IndexType block_offset, int num, int read_lens) { kps::WriteData( @@ -758,11 +724,10 @@ template __device__ void VectorizedElementwiseKernelImpl( - const phi::Array &in, phi::Array<_ptr_ OutT *, NumOuts> outs, + kps::IndexType offset, int num, - int data_offset, int read_lens, Functor func) { using Traits = paddle::platform::FunctionTraits; @@ -771,7 +736,7 @@ __device__ void VectorizedElementwiseKernelImpl( ConditionalT result[VecSize]; Unroller::step( - in, args, num, data_offset, read_lens, IsBoundary); + in, args, offset, num, read_lens, IsBoundary); SameDimsElementwisePrimitiveCaller, VecSize, @@ -780,19 +745,19 @@ __device__ void VectorizedElementwiseKernelImpl( Arity>()(func, args, result, read_lens); ElementwiseWriteDataCallerBc()( - outs, result, data_offset, num, read_lens); + outs, result, offset, num, read_lens); } template __global__ void VectorizedElementwiseKernel( phi::Array ins, phi::Array<_ptr_ OutT *, NumOuts> outs, - int size, - int main_offset, + kps::IndexType numel, + kps::IndexType main_offset, int read_lens, Functor func) { - int data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; - int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; + kps::IndexType data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; + kps::IndexType stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; for (; data_offset < main_offset; data_offset += stride) { VectorizedElementwiseKernelImpl( - ins, outs, read_lens * BLOCK_NUM_X, data_offset, read_lens, func); + ins, outs, data_offset, read_lens * BLOCK_NUM_X, read_lens, func); } - int num = size - data_offset; - if (num > 0) { + int remain = numel - data_offset; + if (remain > 0) { VectorizedElementwiseKernelImpl( - ins, outs, num, data_offset, read_lens, func); + ins, outs, data_offset, remain, read_lens, func); } } template -void ElementwiseCudaKernel(const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - int read_lens, - Functor func) { - auto numel = - (*outs)[0]->numel(); // To avoid running errors when ins.size()== 0 +void LaunchElementwiseCudaKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int read_lens, + Functor func) { + // There are at least 1 output, but maybe 0 input (ins.size() == 0). + // For large tensor numel * sizeof(T) > 2^31, we must use int64_t as index + // type. + int64_t numel = (*outs)[0]->numel(); phi::Array ins_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data; @@ -834,15 +801,16 @@ void ElementwiseCudaKernel(const KPDevice &ctx, int block_size = 64; int grid_size = 8; auto stream = ctx.x_context()->xpu_stream; - int main_offset = (numel / (read_lens * block_size)) * read_lens * block_size; + int64_t main_offset = + (numel / (read_lens * block_size)) * read_lens * block_size; VectorizedElementwiseKernel <<>>( ins_data, outs_data, numel, main_offset, read_lens, func); #else auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); - int main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * VecSize * - gpu_config.GetBlockSize(); + int64_t main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * + VecSize * gpu_config.GetBlockSize(); auto stream = ctx.stream(); VectorizedElementwiseKernel <<>>( @@ -901,15 +869,15 @@ void ElementwiseKernel(const KPDevice &ctx, #endif switch (vec_size) { case VecSizeL: - ElementwiseCudaKernel( + LaunchElementwiseCudaKernel( ctx, ins, outs, read_lens, func); break; case VecSizeM: - ElementwiseCudaKernel( + LaunchElementwiseCudaKernel( ctx, ins, outs, read_lens, func); break; case VecSizeS: - ElementwiseCudaKernel( + LaunchElementwiseCudaKernel( ctx, ins, outs, read_lens, func); break; default: { diff --git a/paddle/phi/kernels/primitive/kernel_primitives.h b/paddle/phi/kernels/primitive/kernel_primitives.h index f68a046ae07..729402ed559 100644 --- a/paddle/phi/kernels/primitive/kernel_primitives.h +++ b/paddle/phi/kernels/primitive/kernel_primitives.h @@ -87,3 +87,16 @@ #include "paddle/phi/kernels/primitive/functor_primitives.h" #endif + +namespace phi { +namespace kps { + +#ifdef PADDLE_WITH_XPU_KP +// The type of index used in kernel +using IndexType = int; +#else +using IndexType = int64_t; +#endif + +} // namespace kps +} // namespace phi -- GitLab