未验证 提交 608a5f55 编写于 作者: F FlyingQianMM 提交者: GitHub

add maximum limit for grid of reduce, elementwise, gather and scatter (#40813)

* add maximum limit for grid of reduce, elementwise and gather

* add {} after if
上级 609077e9
......@@ -128,6 +128,10 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(
// Number of threads per block shall be larger than 64.
threads = std::max(64, threads);
int blocks = DivUp(DivUp(numel, vec_size), threads);
int limit_blocks = context.GetCUDAMaxGridDimSize()[0];
if (blocks > limit_blocks) {
blocks = limit_blocks;
}
GpuLaunchConfig config;
config.thread_per_block.x = threads;
......
......@@ -132,6 +132,10 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context,
// Number of threads per block shall be larger than 64.
threads = std::max(64, threads);
int blocks = DivUp(DivUp(numel, vec_size), threads);
int limit_blocks = context.GetCUDAMaxGridDimSize()[0];
if (blocks > limit_blocks) {
blocks = limit_blocks;
}
GpuLaunchConfig config;
config.thread_per_block.x = threads;
......
......@@ -49,6 +49,14 @@ namespace phi {
namespace funcs {
using DDim = phi::DDim;
template <typename T>
void LimitGridDim(const GPUContext &ctx, T *grid_dim) {
auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0];
if (*grid_dim > max_grid_dim) {
*grid_dim = max_grid_dim;
}
}
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCPU(const DenseTensor &x,
const DenseTensor &y,
......@@ -977,6 +985,10 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
auto gplace = phi::GPUPlace();
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
LimitGridDim(*ctx, &grid_size);
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
}
......@@ -998,6 +1010,11 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
int grid_size = n;
auto gplace = phi::GPUPlace();
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
LimitGridDim(*ctx, &grid_size);
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
}
......@@ -1200,6 +1217,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size,
0,
......@@ -1236,6 +1254,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size,
0,
......@@ -1332,6 +1351,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>(
x_data,
......@@ -1373,6 +1393,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
LimitGridDim(ctx, &grid_size);
// we need to calc y offset with blockid, so do x_pre/y_pre to get left
// size.
if (k_pre != pre) k_pre = pre / k_pre;
......@@ -1403,6 +1424,7 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
LimitGridDim(ctx, &grid_size);
if (k_pre != pre) k_pre = pre / k_pre;
FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
......
......@@ -112,6 +112,10 @@ void GPUGather(const phi::GPUContext& ctx,
int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
if (grid > maxGridDimX) {
grid = maxGridDimX;
}
GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
......@@ -161,6 +165,10 @@ void GPUGatherNd(const phi::GPUContext& ctx,
int block = 512;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
if (grid > maxGridDimX) {
grid = maxGridDimX;
}
GatherNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_input,
g_input_dims,
......
......@@ -309,7 +309,7 @@ struct ReduceConfig {
: reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}
// get the parameters of reduceKernel
void Run() {
void Run(const paddle::platform::Place& place) {
// step1: update the reduce_dim left_dim and x_dim
SetReduceDim();
......@@ -321,6 +321,9 @@ struct ReduceConfig {
// step4: set the block and grid for launch kernel
SetBlockDim();
// step5: limit the grid to prevent thead overflow
LimitGridDim(place);
}
// when should_reduce_again is true, we need malloc temp space for temp data
......@@ -604,6 +607,15 @@ struct ReduceConfig {
grid = grid_dim;
}
void LimitGridDim(const paddle::platform::Place& place) {
auto* ctx = static_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
std::array<int, 3> max_grid_dim = ctx->GetCUDAMaxGridDimSize();
grid.x = grid.x < max_grid_dim[0] ? grid.x : max_grid_dim[0];
grid.y = grid.y < max_grid_dim[1] ? grid.y : max_grid_dim[1];
grid.z = grid.z < max_grid_dim[2] ? grid.z : max_grid_dim[2];
}
public:
std::vector<int> reduce_dims_origin;
std::vector<int> reduce_dim;
......@@ -1060,7 +1072,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto x_dim = phi::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run();
config.Run(x.place());
int numel = x.numel();
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
......
......@@ -156,6 +156,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
grid = grid > maxGridDimX ? maxGridDimX : grid;
// if not overwrite mode, init data
if (!overwrite) {
......@@ -240,6 +242,8 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,
int block = 512;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
grid = grid > maxGridDimX ? maxGridDimX : grid;
ScatterNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_update,
......@@ -252,4 +256,4 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,
}
} // namespace funcs
} // namespace pten
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册