From 964497b5af0d9506b4203e3b0e0c063b3747c221 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 30 Mar 2023 15:33:23 +0800 Subject: [PATCH] use int64 for c split (#52279) --- .../fluid/operators/collective/c_split_op.cu | 32 +++++++++---------- .../operators/collective/partial_recv_op.cc | 2 +- .../collective/partial_recv_op.cu.cc | 4 +-- .../collective/partial_send_op.cu.cc | 4 +-- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/collective/c_split_op.cu b/paddle/fluid/operators/collective/c_split_op.cu index 5b34e4ba9d5..8bf887b954a 100644 --- a/paddle/fluid/operators/collective/c_split_op.cu +++ b/paddle/fluid/operators/collective/c_split_op.cu @@ -21,10 +21,10 @@ limitations under the License. */ namespace paddle { namespace operators { -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; +static constexpr int64_t kNumCUDAThreads = 512; +static constexpr int64_t kNumMaxinumNumBlocks = 4096; -static inline int NumBlocks(const int N) { +static inline int64_t NumBlocks(const int64_t N) { return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, kNumMaxinumNumBlocks); } @@ -32,21 +32,21 @@ static inline int NumBlocks(const int N) { template __global__ void SplitFromRank(const T* input, T* output, - const int rows, - const int columns, + const int64_t rows, + const int64_t columns, const int rank, const int nranks, - const int limit) { - CUDA_KERNEL_LOOP(i, limit) { - int row = i / columns; - int col = i % columns; + const int64_t limit) { + CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) { + int64_t row = i / columns; + int64_t col = i % columns; - int block = columns / nranks; - int start = block * rank; - int end = start + block; + int64_t block = columns / nranks; + int64_t start = block * rank; + int64_t end = start + block; if (col >= start && col < end) { - int idx = block * row + col % block; + int64_t idx = block * row + col % block; output[idx] = input[i]; } } @@ -93,9 +93,9 @@ class CSplitOpCUDAKernel : public framework::OpKernel { auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); int64_t remain_numel = phi::product(remain_ddim); - int limit = x->numel(); - int blocks = NumBlocks(limit); - int threads = kNumCUDAThreads; + int64_t limit = x->numel(); + int64_t blocks = NumBlocks(limit); + int64_t threads = kNumCUDAThreads; dims[dims_size - 1] /= nranks; out->mutable_data(dims, place); diff --git a/paddle/fluid/operators/collective/partial_recv_op.cc b/paddle/fluid/operators/collective/partial_recv_op.cc index f0effde61b7..9b860f93e3b 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cc @@ -69,7 +69,7 @@ class PartialRecvOp : public framework::OperatorWithKernel { out_shape[i])); } auto out_dims = phi::make_ddim(out_shape); - int numel = phi::product(out_dims); + int64_t numel = phi::product(out_dims); PADDLE_ENFORCE_EQ( (numel % num), 0, diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 526f9425992..ad44ce6a109 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -68,8 +68,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); out->mutable_data(out_dims, place); - int recv_numel = numel / num; - int offset = recv_numel * id; + int64_t recv_numel = numel / num; + int64_t offset = recv_numel * id; auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 84b1e7148df..fb49318c012 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -62,8 +62,8 @@ class PartialSendCUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument( "The input numel (%d) must be divisible by num(%d)", numel, num)); - int send_numel = numel / num; - int offset = send_numel * id; + int64_t send_numel = numel / num; + int64_t offset = send_numel * id; auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { -- GitLab