diff --git a/paddle/fluid/operators/index_impl.cu.h b/paddle/fluid/operators/index_impl.cu.h index 3d6a5e0ea88a28addaf09d90cae9659cbea85305..2e3e6569ef5a88f8dfcb6646974b70bcc6c0c95f 100644 --- a/paddle/fluid/operators/index_impl.cu.h +++ b/paddle/fluid/operators/index_impl.cu.h @@ -31,24 +31,24 @@ namespace operators { namespace kps = phi::kps; template -__global__ void VectorizedIndexKernel(T *out, int numel, int main_offset, +__global__ void VectorizedIndexKernel(T *out, size_t numel, size_t main_offset, Functor func) { - int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; - int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; - int args[VecSize]; + size_t data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + size_t stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + size_t args[VecSize]; T result[VecSize]; for (; data_offset < main_offset; data_offset += stride) { - kps::InitWithDataIndex(&args[0], data_offset); - kps::ElementwiseUnary(&result[0], &args[0], - func); + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary(&result[0], + &args[0], func); kps::WriteData(out + data_offset, &result[0], BLOCK_NUM_X * VecSize); } - int num = numel - data_offset; + size_t num = numel - data_offset; if (num > 0) { - kps::InitWithDataIndex(&args[0], data_offset); - kps::ElementwiseUnary(&result[0], &args[0], - func); + kps::InitWithDataIndex(&args[0], data_offset); + kps::ElementwiseUnary(&result[0], + &args[0], func); kps::WriteData(out + data_offset, &result[0], num); } } @@ -58,7 +58,7 @@ void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) { int numel = out->numel(); T *out_data = out->mutable_data(dev_ctx.GetPlace()); if (numel <= 0) return; - int vec_size = paddle::platform::GetVectorizedSize((out->data())); + int vec_size = paddle::platform::GetVectorizedSize(out_data); #ifdef PADDLE_WITH_XPU_KP int block = 64; int grid = 8; @@ -70,8 +70,7 @@ void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) { int block = config.thread_per_block.x; auto stream = dev_ctx.stream(); #endif - - int main_offset = (numel / (vec_size * block)) * vec_size * block; + size_t main_offset = (numel / (vec_size * block)) * vec_size * block; switch (vec_size) { case 4: VectorizedIndexKernel<<>>(