未验证 提交 9fd4fd5f 编写于 作者: Y Yuang Liu 提交者: GitHub

use int64 for c split (#52279) (#52340)

上级 84504f35
......@@ -21,10 +21,10 @@ limitations under the License. */
namespace paddle {
namespace operators {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int64_t kNumCUDAThreads = 512;
static constexpr int64_t kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
static inline int64_t NumBlocks(const int64_t N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
......@@ -32,21 +32,21 @@ static inline int NumBlocks(const int N) {
template <typename T>
__global__ void SplitFromRank(const T* input,
T* output,
const int rows,
const int columns,
const int64_t rows,
const int64_t columns,
const int rank,
const int nranks,
const int limit) {
CUDA_KERNEL_LOOP(i, limit) {
int row = i / columns;
int col = i % columns;
const int64_t limit) {
CUDA_KERNEL_LOOP_TYPE(i, limit, int64_t) {
int64_t row = i / columns;
int64_t col = i % columns;
int block = columns / nranks;
int start = block * rank;
int end = start + block;
int64_t block = columns / nranks;
int64_t start = block * rank;
int64_t end = start + block;
if (col >= start && col < end) {
int idx = block * row + col % block;
int64_t idx = block * row + col % block;
output[idx] = input[i];
}
}
......@@ -93,9 +93,9 @@ class CSplitOpCUDAKernel : public framework::OpKernel<T> {
auto remain_ddim = phi::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = phi::product(remain_ddim);
int limit = x->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
int64_t limit = x->numel();
int64_t blocks = NumBlocks(limit);
int64_t threads = kNumCUDAThreads;
dims[dims_size - 1] /= nranks;
out->mutable_data<T>(dims, place);
......
......@@ -69,7 +69,7 @@ class PartialRecvOp : public framework::OperatorWithKernel {
out_shape[i]));
}
auto out_dims = phi::make_ddim(out_shape);
int numel = phi::product(out_dims);
int64_t numel = phi::product(out_dims);
PADDLE_ENFORCE_EQ(
(numel % num),
0,
......
......@@ -68,8 +68,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
out->mutable_data<T>(out_dims, place);
int recv_numel = numel / num;
int offset = recv_numel * id;
int64_t recv_numel = numel / num;
int64_t offset = recv_numel * id;
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
......
......@@ -62,8 +62,8 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"The input numel (%d) must be divisible by num(%d)", numel, num));
int send_numel = numel / num;
int offset = send_numel * id;
int64_t send_numel = numel / num;
int64_t offset = send_numel * id;
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册