未验证 提交 4adeff06 编写于 作者: F FlyingQianMM 提交者: GitHub

add block and grid loop for index_sample kernel to deal with a large-shape tensor (#37816)

* add block and grid loop for index_sample kernel to deal with a large-shape tensor

* fix code format

* limit grid dim
上级 ba51a6c8
......@@ -18,9 +18,22 @@
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace paddle {
namespace operators {
namespace {
void LimitGridDim(const framework::ExecutionContext& ctx, dim3* grid_dim) {
dim3 max_grid_dim = ctx.template device_context<platform::CUDADeviceContext>()
.GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim.x ? grid_dim->x : max_grid_dim.x;
grid_dim->y = grid_dim->y < max_grid_dim.y ? grid_dim->y : max_grid_dim.y;
}
}
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
......@@ -28,15 +41,16 @@ template <typename T, typename IndexT = int>
__global__ void IndexSampleForward(const IndexT* index, const T* in_data,
T* out_data, size_t index_length,
size_t input_length, size_t batch_size) {
int index_i = blockDim.x * blockIdx.x + threadIdx.x;
int index_j = blockDim.y * blockIdx.y + threadIdx.y;
int index_idx = index_j * index_length + index_i;
int in_idx = index_j * input_length + index_i;
if (index_i < index_length & index_j < batch_size) {
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
}
}
}
template <typename T, typename IndexT = int>
......@@ -44,12 +58,13 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
const T* out_grad, size_t index_length,
size_t input_length, size_t batch_size,
bool same_data_in_row = true) {
int index_i = blockDim.x * blockIdx.x + threadIdx.x;
int index_j = blockDim.y * blockIdx.y + threadIdx.y;
int index_idx = index_j * index_length + index_i;
int in_idx = index_j * input_length + index_i;
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
if (index_i < index_length & index_j < batch_size) {
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
......@@ -58,6 +73,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
}
}
}
}
template <typename T>
......@@ -93,12 +109,14 @@ class IndexSampleKernel<platform::CUDADeviceContext, T>
size_t index_length = index_dim[1];
auto block_width = platform::RoundToPowerOfTwo(index_length);
block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
int block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
......@@ -150,11 +168,14 @@ class IndexSampleGradKernel<platform::CUDADeviceContext, T>
bool same_data_in_index_row = index_length == 1 ? false : true;
auto block_width = platform::RoundToPowerOfTwo(index_length);
block_width = MIN(block_width, PREDEFINED_BLOCK_SIZE_X);
auto block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
block_height = MIN(block_height, PREDEFINED_BLOCK_SIZE / block_width);
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册