未验证 提交 84c9a0d6 编写于 作者: L LiYuRio 提交者: GitHub

refine comm api implementation (#47713)

上级 3198af20
...@@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
} }
}; };
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) { std::vector<phi::DenseTensor>& out_tensors) {
return AllGather(in_tensors, out_tensors, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) {
std::shared_ptr<AllgatherGlooTask> task; std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag(); auto tag = next_tag();
auto context = get_context(); auto context = get_context();
......
...@@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
...@@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -129,24 +129,7 @@ void BindDistributed(py::module *m) { ...@@ -129,24 +129,7 @@ void BindDistributed(py::module *m) {
.def("size", &distributed::ProcessGroup::GetSize) .def("size", &distributed::ProcessGroup::GetSize)
.def("name", &distributed::ProcessGroup::GetBackendName) .def("name", &distributed::ProcessGroup::GetBackendName)
.def( .def(
"allreduce", "all_reduce",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"allreduce",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
distributed::ReduceOp op, distributed::ReduceOp op,
...@@ -164,23 +147,6 @@ void BindDistributed(py::module *m) { ...@@ -164,23 +147,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int source_rank) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts;
opts.source_rank = source_rank;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"broadcast", "broadcast",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -200,31 +166,6 @@ void BindDistributed(py::module *m) { ...@@ -200,31 +166,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"barrier",
[](distributed::ProcessGroup &self, std::vector<int> place_ids) {
distributed::BarrierOptions opts;
opts.place_ids = place_ids;
return self.Barrier(opts);
},
py::arg("place_ids") = std::vector<int>{},
py::call_guard<py::gil_scoped_release>())
.def(
"send",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
},
py::arg("tensor"),
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"send", "send",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -242,27 +183,6 @@ void BindDistributed(py::module *m) { ...@@ -242,27 +183,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense, dst_rank, offset, send_numel);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"send_partial", "send_partial",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -287,21 +207,6 @@ void BindDistributed(py::module *m) { ...@@ -287,21 +207,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
},
py::arg("tensor"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"recv", "recv",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -319,27 +224,6 @@ void BindDistributed(py::module *m) { ...@@ -319,27 +224,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"recv_partial", "recv_partial",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -366,25 +250,6 @@ void BindDistributed(py::module *m) { ...@@ -366,25 +250,6 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_gather", "all_gather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor, py::handle py_in_tensor,
...@@ -413,7 +278,7 @@ void BindDistributed(py::module *m) { ...@@ -413,7 +278,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allgather_into_tensor", "all_gather_into_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor, py::handle py_in_tensor,
...@@ -436,53 +301,7 @@ void BindDistributed(py::module *m) { ...@@ -436,53 +301,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"all_gather_partial", "all_to_all",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int64_t numel = (*in_dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.AllGather_Partial(
in_tensors, out_tensors, offset, send_numel);
},
py::arg("in"),
py::arg("out"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor_list, py::handle py_in_tensor_list,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
...@@ -515,7 +334,7 @@ void BindDistributed(py::module *m) { ...@@ -515,7 +334,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"alltoall_tensor", "all_to_all_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor, py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
...@@ -538,31 +357,7 @@ void BindDistributed(py::module *m) { ...@@ -538,31 +357,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"alltoall_single", "all_to_all_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single(
in_tensors, out_tensors, in_sizes, out_sizes);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("out_sizes"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor, py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
...@@ -589,26 +384,6 @@ void BindDistributed(py::module *m) { ...@@ -589,26 +384,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
int dst,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
distributed::ReduceOptions opts;
opts.reduce_op = op;
opts.root_rank = dst;
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def( .def(
"reduce", "reduce",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -685,29 +460,6 @@ void BindDistributed(py::module *m) { ...@@ -685,29 +460,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
opts.root_rank = src;
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"scatter", "scatter",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -762,6 +514,255 @@ void BindDistributed(py::module *m) { ...@@ -762,6 +514,255 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"barrier",
[](distributed::ProcessGroup &self, std::vector<int> place_ids) {
distributed::BarrierOptions opts;
opts.place_ids = place_ids;
return self.Barrier(opts);
},
py::arg("place_ids") = std::vector<int>{},
py::call_guard<py::gil_scoped_release>())
// TODO(liyurui): Interface below will be removed in the future.
.def(
"allreduce",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int source_rank) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts;
opts.source_rank = source_rank;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
.def(
"send",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
},
py::arg("tensor"),
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense, dst_rank, offset, send_numel);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
},
py::arg("tensor"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather_partial",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int64_t numel = (*in_dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.AllGather_Partial(
in_tensors, out_tensors, offset, send_numel);
},
py::arg("in"),
py::arg("out"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single(
in_tensors, out_tensors, in_sizes, out_sizes);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("out_sizes"),
py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
int dst,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
distributed::ReduceOptions opts;
opts.reduce_op = op;
opts.root_rank = dst;
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
opts.root_rank = src;
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def( .def(
"_reduce_scatter_base", "_reduce_scatter_base",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
...@@ -788,7 +789,7 @@ void BindDistributed(py::module *m) { ...@@ -788,7 +789,7 @@ void BindDistributed(py::module *m) {
std::shared_ptr<distributed::ProcessGroupStream>>( std::shared_ptr<distributed::ProcessGroupStream>>(
*m, "ProcessGroupStream", ProcessGroup) *m, "ProcessGroupStream", ProcessGroup)
.def( .def(
"allgather_on_calc_stream", "all_gather_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor) { py::handle py_in_tensor) {
...@@ -818,7 +819,7 @@ void BindDistributed(py::module *m) { ...@@ -818,7 +819,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allgather_into_tensor_on_calc_stream", "all_gather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor) { py::handle py_in_tensor) {
...@@ -873,7 +874,7 @@ void BindDistributed(py::module *m) { ...@@ -873,7 +874,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"allreduce_on_calc_stream", "all_reduce_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_tensor, py::handle py_tensor,
distributed::ReduceOp op) { distributed::ReduceOp op) {
...@@ -890,11 +891,11 @@ void BindDistributed(py::module *m) { ...@@ -890,11 +891,11 @@ void BindDistributed(py::module *m) {
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
}, },
py::arg("tensor"), py::arg("tensor"),
py::arg("op"), py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"alltoall_on_calc_stream", "all_to_all_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list, py::handle py_in_tensor_list,
py::handle py_out_tensor_list) { py::handle py_out_tensor_list) {
...@@ -927,7 +928,7 @@ void BindDistributed(py::module *m) { ...@@ -927,7 +928,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"alltoall_tensor_on_calc_stream", "all_to_all_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_in_tensor,
py::handle py_out_tensor) { py::handle py_out_tensor) {
...@@ -951,7 +952,7 @@ void BindDistributed(py::module *m) { ...@@ -951,7 +952,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"alltoall_single_on_calc_stream", "all_to_all_single_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
......
...@@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): ...@@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
out = paddle.empty(tensor_shape, tensor.dtype) out = paddle.empty(tensor_shape, tensor.dtype)
else: else:
out = paddle.concat(tensor_list, axis=0) out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather(tensor, out) task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
task.wait() task.wait()
tensor_list.clear() tensor_list.clear()
list_of_tensor = paddle.split(out, group.nranks, 0) list_of_tensor = paddle.split(out, group.nranks, 0)
......
...@@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph( ...@@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream( return group.process_group.all_gather_into_tensor_on_calc_stream(
out_tensor, out_tensor,
in_tensor, in_tensor,
) )
task = group.process_group.allgather_into_tensor( task = group.process_group.all_gather_into_tensor(
out_tensor, in_tensor, sync_op out_tensor, in_tensor, sync_op
) )
if sync_op: if sync_op:
...@@ -69,9 +69,11 @@ def _all_gather_in_dygraph( ...@@ -69,9 +69,11 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor_list, tensor) return group.process_group.all_gather_on_calc_stream(
tensor_list, tensor
)
task = group.process_group.allgather(tensor_list, tensor, sync_op) task = group.process_group.all_gather(tensor_list, tensor, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
......
...@@ -25,11 +25,10 @@ from paddle.distributed.communication.group import ( ...@@ -25,11 +25,10 @@ from paddle.distributed.communication.group import (
def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
op_type = _get_reduce_op(op, "allreduce") op_type = _get_reduce_op(op, "allreduce")
group = _get_global_group() if group is None else group
if use_calc_stream: if use_calc_stream:
return group.process_group.allreduce_on_calc_stream(tensor, op_type) return group.process_group.all_reduce_on_calc_stream(tensor, op_type)
task = group.process_group.allreduce(tensor, op_type, sync_op) task = group.process_group.all_reduce(tensor, op_type, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -119,6 +118,7 @@ def all_reduce( ...@@ -119,6 +118,7 @@ def all_reduce(
) )
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _all_reduce_in_dygraph( return _all_reduce_in_dygraph(
tensor, op, group, sync_op, use_calc_stream tensor, op, group, sync_op, use_calc_stream
) )
......
...@@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph( ...@@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.alltoall_tensor_on_calc_stream( return group.process_group.all_to_all_tensor_on_calc_stream(
in_tensor, out_tensor in_tensor, out_tensor
) )
task = group.process_group.alltoall_tensor(in_tensor, out_tensor, sync_op) task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op)
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -74,11 +74,11 @@ def _all_to_all_in_dygraph( ...@@ -74,11 +74,11 @@ def _all_to_all_in_dygraph(
) )
if use_calc_stream: if use_calc_stream:
return group.process_group.alltoall_on_calc_stream( return group.process_group.all_to_all_on_calc_stream(
in_tensor_list, out_tensor_list in_tensor_list, out_tensor_list
) )
task = group.process_group.alltoall( task = group.process_group.all_to_all(
in_tensor_list, out_tensor_list, sync_op in_tensor_list, out_tensor_list, sync_op
) )
if sync_op: if sync_op:
...@@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph( ...@@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph(
in_split_sizes = [] in_split_sizes = []
if use_calc_stream: if use_calc_stream:
return group.process_group.alltoall_single_on_calc_stream( return group.process_group.all_to_all_single_on_calc_stream(
in_tensor, out_tensor, in_split_sizes, out_split_sizes in_tensor, out_tensor, in_split_sizes, out_split_sizes
) )
task = group.process_group.alltoall_single( task = group.process_group.all_to_all_single(
in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op
) )
if sync_op: if sync_op:
......
...@@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph( ...@@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph(
caller="reduce_scatter", caller="reduce_scatter",
): ):
op_type = _get_reduce_op(op, caller) op_type = _get_reduce_op(op, caller)
group = _get_global_group() if group is None else group
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
...@@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph( ...@@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph(
tensor, tensor_list, op, group, sync_op, use_calc_stream tensor, tensor_list, op, group, sync_op, use_calc_stream
): ):
op_type = _get_reduce_op(op, "reduce_scatter") op_type = _get_reduce_op(op, "reduce_scatter")
group = _get_global_group() if group is None else group
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
...@@ -149,6 +147,7 @@ def reduce_scatter( ...@@ -149,6 +147,7 @@ def reduce_scatter(
) )
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
if paddle.is_tensor(tensor_or_tensor_list): if paddle.is_tensor(tensor_or_tensor_list):
return _reduce_scatter_tensor_in_dygraph( return _reduce_scatter_tensor_in_dygraph(
tensor, tensor,
...@@ -230,6 +229,7 @@ def _reduce_scatter_base( ...@@ -230,6 +229,7 @@ def _reduce_scatter_base(
) )
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _reduce_scatter_tensor_in_dygraph( return _reduce_scatter_tensor_in_dygraph(
out_tensor, out_tensor,
in_tensor, in_tensor,
......
...@@ -62,7 +62,7 @@ def _c_identity(tensor, group=None): ...@@ -62,7 +62,7 @@ def _c_identity(tensor, group=None):
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.allreduce_on_calc_stream(dy, op_type) group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy return dy
return c_identity_eager.apply(tensor) return c_identity_eager.apply(tensor)
...@@ -255,7 +255,7 @@ def _mp_allreduce( ...@@ -255,7 +255,7 @@ def _mp_allreduce(
if use_calc_stream: if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce") op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.allreduce_on_calc_stream( group.process_group.all_reduce_on_calc_stream(
tensor, op_type tensor, op_type
) )
return tensor return tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册