未验证 提交 84c9a0d6 编写于 作者: L LiYuRio 提交者: GitHub

refine comm api implementation (#47713)

上级 3198af20
......@@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper = {in_tensor};
std::vector<phi::DenseTensor> out_wrapper = {*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
return AllGather(in_tensors, out_tensors, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) {
std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
......
......@@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup {
~ProcessGroupGloo() = default;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -129,24 +129,7 @@ void BindDistributed(py::module *m) {
.def("size", &distributed::ProcessGroup::GetSize)
.def("name", &distributed::ProcessGroup::GetBackendName)
.def(
"allreduce",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"allreduce",
"all_reduce",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op,
......@@ -164,23 +147,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int source_rank) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts;
opts.source_rank = source_rank;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
[](distributed::ProcessGroup &self,
......@@ -200,31 +166,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"barrier",
[](distributed::ProcessGroup &self, std::vector<int> place_ids) {
distributed::BarrierOptions opts;
opts.place_ids = place_ids;
return self.Barrier(opts);
},
py::arg("place_ids") = std::vector<int>{},
py::call_guard<py::gil_scoped_release>())
.def(
"send",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
},
py::arg("tensor"),
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"send",
[](distributed::ProcessGroup &self,
......@@ -242,27 +183,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense, dst_rank, offset, send_numel);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
......@@ -287,21 +207,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
},
py::arg("tensor"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
......@@ -319,27 +224,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
......@@ -366,25 +250,6 @@ void BindDistributed(py::module *m) {
.def(
"all_gather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather",
[](distributed::ProcessGroup &self,
py::handle py_out_tensor_list,
py::handle py_in_tensor,
......@@ -413,7 +278,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_into_tensor",
"all_gather_into_tensor",
[](distributed::ProcessGroup &self,
py::handle py_out_tensor,
py::handle py_in_tensor,
......@@ -436,53 +301,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather_partial",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int64_t numel = (*in_dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.AllGather_Partial(
in_tensors, out_tensors, offset, send_numel);
},
py::arg("in"),
py::arg("out"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
"all_to_all",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor_list,
......@@ -515,7 +334,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_tensor",
"all_to_all_tensor",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
......@@ -538,31 +357,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single(
in_tensors, out_tensors, in_sizes, out_sizes);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("out_sizes"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single",
"all_to_all_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
......@@ -589,26 +384,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
int dst,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
distributed::ReduceOptions opts;
opts.reduce_op = op;
opts.root_rank = dst;
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self,
......@@ -685,29 +460,6 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
opts.root_rank = src;
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self,
......@@ -762,6 +514,255 @@ void BindDistributed(py::module *m) {
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"barrier",
[](distributed::ProcessGroup &self, std::vector<int> place_ids) {
distributed::BarrierOptions opts;
opts.place_ids = place_ids;
return self.Barrier(opts);
},
py::arg("place_ids") = std::vector<int>{},
py::call_guard<py::gil_scoped_release>())
// TODO(liyurui): Interface below will be removed in the future.
.def(
"allreduce",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::AllreduceOptions opts;
opts.reduce_op = op;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.AllReduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int source_rank) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
distributed::BroadcastOptions opts;
opts.source_rank = source_rank;
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Broadcast(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("source_rank"),
py::call_guard<py::gil_scoped_release>())
.def(
"send",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Send(tensors, dst);
},
py::arg("tensor"),
py::arg("dst"),
py::call_guard<py::gil_scoped_release>())
.def(
"send_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int dst_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.Send_Partial(*dense, dst_rank, offset, send_numel);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Recv(tensors, src);
},
py::arg("tensor"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"recv_partial",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
int src_rank,
int nranks,
int rank_id) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
int64_t numel = (*dense).numel();
int64_t recv_numel = numel / nranks;
int64_t offset = recv_numel * rank_id;
return self.Recv_Partial(*dense, src_rank, offset, recv_numel);
},
py::arg("tensor"),
py::arg("src"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllGather(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather_partial",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int nranks,
int rank_id) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
int64_t numel = (*in_dense).numel();
int64_t send_numel = numel / nranks;
int64_t offset = send_numel * rank_id;
return self.AllGather_Partial(
in_tensors, out_tensors, offset, send_numel);
},
py::arg("in"),
py::arg("out"),
py::arg("num"),
py::arg("id"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll(in_tensors, out_tensors);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single(
in_tensors, out_tensors, in_sizes, out_sizes);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("out_sizes"),
py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
int dst,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
distributed::ReduceOptions opts;
opts.reduce_op = op;
opts.root_rank = dst;
auto dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> tensors = {*dense};
return self.Reduce(tensors, tensors, opts);
},
py::arg("tensor"),
py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
int src) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ScatterOptions opts;
opts.root_rank = src;
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.Scatter(in_tensors, out_tensors, opts);
},
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"_reduce_scatter_base",
[](distributed::ProcessGroup &self,
......@@ -788,7 +789,7 @@ void BindDistributed(py::module *m) {
std::shared_ptr<distributed::ProcessGroupStream>>(
*m, "ProcessGroupStream", ProcessGroup)
.def(
"allgather_on_calc_stream",
"all_gather_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_out_tensor_list,
py::handle py_in_tensor) {
......@@ -818,7 +819,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_into_tensor_on_calc_stream",
"all_gather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_out_tensor,
py::handle py_in_tensor) {
......@@ -873,7 +874,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"allreduce_on_calc_stream",
"all_reduce_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_tensor,
distributed::ReduceOp op) {
......@@ -890,11 +891,11 @@ void BindDistributed(py::module *m) {
/*use_calc_stream*/ true);
},
py::arg("tensor"),
py::arg("op"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_on_calc_stream",
"all_to_all_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor_list) {
......@@ -927,7 +928,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_tensor_on_calc_stream",
"all_to_all_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
......@@ -951,7 +952,7 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single_on_calc_stream",
"all_to_all_single_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
......
......@@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
out = paddle.empty(tensor_shape, tensor.dtype)
else:
out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather(tensor, out)
task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
task.wait()
tensor_list.clear()
list_of_tensor = paddle.split(out, group.nranks, 0)
......
......@@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_into_tensor_on_calc_stream(
return group.process_group.all_gather_into_tensor_on_calc_stream(
out_tensor,
in_tensor,
)
task = group.process_group.allgather_into_tensor(
task = group.process_group.all_gather_into_tensor(
out_tensor, in_tensor, sync_op
)
if sync_op:
......@@ -69,9 +69,11 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor_list, tensor)
return group.process_group.all_gather_on_calc_stream(
tensor_list, tensor
)
task = group.process_group.allgather(tensor_list, tensor, sync_op)
task = group.process_group.all_gather(tensor_list, tensor, sync_op)
if sync_op:
task.wait()
......
......@@ -25,11 +25,10 @@ from paddle.distributed.communication.group import (
def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
op_type = _get_reduce_op(op, "allreduce")
group = _get_global_group() if group is None else group
if use_calc_stream:
return group.process_group.allreduce_on_calc_stream(tensor, op_type)
return group.process_group.all_reduce_on_calc_stream(tensor, op_type)
task = group.process_group.allreduce(tensor, op_type, sync_op)
task = group.process_group.all_reduce(tensor, op_type, sync_op)
if sync_op:
task.wait()
......@@ -119,6 +118,7 @@ def all_reduce(
)
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _all_reduce_in_dygraph(
tensor, op, group, sync_op, use_calc_stream
)
......
......@@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph(
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.alltoall_tensor_on_calc_stream(
return group.process_group.all_to_all_tensor_on_calc_stream(
in_tensor, out_tensor
)
task = group.process_group.alltoall_tensor(in_tensor, out_tensor, sync_op)
task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op)
if sync_op:
task.wait()
......@@ -74,11 +74,11 @@ def _all_to_all_in_dygraph(
)
if use_calc_stream:
return group.process_group.alltoall_on_calc_stream(
return group.process_group.all_to_all_on_calc_stream(
in_tensor_list, out_tensor_list
)
task = group.process_group.alltoall(
task = group.process_group.all_to_all(
in_tensor_list, out_tensor_list, sync_op
)
if sync_op:
......@@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph(
in_split_sizes = []
if use_calc_stream:
return group.process_group.alltoall_single_on_calc_stream(
return group.process_group.all_to_all_single_on_calc_stream(
in_tensor, out_tensor, in_split_sizes, out_split_sizes
)
task = group.process_group.alltoall_single(
task = group.process_group.all_to_all_single(
in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op
)
if sync_op:
......
......@@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph(
caller="reduce_scatter",
):
op_type = _get_reduce_op(op, caller)
group = _get_global_group() if group is None else group
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
......@@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph(
tensor, tensor_list, op, group, sync_op, use_calc_stream
):
op_type = _get_reduce_op(op, "reduce_scatter")
group = _get_global_group() if group is None else group
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
......@@ -149,6 +147,7 @@ def reduce_scatter(
)
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
if paddle.is_tensor(tensor_or_tensor_list):
return _reduce_scatter_tensor_in_dygraph(
tensor,
......@@ -230,6 +229,7 @@ def _reduce_scatter_base(
)
if framework.in_dygraph_mode():
group = _get_global_group() if group is None else group
return _reduce_scatter_tensor_in_dygraph(
out_tensor,
in_tensor,
......
......@@ -62,7 +62,7 @@ def _c_identity(tensor, group=None):
@staticmethod
def backward(ctx, dy):
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
group.process_group.allreduce_on_calc_stream(dy, op_type)
group.process_group.all_reduce_on_calc_stream(dy, op_type)
return dy
return c_identity_eager.apply(tensor)
......@@ -255,7 +255,7 @@ def _mp_allreduce(
if use_calc_stream:
op_type = _get_reduce_op(op, "_mp_allreduce")
group.process_group.allreduce_on_calc_stream(
group.process_group.all_reduce_on_calc_stream(
tensor, op_type
)
return tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册