diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 5e2ab922850ff93b1f209fd1e3ce1057c8c6d793..1b1fb8467f5f4915e4ad92bec946684e433621f1 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -44,6 +44,7 @@ __global__ void IndexSampleForward(const IndexT* index, const T* in_data, unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + index_i = blockDim.x * blockIdx.x + threadIdx.x; for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { unsigned int index_idx = index_j * index_length + index_i; unsigned int in_idx = index_j * input_length + index_i; @@ -62,6 +63,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad, unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { + index_i = blockDim.x * blockIdx.x + threadIdx.x; for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { unsigned int index_idx = index_j * index_length + index_i; unsigned int in_idx = index_j * input_length + index_i;