From af8d248215a0e6f725179c772bb97252cf84a545 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Sun, 3 Apr 2022 13:12:55 +0800 Subject: [PATCH] add maximum limit for grid of index_select (#41127) * limit grid dim for index select * mv LimitGridDim into gpu_launch_config.h * fix conflicts * fix conflicts * fix code style * set block to 256 * fix grid setting * set dtype of block_dim to unsigned int --- .../platform/device/gpu/gpu_launch_config.h | 8 ++++ .../phi/kernels/funcs/elementwise_grad_base.h | 44 ++++++++----------- paddle/phi/kernels/funcs/reduce_function.h | 16 ++----- .../kernels/gpu/index_sample_grad_kernel.cu | 9 +--- paddle/phi/kernels/gpu/index_sample_kernel.cu | 9 +--- .../kernels/gpu/index_select_grad_kernel.cu | 23 +++++----- paddle/phi/kernels/gpu/index_select_kernel.cu | 35 +++++++-------- 7 files changed, 58 insertions(+), 86 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index 4e8b790fa63..4a550e61d42 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -170,6 +170,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D( return config; } +template +void LimitGridDim(const Context& ctx, dim3* grid_dim) { + auto max_grid_dim = reinterpret_cast(ctx) + .GetCUDAMaxGridDimSize(); + grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; + grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; + grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2]; +} } // namespace platform } // namespace paddle diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 23b8388c745..1021b510b26 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -24,6 +24,7 @@ limitations under the License. */ // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" #endif @@ -49,14 +50,6 @@ namespace phi { namespace funcs { using DDim = phi::DDim; -template -void LimitGridDim(const GPUContext &ctx, T *grid_dim) { - auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0]; - if (*grid_dim > max_grid_dim) { - *grid_dim = max_grid_dim; - } -} - template void CommonGradBroadcastCPU(const DenseTensor &x, const DenseTensor &y, @@ -978,17 +971,17 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, constexpr int half_walf = 16; if (w < half_walf || h < half_walf) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int gird_size = w; - ElemwiseGradBroadcast1CUDAKernel<<>>( + int grid_size = w; + ElemwiseGradBroadcast1CUDAKernel<<>>( x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); } else { // suppose perfoemance improves with h increased. dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X); auto gplace = phi::GPUPlace(); auto *ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(gplace)); - LimitGridDim(*ctx, &grid_size); + paddle::platform::LimitGridDim(*ctx, &grid_size); FastElemwiseGradBroadcast1CUDAKernel<<>>( x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); } @@ -1009,13 +1002,12 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, T *dx, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); - int gird_size = n; - int grid_size = n; + dim3 grid_size = dim3(n); auto gplace = phi::GPUPlace(); auto *ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(gplace)); - LimitGridDim(*ctx, &grid_size); - ElemwiseGradBroadcast2CUDAKernel<<>>( + paddle::platform::LimitGridDim(*ctx, &grid_size); + ElemwiseGradBroadcast2CUDAKernel<<>>( x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); } @@ -1216,8 +1208,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, is_y); } else { dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - LimitGridDim(ctx, &grid_size); + dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X); + paddle::platform::LimitGridDim(ctx, &grid_size); FastCommonGradBroadcastCUDAKernelHeight<<>>( x_data, @@ -1392,8 +1384,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, 1, std::multiplies()); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - LimitGridDim(ctx, &grid_size); + dim3 grid_size = dim3(pre * post); + paddle::platform::LimitGridDim(ctx, &grid_size); // we need to calc y offset with blockid, so do x_pre/y_pre to get left // size. if (k_pre != pre) k_pre = pre / k_pre; @@ -1423,8 +1415,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, 1, std::multiplies()); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); - int grid_size = pre * post; - LimitGridDim(ctx, &grid_size); + dim3 grid_size = dim3(pre * post); + paddle::platform::LimitGridDim(ctx, &grid_size); if (k_pre != pre) k_pre = pre / k_pre; FastCommonGradBroadcastOneCUDAKernel<<( - paddle::platform::DeviceContextPool::Instance().Get(place)); - std::array max_grid_dim = ctx->GetCUDAMaxGridDimSize(); - grid.x = grid.x < max_grid_dim[0] ? grid.x : max_grid_dim[0]; - grid.y = grid.y < max_grid_dim[1] ? grid.y : max_grid_dim[1]; - grid.z = grid.z < max_grid_dim[2] ? grid.z : max_grid_dim[2]; - } - public: std::vector reduce_dims_origin; std::vector reduce_dim; @@ -1072,7 +1064,7 @@ void ReduceKernel(const KPDevice& dev_ctx, auto x_dim = phi::vectorize(x.dims()); auto config = ReduceConfig(origin_reduce_dims, x_dim); - config.Run(x.place()); + config.Run(dev_ctx); int numel = x.numel(); // after config.run() // SetOutputData for ReduceHigherDim when should_reduce_again is true, diff --git a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu index 669ae115439..c8c025c7fc1 100644 --- a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu @@ -26,13 +26,6 @@ namespace phi { namespace { -template -void LimitGridDim(const Context& ctx, dim3* grid_dim) { - auto max_grid_dim = - reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); - grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; - grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; -} #define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE 1024 #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -107,7 +100,7 @@ void IndexSampleGradKernel(const Context& ctx, dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); - LimitGridDim(ctx, &grid_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); phi::funcs::SetConstant set_zero; set_zero(ctx, x_grad, static_cast(0)); diff --git a/paddle/phi/kernels/gpu/index_sample_kernel.cu b/paddle/phi/kernels/gpu/index_sample_kernel.cu index 68573d55966..0eca473a565 100644 --- a/paddle/phi/kernels/gpu/index_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/index_sample_kernel.cu @@ -25,13 +25,6 @@ namespace phi { namespace { -template -void LimitGridDim(const Context& ctx, dim3* grid_dim) { - auto max_grid_dim = - reinterpret_cast(ctx).GetCUDAMaxGridDimSize(); - grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0]; - grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1]; -} #define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE 1024 #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -95,7 +88,7 @@ void IndexSampleKernel(const Context& ctx, dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); - LimitGridDim(ctx, &grid_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); if (index_type == DataType::INT64) { const int64_t* index_data = index.data(); diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index b3bd307e2aa..209ce1ccf5c 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/index_select_grad_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" @@ -89,25 +90,23 @@ void IndexSelectGradKernel(const Context& ctx, auto stream = ctx.stream(); - index_select_grad_init< - T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(in_grad_data, numel); + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); - int blocks = - (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; - int threads = PADDLE_CUDA_NUM_THREADS; + index_select_grad_init<<>>(in_grad_data, + numel); if (FLAGS_cudnn_deterministic) { VLOG(2) << "Run grad kernel of index_select with single thread."; - blocks = 1; - threads = 1; + block_dim = 1; + grid_dim.x = 1; } if (index_type == phi::DataType::INT64) { const int64_t* index_data = index.data(); - index_select_grad_cuda_kernel<<>>( + index_select_grad_cuda_kernel<<>>( output_grad_data, in_grad_data, index_data, @@ -118,7 +117,7 @@ void IndexSelectGradKernel(const Context& ctx, delta); } else { const int* index_data = index.data(); - index_select_grad_cuda_kernel<<>>( + index_select_grad_cuda_kernel<<>>( output_grad_data, in_grad_data, index_data, diff --git a/paddle/phi/kernels/gpu/index_select_kernel.cu b/paddle/phi/kernels/gpu/index_select_kernel.cu index e82976d46e6..57a13a9aefc 100644 --- a/paddle/phi/kernels/gpu/index_select_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/index_select_kernel.h" +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" @@ -31,16 +32,14 @@ __global__ void index_select_cuda_kernel(const T* input, int64_t stride, int64_t size, int64_t delta) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; + CUDA_KERNEL_LOOP(idx, N) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; } - - int64_t pre_idx = idx / (stride * size); - int64_t dim_idx = idx % (stride * size) / stride; - IndexT src_dim_idx = index[dim_idx]; - int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; - output[idx] = input[input_idx]; } template @@ -75,21 +74,17 @@ void IndexSelectKernel(const Context& ctx, int64_t numel = output->numel(); auto stream = ctx.stream(); + unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; + dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); + paddle::platform::LimitGridDim(ctx, &grid_dim); + if (index_type == phi::DataType::INT64) { const int64_t* index_data = index.data(); - index_select_cuda_kernel<<< - (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(in_data, out_data, index_data, numel, stride, size, delta); + index_select_cuda_kernel<<>>( + in_data, out_data, index_data, numel, stride, size, delta); } else { const int* index_data = index.data(); - index_select_cuda_kernel< - T, - int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>( + index_select_cuda_kernel<<>>( in_data, out_data, index_data, numel, stride, size, delta); } } -- GitLab