diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index 59c8c9f3b8f0ed0569f40ab1476f629bd847d0c2..617d249308cda2abe47385cd1293e2b2fb9e4b29 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/memory/memcpy.h" // TODO(paddle-dev): move gpu_primitives.h to phi +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/place.h" @@ -110,11 +111,8 @@ void GPUGather(const phi::GPUContext& ctx, int block = 512; int64_t n = slice_size * index_size; - int64_t grid = (n + block - 1) / block; - unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - if (grid > maxGridDimX) { - grid = maxGridDimX; - } + dim3 grid = dim3((n + block - 1) / block); + paddle::platform::LimitGridDim(ctx, &grid); GatherCUDAKernel<<>>( p_src, p_index, p_output, index_size, slice_size); @@ -155,11 +153,8 @@ void GPUGatherNd(const phi::GPUContext& ctx, int block = 512; int64_t n = slice_size * remain_numel; - int64_t grid = (n + block - 1) / block; - unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - if (grid > maxGridDimX) { - grid = maxGridDimX; - } + dim3 grid = dim3((n + block - 1) / block); + paddle::platform::LimitGridDim(ctx, &grid); GatherNdCUDAKernel<<>>(p_input, g_input_dims, diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index 254dd45edb596243a8867dee52708d5de6776bea..87083af3bc6a2e66cb59c7a185f2f3a0c4d2bde4 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include +#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" @@ -155,9 +156,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx, // set block and grid num int block = 512; int64_t n = slice_size * index_size; - int64_t grid = (n + block - 1) / block; - unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - grid = grid > maxGridDimX ? maxGridDimX : grid; + dim3 grid = dim3((n + block - 1) / block); + paddle::platform::LimitGridDim(ctx, &grid); // if not overwrite mode, init data if (!overwrite) { @@ -188,9 +188,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx, int64_t block = 512; int64_t n = slice_size * index_size; int64_t height = (n + block - 1) / block; - - int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; - int64_t grid = height < max_grid_dimx ? height : max_grid_dimx; + dim3 grid = dim3((n + block - 1) / block); + paddle::platform::LimitGridDim(ctx, &grid); ScatterInitCUDAKernel<<>>( p_index, p_output, index_size, slice_size); @@ -230,9 +229,8 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx, int block = 512; int64_t n = slice_size * remain_numel; - int64_t grid = (n + block - 1) / block; - unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0]; - grid = grid > maxGridDimX ? maxGridDimX : grid; + dim3 grid = dim3((n + block - 1) / block); + paddle::platform::LimitGridDim(ctx, &grid); ScatterNdCUDAKernel<<>>( p_update, diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index 75ae1bbcd0a08d9e7c31fa639bf0b90355fd454d..84094f4c1ee5ae1d4e96ce882355c94feb003915 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -19,6 +19,7 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" DECLARE_bool(cudnn_deterministic); @@ -35,7 +36,7 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad, int64_t stride, int64_t size, int64_t delta) { - CUDA_KERNEL_LOOP(idx, N) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { int64_t pre_idx = idx / (stride * size); int64_t dim_idx = idx % (stride * size) / stride; IndexT src_dim_idx = index[dim_idx]; @@ -45,15 +46,6 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad, } } -template -__global__ void index_select_grad_init(T* input_grad, int64_t N) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; - } - input_grad[idx] = 0.0; -} - template void IndexSelectGradKernel(const Context& ctx, const DenseTensor& x, @@ -97,8 +89,8 @@ void IndexSelectGradKernel(const Context& ctx, dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim); paddle::platform::LimitGridDim(ctx, &grid_dim); - index_select_grad_init<<>>(in_grad_data, - numel); + phi::funcs::SetConstant index_select_grad_init; + index_select_grad_init(ctx, x_grad, static_cast(0)); if (FLAGS_cudnn_deterministic) { VLOG(2) << "Run grad kernel of index_select with single thread."; diff --git a/paddle/phi/kernels/gpu/index_select_kernel.cu b/paddle/phi/kernels/gpu/index_select_kernel.cu index 38a6582d790f8c86c4255e443acb9bc26e74f094..0a6ac69cef0981edebb2d273c04e103e9679e3ad 100644 --- a/paddle/phi/kernels/gpu/index_select_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_kernel.cu @@ -32,7 +32,7 @@ __global__ void index_select_cuda_kernel(const T* input, int64_t stride, int64_t size, int64_t delta) { - CUDA_KERNEL_LOOP(idx, N) { + CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) { int64_t pre_idx = idx / (stride * size); int64_t dim_idx = idx % (stride * size) / stride; IndexT src_dim_idx = index[dim_idx];