From 4adeff06aebf2d824e361caced9f94506a68533b Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Fri, 21 Jan 2022 12:51:57 +0800 Subject: [PATCH] add block and grid loop for index_sample kernel to deal with a large-shape tensor (#37816) * add block and grid loop for index_sample kernel to deal with a large-shape tensor * fix code format * limit grid dim --- paddle/fluid/operators/index_sample_op.cu | 63 +++++++++++++++-------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 4260d0516e..45f63c2b2f 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -18,9 +18,22 @@ #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#define PREDEFINED_BLOCK_SIZE_X 512 +#define PREDEFINED_BLOCK_SIZE 1024 +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + namespace paddle { namespace operators { +namespace { +void LimitGridDim(const framework::ExecutionContext& ctx, dim3* grid_dim) { + dim3 max_grid_dim = ctx.template device_context() + .GetCUDAMaxGridDimSize(); + grid_dim->x = grid_dim->x < max_grid_dim.x ? grid_dim->x : max_grid_dim.x; + grid_dim->y = grid_dim->y < max_grid_dim.y ? grid_dim->y : max_grid_dim.y; +} +} + using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; @@ -28,14 +41,15 @@ template __global__ void IndexSampleForward(const IndexT* index, const T* in_data, T* out_data, size_t index_length, size_t input_length, size_t batch_size) { - int index_i = blockDim.x * blockIdx.x + threadIdx.x; - int index_j = blockDim.y * blockIdx.y + threadIdx.y; - int index_idx = index_j * index_length + index_i; - int in_idx = index_j * input_length + index_i; - - if (index_i < index_length & index_j < batch_size) { - IndexT sample_idx = index[index_idx]; - out_data[index_idx] = in_data[in_idx - index_i + sample_idx]; + unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; + unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; + for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { + unsigned int index_idx = index_j * index_length + index_i; + unsigned int in_idx = index_j * input_length + index_i; + IndexT sample_idx = index[index_idx]; + out_data[index_idx] = in_data[in_idx - index_i + sample_idx]; + } } } @@ -44,18 +58,20 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad, const T* out_grad, size_t index_length, size_t input_length, size_t batch_size, bool same_data_in_row = true) { - int index_i = blockDim.x * blockIdx.x + threadIdx.x; - int index_j = blockDim.y * blockIdx.y + threadIdx.y; - int index_idx = index_j * index_length + index_i; - int in_idx = index_j * input_length + index_i; - - if (index_i < index_length & index_j < batch_size) { - IndexT sample_idx = index[index_idx]; - if (same_data_in_row) { - platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]), - out_grad[sample_idx]); - } else { - in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx]; + unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; + unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; + + for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { + unsigned int index_idx = index_j * index_length + index_i; + unsigned int in_idx = index_j * input_length + index_i; + IndexT sample_idx = index[index_idx]; + if (same_data_in_row) { + platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]), + out_grad[sample_idx]); + } else { + in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx]; + } } } } @@ -93,12 +109,14 @@ class IndexSampleKernel size_t index_length = index_dim[1]; auto block_width = platform::RoundToPowerOfTwo(index_length); + block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X); int block_height = platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; - + block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width); dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); + LimitGridDim(ctx, &grid_dim); if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); @@ -150,11 +168,14 @@ class IndexSampleGradKernel bool same_data_in_index_row = index_length == 1 ? false : true; auto block_width = platform::RoundToPowerOfTwo(index_length); + block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X); auto block_height = platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; + block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width); dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); + LimitGridDim(ctx, &grid_dim); math::SetConstant set_zero; auto& dev_ctx = ctx.template device_context(); -- GitLab