diff --git a/paddle/fluid/operators/collective/c_split_op.cu b/paddle/fluid/operators/collective/c_split_op.cu index 5b34e4ba9d59466a7f89cbc534957d083d36b25e..8bf887b954aac8e7195136dfe98fe028b66a7ef5 100644 --- a/paddle/fluid/operators/collective/c_split_op.cu +++ b/paddle/fluid/operators/collective/c_split_op.cu @@ -21,10 +21,10 @@ limitations under the License. */ namespace paddle { namespace operators { -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; +static constexpr int64_t kNumCUDAThreads = 512; +static constexpr int64_t kNumMaxinumNumBlocks = 4096; -static inline int NumBlocks(const int N) { +static inline int64_t NumBlocks(const int64_t N) { return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, kNumMaxinumNumBlocks); } @@ -32,21 +32,21 @@ static inline int NumBlocks(const int N) { template __global__ void SplitFromRank(const T* input, T* output, - const int rows, - const int columns, + const int64_t rows, + const int64_t columns, const int rank, const int nranks, - const int limit) { - CUDA_KERNEL_LOOP(i, limit) { - int row = i / columns; - int col = i % columns; + const int64_t limit) { + CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) { + int64_t row = i / columns; + int64_t col = i % columns; - int block = columns / nranks; - int start = block * rank; - int end = start + block; + int64_t block = columns / nranks; + int64_t start = block * rank; + int64_t end = start + block; if (col >= start && col < end) { - int idx = block * row + col % block; + int64_t idx = block * row + col % block; output[idx] = input[i]; } } @@ -93,9 +93,9 @@ class CSplitOpCUDAKernel : public framework::OpKernel { auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); int64_t remain_numel = phi::product(remain_ddim); - int limit = x->numel(); - int blocks = NumBlocks(limit); - int threads = kNumCUDAThreads; + int64_t limit = x->numel(); + int64_t blocks = NumBlocks(limit); + int64_t threads = kNumCUDAThreads; dims[dims_size - 1] /= nranks; out->mutable_data(dims, place); diff --git a/paddle/fluid/operators/collective/partial_recv_op.cc b/paddle/fluid/operators/collective/partial_recv_op.cc index f0effde61b7e2426d247e58e6b95aee89cdcb855..9b860f93e3bd2b5d3f9e1750069f5bd4678ffe24 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cc @@ -69,7 +69,7 @@ class PartialRecvOp : public framework::OperatorWithKernel { out_shape[i])); } auto out_dims = phi::make_ddim(out_shape); - int numel = phi::product(out_dims); + int64_t numel = phi::product(out_dims); PADDLE_ENFORCE_EQ( (numel % num), 0, diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index 526f942599210a686e4196cb5f307392d34c3301..ad44ce6a1095bac43df00697cb8ec9983632571e 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -68,8 +68,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); out->mutable_data(out_dims, place); - int recv_numel = numel / num; - int offset = recv_numel * id; + int64_t recv_numel = numel / num; + int64_t offset = recv_numel * id; auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) { diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 84b1e7148df0216eaf69721898a2e2a16cfbfcba..fb49318c01221a40fcc0d0059ba26e4e62c7d115 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -62,8 +62,8 @@ class PartialSendCUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument( "The input numel (%d) must be divisible by num(%d)", numel, num)); - int send_numel = numel / num; - int offset = send_numel * id; + int64_t send_numel = numel / num; + int64_t offset = send_numel * id; auto map = distributed::ProcessGroupMapFromGid::getInstance(); if (map->has(rid)) {