未验证 提交 af8d2482 编写于 作者: F FlyingQianMM 提交者: GitHub

add maximum limit for grid of index_select (#41127)

* limit grid dim for index select

* mv LimitGridDim into gpu_launch_config.h

* fix conflicts

* fix conflicts

* fix code style

* set block to 256

* fix grid setting

* set dtype of block_dim to unsigned int
上级 61e60e68
...@@ -170,6 +170,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D( ...@@ -170,6 +170,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
return config; return config;
} }
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim = reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2];
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h"
#endif #endif
...@@ -49,14 +50,6 @@ namespace phi { ...@@ -49,14 +50,6 @@ namespace phi {
namespace funcs { namespace funcs {
using DDim = phi::DDim; using DDim = phi::DDim;
template <typename T>
void LimitGridDim(const GPUContext &ctx, T *grid_dim) {
auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0];
if (*grid_dim > max_grid_dim) {
*grid_dim = max_grid_dim;
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T> template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCPU(const DenseTensor &x, void CommonGradBroadcastCPU(const DenseTensor &x,
const DenseTensor &y, const DenseTensor &y,
...@@ -978,17 +971,17 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, ...@@ -978,17 +971,17 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
constexpr int half_walf = 16; constexpr int half_walf = 16;
if (w < half_walf || h < half_walf) { if (w < half_walf || h < half_walf) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w; int grid_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>( ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
} else { } else {
// suppose perfoemance improves with h increased. // suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y); dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X; dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
auto gplace = phi::GPUPlace(); auto gplace = phi::GPUPlace();
auto *ctx = static_cast<GPUContext *>( auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace)); paddle::platform::DeviceContextPool::Instance().Get(gplace));
LimitGridDim(*ctx, &grid_size); paddle::platform::LimitGridDim(*ctx, &grid_size);
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>( FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
} }
...@@ -1009,13 +1002,12 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, ...@@ -1009,13 +1002,12 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
T *dx, T *dx,
T *dy) { T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n; dim3 grid_size = dim3(n);
int grid_size = n;
auto gplace = phi::GPUPlace(); auto gplace = phi::GPUPlace();
auto *ctx = static_cast<GPUContext *>( auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace)); paddle::platform::DeviceContextPool::Instance().Get(gplace));
LimitGridDim(*ctx, &grid_size); paddle::platform::LimitGridDim(*ctx, &grid_size);
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>( ElemwiseGradBroadcast2CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
} }
...@@ -1216,8 +1208,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, ...@@ -1216,8 +1208,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
is_y); is_y);
} else { } else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y); dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X; dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
LimitGridDim(ctx, &grid_size); paddle::platform::LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size, FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size, block_size,
0, 0,
...@@ -1253,8 +1245,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, ...@@ -1253,8 +1245,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
is_y); is_y);
} else { } else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y); dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X; dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
LimitGridDim(ctx, &grid_size); paddle::platform::LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size, FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size, block_size,
0, 0,
...@@ -1350,8 +1342,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, ...@@ -1350,8 +1342,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
<< " post:" << post; << " post:" << post;
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post; dim3 grid_size = dim3(pre * post);
LimitGridDim(ctx, &grid_size); paddle::platform::LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>( FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>(
x_data, x_data,
...@@ -1392,8 +1384,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, ...@@ -1392,8 +1384,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
1, 1,
std::multiplies<int>()); std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post; dim3 grid_size = dim3(pre * post);
LimitGridDim(ctx, &grid_size); paddle::platform::LimitGridDim(ctx, &grid_size);
// we need to calc y offset with blockid, so do x_pre/y_pre to get left // we need to calc y offset with blockid, so do x_pre/y_pre to get left
// size. // size.
if (k_pre != pre) k_pre = pre / k_pre; if (k_pre != pre) k_pre = pre / k_pre;
...@@ -1423,8 +1415,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x, ...@@ -1423,8 +1415,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
1, 1,
std::multiplies<int>()); std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post; dim3 grid_size = dim3(pre * post);
LimitGridDim(ctx, &grid_size); paddle::platform::LimitGridDim(ctx, &grid_size);
if (k_pre != pre) k_pre = pre / k_pre; if (k_pre != pre) k_pre = pre / k_pre;
FastCommonGradBroadcastOneCUDAKernel<<<grid_size, FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
......
...@@ -33,6 +33,7 @@ namespace cub = hipcub; ...@@ -33,6 +33,7 @@ namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
...@@ -309,7 +310,7 @@ struct ReduceConfig { ...@@ -309,7 +310,7 @@ struct ReduceConfig {
: reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {} : reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
// get the parameters of reduceKernel // get the parameters of reduceKernel
void Run(const paddle::platform::Place& place) { void Run(const KPDevice& dev_ctx) {
// step1: update the reduce_dim left_dim and x_dim // step1: update the reduce_dim left_dim and x_dim
SetReduceDim(); SetReduceDim();
...@@ -323,7 +324,7 @@ struct ReduceConfig { ...@@ -323,7 +324,7 @@ struct ReduceConfig {
SetBlockDim(); SetBlockDim();
// step5: limit the grid to prevent thead overflow // step5: limit the grid to prevent thead overflow
LimitGridDim(place); paddle::platform::LimitGridDim(dev_ctx, &grid);
} }
// when should_reduce_again is true, we need malloc temp space for temp data // when should_reduce_again is true, we need malloc temp space for temp data
...@@ -607,15 +608,6 @@ struct ReduceConfig { ...@@ -607,15 +608,6 @@ struct ReduceConfig {
grid = grid_dim; grid = grid_dim;
} }
void LimitGridDim(const paddle::platform::Place& place) {
auto* ctx = static_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
std::array<int, 3> max_grid_dim = ctx->GetCUDAMaxGridDimSize();
grid.x = grid.x < max_grid_dim[0] ? grid.x : max_grid_dim[0];
grid.y = grid.y < max_grid_dim[1] ? grid.y : max_grid_dim[1];
grid.z = grid.z < max_grid_dim[2] ? grid.z : max_grid_dim[2];
}
public: public:
std::vector<int> reduce_dims_origin; std::vector<int> reduce_dims_origin;
std::vector<int> reduce_dim; std::vector<int> reduce_dim;
...@@ -1072,7 +1064,7 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1072,7 +1064,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto x_dim = phi::vectorize<int>(x.dims()); auto x_dim = phi::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim); auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(x.place()); config.Run(dev_ctx);
int numel = x.numel(); int numel = x.numel();
// after config.run() // after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true, // SetOutputData for ReduceHigherDim when should_reduce_again is true,
......
...@@ -26,13 +26,6 @@ ...@@ -26,13 +26,6 @@
namespace phi { namespace phi {
namespace { namespace {
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
#define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024 #define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -107,7 +100,7 @@ void IndexSampleGradKernel(const Context& ctx, ...@@ -107,7 +100,7 @@ void IndexSampleGradKernel(const Context& ctx,
dim3 block_dim(block_width, block_height); dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y); (batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim); paddle::platform::LimitGridDim(ctx, &grid_dim);
phi::funcs::SetConstant<Context, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, x_grad, static_cast<T>(0)); set_zero(ctx, x_grad, static_cast<T>(0));
......
...@@ -25,13 +25,6 @@ ...@@ -25,13 +25,6 @@
namespace phi { namespace phi {
namespace { namespace {
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
#define PREDEFINED_BLOCK_SIZE_X 512 #define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024 #define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -95,7 +88,7 @@ void IndexSampleKernel(const Context& ctx, ...@@ -95,7 +88,7 @@ void IndexSampleKernel(const Context& ctx,
dim3 block_dim(block_width, block_height); dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y); (batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim); paddle::platform::LimitGridDim(ctx, &grid_dim);
if (index_type == DataType::INT64) { if (index_type == DataType::INT64) {
const int64_t* index_data = index.data<int64_t>(); const int64_t* index_data = index.data<int64_t>();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/index_select_grad_kernel.h" #include "paddle/phi/kernels/index_select_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -89,25 +90,23 @@ void IndexSelectGradKernel(const Context& ctx, ...@@ -89,25 +90,23 @@ void IndexSelectGradKernel(const Context& ctx,
auto stream = ctx.stream(); auto stream = ctx.stream();
index_select_grad_init< unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
PADDLE_CUDA_NUM_THREADS, paddle::platform::LimitGridDim(ctx, &grid_dim);
0,
stream>>>(in_grad_data, numel);
int blocks = index_select_grad_init<T><<<grid_dim, block_dim, 0, stream>>>(in_grad_data,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; numel);
int threads = PADDLE_CUDA_NUM_THREADS;
if (FLAGS_cudnn_deterministic) { if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_select with single thread."; VLOG(2) << "Run grad kernel of index_select with single thread.";
blocks = 1; block_dim = 1;
threads = 1; grid_dim.x = 1;
} }
if (index_type == phi::DataType::INT64) { if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>(); const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<blocks, threads, 0, stream>>>( index_select_grad_cuda_kernel<T,
int64_t><<<grid_dim, block_dim, 0, stream>>>(
output_grad_data, output_grad_data,
in_grad_data, in_grad_data,
index_data, index_data,
...@@ -118,7 +117,7 @@ void IndexSelectGradKernel(const Context& ctx, ...@@ -118,7 +117,7 @@ void IndexSelectGradKernel(const Context& ctx,
delta); delta);
} else { } else {
const int* index_data = index.data<int>(); const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int><<<blocks, threads, 0, stream>>>( index_select_grad_cuda_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>(
output_grad_data, output_grad_data,
in_grad_data, in_grad_data,
index_data, index_data,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/index_select_kernel.h" #include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -31,16 +32,14 @@ __global__ void index_select_cuda_kernel(const T* input, ...@@ -31,16 +32,14 @@ __global__ void index_select_cuda_kernel(const T* input,
int64_t stride, int64_t stride,
int64_t size, int64_t size,
int64_t delta) { int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; CUDA_KERNEL_LOOP(idx, N) {
if (idx >= N) { int64_t pre_idx = idx / (stride * size);
return; int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx =
idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
} }
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -75,21 +74,17 @@ void IndexSelectKernel(const Context& ctx, ...@@ -75,21 +74,17 @@ void IndexSelectKernel(const Context& ctx,
int64_t numel = output->numel(); int64_t numel = output->numel();
auto stream = ctx.stream(); auto stream = ctx.stream();
unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);
if (index_type == phi::DataType::INT64) { if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>(); const int64_t* index_data = index.data<int64_t>();
index_select_cuda_kernel<T, int64_t><<< index_select_cuda_kernel<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, in_data, out_data, index_data, numel, stride, size, delta);
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_data, out_data, index_data, numel, stride, size, delta);
} else { } else {
const int* index_data = index.data<int>(); const int* index_data = index.data<int>();
index_select_cuda_kernel< index_select_cuda_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>(
T,
int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta); in_data, out_data, index_data, numel, stride, size, delta);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册