未验证 提交 d926c270 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication P2P C++ API (#47801)

* refactor: send, recv, send_partial, recv_partial

* refactor: rm useless const ref
上级 1b250710
......@@ -139,6 +139,44 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor*,
int dst_rank,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial with sync_op flag",
GetBackendName()));
}
// TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
......
......@@ -139,7 +139,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
gpuStream_t stream) {
return platform::dynload::ncclAllGather(
input.data(),
output->data(),
......@@ -165,7 +165,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
gpuStream_t stream) {
return platform::dynload::ncclAllReduce(
input.data(),
output->data(),
......@@ -209,7 +209,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream) {
gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
return platform::dynload::ncclBroadcast(
input.data(),
......@@ -225,6 +225,118 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
return PointToPoint(
tensor,
src_rank,
[&](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
comm,
stream);
},
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(*tensor).Resize({tensor->numel()});
phi::DenseTensor tensor_recv =
tensor_flattened.Slice(offset, offset + length);
return PointToPoint(
&tensor_recv,
src_rank,
[&](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
comm,
stream);
},
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
return PointToPoint(
tensor,
dst_rank,
[&](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
comm,
stream);
},
CommType::SEND,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
phi::DenseTensor tensor_flattened;
tensor_flattened.ShareDataWith(*tensor).Resize({tensor->numel()});
phi::DenseTensor tensor_send =
tensor_flattened.Slice(offset, offset + length);
return PointToPoint(
&tensor_send,
dst_rank,
[&](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
return platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
comm,
stream);
},
CommType::SEND,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
const Place& place,
int rank,
......@@ -331,6 +443,45 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = tensor->place();
const auto& key = GetKeyFromPlace(place);
platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key);
}
if (!use_calc_stream) {
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(tensor, rank, nccl_comm, nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor->Holder(), nccl_stream);
}
task->UpdateWaitChain(*comm_ctx);
}
return task;
}
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
std::vector<int64_t> tensor_shape) {
int64_t len_size = (*split_sizes).size();
......@@ -864,34 +1015,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& input,
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return platform::dynload::ncclSend(
input.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
},
dst_rank,
CommType::SEND,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
......@@ -915,34 +1038,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
auto task = PointToPoint(
tensors,
[&](phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return platform::dynload::ncclRecv(
output.data(),
output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
},
src_rank,
CommType::RECV,
sync_op,
use_calc_stream);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int64_t offset, int64_t length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
......
......@@ -118,6 +118,32 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) override;
static void GroupStart();
static void GroupEnd();
......@@ -139,21 +165,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors,
int src_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank,
int64_t offset,
......@@ -273,6 +287,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void SyncCalcStream(const Place& place);
// TODO(sunyilun): methods below will be removed later
......
......@@ -92,6 +92,87 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
"ProcessGroup%s does not support do broadcast", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor, int src_rank, bool sync_op) {
return Recv(tensor,
src_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return RecvPartial(tensor,
src_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do recv_partial", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor* tensor, int dst_rank, bool sync_op) {
return Send(tensor,
dst_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor*, int dst_rank, bool sync_op, bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) {
return SendPartial(tensor,
dst_rank,
offset,
length,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send partial", GetBackendName()));
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
......@@ -203,23 +284,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
"ProcessGroup%s does not support do scatter", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std::vector<phi::DenseTensor>& tensors, int dst_rank, bool sync_op) {
return Send(tensors,
dst_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std::vector<phi::DenseTensor>& tensors,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do send", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
phi::DenseTensor& tensors,
int dst_rank,
......
......@@ -98,6 +98,52 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> RecvPartial(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> RecvPartial(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> SendPartial(phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> SendPartial(
phi::DenseTensor* tensor,
int dst_rank,
int64_t offset,
int64_t length,
bool sync_op,
bool use_calc_stream);
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
......@@ -164,17 +210,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int dst_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>& tensors, // NOLINT
int dst_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor& tensors, // NOLINT
int dst_rank,
......
......@@ -173,10 +173,10 @@ void BindDistributed(py::module *m) {
int dst,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst, sync_op);
auto *out_dense = p_dense.get();
return self.Send(out_dense, dst, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -192,13 +192,14 @@ void BindDistributed(py::module *m) {
int rank_id,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(
*dense, dst_rank, offset, send_numel, sync_op);
auto *out_dense = p_dense.get();
return self.SendPartial(
out_dense, dst_rank, offset, send_numel, sync_op);
},
py::arg("tensor"),
py::arg("dst"),
......@@ -214,10 +215,10 @@ void BindDistributed(py::module *m) {
int src,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src, sync_op);
auto *in_dense = p_dense.get();
return self.Recv(in_dense, src, sync_op);
},
py::arg("tensor"),
py::arg("src"),
......@@ -233,13 +234,14 @@ void BindDistributed(py::module *m) {
int rank_id,
bool sync_op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t numel = p_dense->numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(
*dense, src_rank, offset, recv_numel, sync_op);
auto *out_dense = p_dense.get();
return self.RecvPartial(
out_dense, src_rank, offset, recv_numel, sync_op);
},
py::arg("tensor"),
py::arg("src"),
......@@ -1125,10 +1127,10 @@ void BindDistributed(py::module *m) {
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors,
auto *out_dense = p_dense.get();
return self.Send(out_dense,
dst,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......@@ -1145,12 +1147,13 @@ void BindDistributed(py::module *m) {
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense,
auto *out_dense = p_dense.get();
return self.SendPartial(out_dense,
dst_rank,
offset,
send_numel,
......@@ -1169,10 +1172,10 @@ void BindDistributed(py::module *m) {
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors,
auto *in_dense = p_dense.get();
return self.Recv(in_dense,
src,
/*sync_op*/ true,
/*use_calc_stream*/ true);
......@@ -1189,12 +1192,13 @@ void BindDistributed(py::module *m) {
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t numel = p_dense->numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense,
auto *out_dense = p_dense.get();
return self.RecvPartial(out_dense,
src_rank,
offset,
recv_numel,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册