未验证 提交 964497b5 编写于 作者: Y Yuang Liu 提交者: GitHub

use int64 for c split (#52279)

上级 bd3b6adf
...@@ -21,10 +21,10 @@ limitations under the License. */ ...@@ -21,10 +21,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static constexpr int kNumCUDAThreads = 512; static constexpr int64_t kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096; static constexpr int64_t kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) { static inline int64_t NumBlocks(const int64_t N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks); kNumMaxinumNumBlocks);
} }
...@@ -32,21 +32,21 @@ static inline int NumBlocks(const int N) { ...@@ -32,21 +32,21 @@ static inline int NumBlocks(const int N) {
template <typename T> template <typename T>
__global__ void SplitFromRank(const T* input, __global__ void SplitFromRank(const T* input,
T* output, T* output,
const int rows, const int64_t rows,
const int columns, const int64_t columns,
const int rank, const int rank,
const int nranks, const int nranks,
const int limit) { const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) { CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) {
int row = i / columns; int64_t row = i / columns;
int col = i % columns; int64_t col = i % columns;
int block = columns / nranks; int64_t block = columns / nranks;
int start = block * rank; int64_t start = block * rank;
int end = start + block; int64_t end = start + block;
if (col >= start && col < end) { if (col >= start && col < end) {
int idx = block * row + col % block; int64_t idx = block * row + col % block;
output[idx] = input[i]; output[idx] = input[i];
} }
} }
...@@ -93,9 +93,9 @@ class CSplitOpCUDAKernel : public framework::OpKernel<T> { ...@@ -93,9 +93,9 @@ class CSplitOpCUDAKernel : public framework::OpKernel<T> {
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1); auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim); int64_t remain_numel = phi::product(remain_ddim);
int limit = x->numel(); int64_t limit = x->numel();
int blocks = NumBlocks(limit); int64_t blocks = NumBlocks(limit);
int threads = kNumCUDAThreads; int64_t threads = kNumCUDAThreads;
dims[dims_size - 1] /= nranks; dims[dims_size - 1] /= nranks;
out->mutable_data<T>(dims, place); out->mutable_data<T>(dims, place);
......
...@@ -69,7 +69,7 @@ class PartialRecvOp : public framework::OperatorWithKernel { ...@@ -69,7 +69,7 @@ class PartialRecvOp : public framework::OperatorWithKernel {
out_shape[i])); out_shape[i]));
} }
auto out_dims = phi::make_ddim(out_shape); auto out_dims = phi::make_ddim(out_shape);
int numel = phi::product(out_dims); int64_t numel = phi::product(out_dims);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(numel % num), (numel % num),
0, 0,
......
...@@ -68,8 +68,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> { ...@@ -68,8 +68,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
int recv_numel = numel / num; int64_t recv_numel = numel / num;
int offset = recv_numel * id; int64_t offset = recv_numel * id;
auto map = distributed::ProcessGroupMapFromGid::getInstance(); auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) { if (map->has(rid)) {
......
...@@ -62,8 +62,8 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> { ...@@ -62,8 +62,8 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input numel (%d) must be divisible by num(%d)", numel, num)); "The input numel (%d) must be divisible by num(%d)", numel, num));
int send_numel = numel / num; int64_t send_numel = numel / num;
int offset = send_numel * id; int64_t offset = send_numel * id;
auto map = distributed::ProcessGroupMapFromGid::getInstance(); auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) { if (map->has(rid)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册