From a1dbee23be892212cffc541abafc9dde7ff42084 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 5 Sep 2022 11:01:17 +0800 Subject: [PATCH] fix some op int32 exceed range (#45711) --- .../platform/device/gpu/cuda/cuda_helper.h | 7 ++--- .../platform/device/gpu/gpu_launch_config.h | 25 ++++++++++------- .../platform/device/gpu/rocm/rocm_helper.h | 7 ++--- paddle/phi/backends/gpu/cuda/cuda_helper.h | 7 ++--- paddle/phi/backends/gpu/gpu_launch_config.h | 11 ++++---- paddle/phi/backends/gpu/rocm/rocm_helper.h | 7 ++--- paddle/phi/kernels/gpu/one_hot_kernel.cu | 11 ++++---- paddle/phi/kernels/gpu/stack_kernel.cu | 27 ++++++++++--------- 8 files changed, 58 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h index 7185d2356a..d1d33d50a5 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h @@ -70,9 +70,10 @@ namespace platform { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ - for (index_type i = __index__; __index__ < (num); \ +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ + for (index_type i = __index__; __index__ < (num); \ __index__ += blockDim.x * gridDim.x, i = __index__) class CublasHandleHolder { diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index 3628b7e041..ca861d543f 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -44,7 +44,10 @@ namespace paddle { namespace platform { -inline int DivUp(int a, int b) { return (a + b - 1) / b; } +template +inline T DivUp(T a, T b) { + return (a + b - 1) / b; +} /* https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 for round integer value into next highest power of 2. */ @@ -120,7 +123,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, #endif int threads = limit_threads; int sm_count = context.GetSMCount(); - int active_threads_num = numel / vec_size; + int64_t active_threads_num = numel / vec_size; if (active_threads_num / (sm_count << 1) < limit_threads) { // Round up threads number into an exponential multiple of 2, while number // of acitve blocks is about twice of SM, to acquire better performance. @@ -132,7 +135,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, } // Number of threads per block shall be larger than 64. threads = std::max(64, threads); - int blocks = DivUp(DivUp(numel, vec_size), threads); + int64_t blocks = DivUp(DivUp(numel, vec_size), threads); int limit_blocks = context.GetCUDAMaxGridDimSize()[0]; if (blocks > limit_blocks) { blocks = limit_blocks; @@ -146,8 +149,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, } inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, - int x_dim, - int y_dim) { + int64_t x_dim, + int64_t y_dim) { PADDLE_ENFORCE_GT( x_dim, 0, @@ -162,8 +165,10 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, y_dim)); const int kThreadsPerBlock = 256; - int block_cols = (std::min)(x_dim, kThreadsPerBlock); - int block_rows = (std::max)(kThreadsPerBlock / block_cols, 1); + // NOTE(zengjinle): cast std::min result to int is safe here, because + // kThreadsPerBlock is always very small. + int block_cols = std::min(x_dim, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int max_physical_threads = context.GetMaxPhysicalThreadCount(); const int max_blocks = (std::max)(max_physical_threads / kThreadsPerBlock, 1); @@ -172,9 +177,9 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, // Noticed, block size is not align to 32, if needed do it yourself. config.thread_per_block = dim3(block_cols, block_rows, 1); - int grid_x = (std::min)(DivUp(x_dim, block_cols), max_blocks); - int grid_y = - (std::min)(max_blocks / grid_x, (std::max)(y_dim / block_rows, 1)); + int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); + int grid_y = std::min(max_blocks / grid_x, + std::max(y_dim / block_rows, 1)); config.block_per_grid = dim3(grid_x, grid_y, 1); return config; diff --git a/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h b/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h index c0f6f173a7..8bcae15d35 100644 --- a/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h +++ b/paddle/fluid/platform/device/gpu/rocm/rocm_helper.h @@ -67,9 +67,10 @@ namespace platform { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; \ - for (index_type i = __index__; __index__ < (num); \ +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(hipBlockIdx_x) * hipBlockDim_x + hipThreadIdx_x; \ + for (index_type i = __index__; __index__ < (num); \ __index__ += hipBlockDim_x * hipGridDim_x, i = __index__) class CublasHandleHolder { diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index c62addfd25..6d33d802b1 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -62,9 +62,10 @@ namespace gpu { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ - for (index_type i = __index__; __index__ < (num); \ +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; \ + for (index_type i = __index__; __index__ < (num); \ __index__ += blockDim.x * gridDim.x, i = __index__) } // namespace gpu diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 6ea206178c..f0a37d4fb7 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -162,8 +162,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, } inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, - int x_dim, - int y_dim) { + int64_t x_dim, + int64_t y_dim) { PADDLE_ENFORCE_GT( x_dim, 0, @@ -178,7 +178,7 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, y_dim)); const int kThreadsPerBlock = 256; - int block_cols = std::min(x_dim, kThreadsPerBlock); + int block_cols = std::min(x_dim, kThreadsPerBlock); int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int max_physical_threads = context.GetMaxPhysicalThreadCount(); @@ -188,8 +188,9 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(const phi::GPUContext& context, // Noticed, block size is not align to 32, if needed do it yourself. config.thread_per_block = dim3(block_cols, block_rows, 1); - int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); - int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1)); + int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); + int grid_y = std::min(max_blocks / grid_x, + std::max(y_dim / block_rows, 1)); config.block_per_grid = dim3(grid_x, grid_y, 1); return config; diff --git a/paddle/phi/backends/gpu/rocm/rocm_helper.h b/paddle/phi/backends/gpu/rocm/rocm_helper.h index 14e9ca660b..e25dea28e3 100644 --- a/paddle/phi/backends/gpu/rocm/rocm_helper.h +++ b/paddle/phi/backends/gpu/rocm/rocm_helper.h @@ -62,9 +62,10 @@ namespace gpu { * */ -#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ - int64_t __index__ = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; \ - for (index_type i = __index__; __index__ < (num); \ +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ + int64_t __index__ = \ + static_cast(hipBlockIdx_x) * hipBlockDim_x + hipThreadIdx_x; \ + for (index_type i = __index__; __index__ < (num); \ __index__ += hipBlockDim_x * hipGridDim_x, i = __index__) } // namespace gpu diff --git a/paddle/phi/kernels/gpu/one_hot_kernel.cu b/paddle/phi/kernels/gpu/one_hot_kernel.cu index 2ae9e9333e..abe7757df7 100644 --- a/paddle/phi/kernels/gpu/one_hot_kernel.cu +++ b/paddle/phi/kernels/gpu/one_hot_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -28,8 +29,7 @@ __global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, const int64_t numel, const int depth) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) { + CUDA_KERNEL_LOOP_TYPE(idx, numel, int64_t) { PADDLE_ENFORCE(p_in_data[idx] >= 0 && p_in_data[idx] < depth, "Illegal index value, Input(input) value should be " "greater than or equal to 0, and less than depth [%d], " @@ -62,9 +62,10 @@ struct OneHotV2OpCUDAFunctor { auto stream = ctx_.stream(); funcs::set_constant(ctx_, out_, 0.0); - FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx_, numel); + + FillOutputKernel<<>>(p_in_data, p_out_data, numel, depth_); } diff --git a/paddle/phi/kernels/gpu/stack_kernel.cu b/paddle/phi/kernels/gpu/stack_kernel.cu index e5c8d392e6..c079b61c06 100644 --- a/paddle/phi/kernels/gpu/stack_kernel.cu +++ b/paddle/phi/kernels/gpu/stack_kernel.cu @@ -23,20 +23,23 @@ namespace phi { template __global__ void StackCUDAKernel(T** input_ptrs, - int split_size, - int rows, - int cols, + IntType split_size, + IntType rows, + IntType cols, T* __restrict__ output) { - IntType grid_x = blockIdx.x * blockDim.x + threadIdx.x; + IntType grid_x = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IntType grid_x_stride = static_cast(blockDim.x) * gridDim.x; + IntType grid_y_stride = static_cast(blockDim.y) * gridDim.y; - for (; grid_x < cols; grid_x += blockDim.x * gridDim.x) { - IntType grid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; grid_x < cols; grid_x += grid_x_stride) { + IntType grid_y = + static_cast(blockIdx.y) * blockDim.y + threadIdx.y; IntType split = grid_x / split_size; const T* input_ptr = input_ptrs[split]; IntType col_offset = grid_x % split_size; #pragma unroll - for (; grid_y < rows; grid_y += blockDim.y * gridDim.y) { + for (; grid_y < rows; grid_y += grid_y_stride) { output[grid_y * cols + grid_x] = input_ptr[grid_y * split_size + col_offset]; } @@ -69,12 +72,12 @@ void StackKernel(const Context& dev_ctx, dev_ctx.stream()); // Split x dim from axis to matrix - int x_row = 1, x_col = 1; + int64_t x_row = 1, x_col = 1; for (int i = 0; i < axis; ++i) { x_row *= x[0]->dims()[i]; } x_col = x[0]->numel() / x_row; - int out_col = x_col * n; + int64_t out_col = x_col * n; auto config = phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); @@ -85,9 +88,9 @@ void StackKernel(const Context& dev_ctx, config.thread_per_block, 0, dev_ctx.stream()>>>(reinterpret_cast(tmp_x_data->ptr()), - x_col, - x_row, - out_col, + static_cast(x_col), + static_cast(x_row), + static_cast(out_col), y_data); } else { StackCUDAKernel -- GitLab