From c6950ab2573aece1fa0728aef1446bd8b0b8c1a0 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Sun, 20 Feb 2022 11:47:55 +0800 Subject: [PATCH] add index initialization in the block loop for index_sample kernel when dealing with a input tensor whose shape is larger than block_dim * grid_dim (#39736) * add block and grid loop for index_sample kernel to deal with a large-shape tensor * fix code format * limit grid dim * fix the omissive initialization of index_i in the second cycle for index_sample kernel * fix conflicts --- paddle/fluid/operators/index_sample_op.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 5e2ab92285..1b1fb8467f 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -44,6 +44,7 @@ __global__ void IndexSampleForward(const IndexT* index, const T* in_data, unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + index_i = blockDim.x * blockIdx.x + threadIdx.x; for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { unsigned int index_idx = index_j * index_length + index_i; unsigned int in_idx = index_j * input_length + index_i; @@ -62,6 +63,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad, unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + index_i = blockDim.x * blockIdx.x + threadIdx.x; for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { unsigned int index_idx = index_j * index_length + index_i; unsigned int in_idx = index_j * input_length + index_i; -- GitLab