From 84c9a0d6771aea8524811f600967f8a7ee72753f Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 8 Nov 2022 18:56:59 +0800 Subject: [PATCH] refine comm api implementation (#47713) --- .../collective/ProcessGroupGloo.cc | 16 + .../distributed/collective/ProcessGroupGloo.h | 10 + paddle/fluid/pybind/distributed_py.cc | 521 +++++++++--------- python/paddle/distributed/collective.py | 2 +- .../communication/stream/all_gather.py | 10 +- .../communication/stream/all_reduce.py | 6 +- .../communication/stream/all_to_all.py | 12 +- .../communication/stream/reduce_scatter.py | 4 +- .../distributed/fleet/layers/mpu/mp_ops.py | 4 +- 9 files changed, 307 insertions(+), 278 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc index d6d7f328ae..5cb4daf728 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc @@ -391,9 +391,25 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { } }; +std::shared_ptr ProcessGroupGloo::AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op) { + std::vector in_wrapper = {in_tensor}; + std::vector out_wrapper = {*out_tensor}; + return AllGather(in_wrapper, out_wrapper, true); +} + std::shared_ptr ProcessGroupGloo::AllGather( std::vector& in_tensors, std::vector& out_tensors) { + return AllGather(in_tensors, out_tensors, true); +} + +std::shared_ptr ProcessGroupGloo::AllGather( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op) { std::shared_ptr task; auto tag = next_tag(); auto context = get_context(); diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index 72309bf769..9796f91663 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -108,6 +108,11 @@ class ProcessGroupGloo : public ProcessGroup { ~ProcessGroupGloo() = default; + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op) override; + std::shared_ptr Broadcast( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -144,6 +149,11 @@ class ProcessGroupGloo : public ProcessGroup { std::vector& in_tensors, std::vector& out_tensors) override; + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op) override; + std::shared_ptr Reduce( std::vector& in_tensors, std::vector& out_tensors, diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 2b9d8d1631..06b26d66a6 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -129,24 +129,7 @@ void BindDistributed(py::module *m) { .def("size", &distributed::ProcessGroup::GetSize) .def("name", &distributed::ProcessGroup::GetBackendName) .def( - "allreduce", - [](distributed::ProcessGroup &self, - py::handle py_tensor, - distributed::ReduceOp op) { - auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - distributed::AllreduceOptions opts; - opts.reduce_op = op; - auto dense = - std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.AllReduce(tensors, tensors, opts); - }, - py::arg("tensor"), - py::arg("op") = distributed::ReduceOp::SUM, - py::call_guard()) - - .def( - "allreduce", + "all_reduce", [](distributed::ProcessGroup &self, py::handle py_tensor, distributed::ReduceOp op, @@ -164,23 +147,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "broadcast", - [](distributed::ProcessGroup &self, - py::handle py_tensor, - int source_rank) { - auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - distributed::BroadcastOptions opts; - opts.source_rank = source_rank; - auto dense = - std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Broadcast(tensors, tensors, opts); - }, - py::arg("tensor"), - py::arg("source_rank"), - py::call_guard()) - .def( "broadcast", [](distributed::ProcessGroup &self, @@ -200,31 +166,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "barrier", - [](distributed::ProcessGroup &self, std::vector place_ids) { - distributed::BarrierOptions opts; - opts.place_ids = place_ids; - return self.Barrier(opts); - }, - py::arg("place_ids") = std::vector{}, - py::call_guard()) - - .def( - "send", - [](distributed::ProcessGroup &self, - py::handle py_tensor, - int dst) { - auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - auto dense = - std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Send(tensors, dst); - }, - py::arg("tensor"), - py::arg("dst"), - py::call_guard()) - .def( "send", [](distributed::ProcessGroup &self, @@ -242,27 +183,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "send_partial", - [](distributed::ProcessGroup &self, - py::handle py_tensor, - int dst_rank, - int nranks, - int rank_id) { - auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - auto dense = - std::dynamic_pointer_cast(tensor.impl()); - int64_t numel = (*dense).numel(); - int64_t send_numel = numel / nranks; - int64_t offset = send_numel * rank_id; - return self.Send_Partial(*dense, dst_rank, offset, send_numel); - }, - py::arg("tensor"), - py::arg("dst"), - py::arg("num"), - py::arg("id"), - py::call_guard()) - .def( "send_partial", [](distributed::ProcessGroup &self, @@ -287,21 +207,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "recv", - [](distributed::ProcessGroup &self, - py::handle py_tensor, - int src) { - auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - auto dense = - std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Recv(tensors, src); - }, - py::arg("tensor"), - py::arg("src"), - py::call_guard()) - .def( "recv", [](distributed::ProcessGroup &self, @@ -319,27 +224,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "recv_partial", - [](distributed::ProcessGroup &self, - py::handle py_tensor, - int src_rank, - int nranks, - int rank_id) { - auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - auto dense = - std::dynamic_pointer_cast(tensor.impl()); - int64_t numel = (*dense).numel(); - int64_t recv_numel = numel / nranks; - int64_t offset = recv_numel * rank_id; - return self.Recv_Partial(*dense, src_rank, offset, recv_numel); - }, - py::arg("tensor"), - py::arg("src"), - py::arg("num"), - py::arg("id"), - py::call_guard()) - .def( "recv_partial", [](distributed::ProcessGroup &self, @@ -366,25 +250,6 @@ void BindDistributed(py::module *m) { .def( "all_gather", - [](distributed::ProcessGroup &self, - py::handle py_in_tensor, - py::handle py_out_tensor) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllGather(in_tensors, out_tensors); - }, - py::arg("in"), - py::arg("out"), - py::call_guard()) - - .def( - "allgather", [](distributed::ProcessGroup &self, py::handle py_out_tensor_list, py::handle py_in_tensor, @@ -413,7 +278,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "allgather_into_tensor", + "all_gather_into_tensor", [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor, @@ -436,53 +301,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "all_gather_partial", - [](distributed::ProcessGroup &self, - py::handle py_in_tensor, - py::handle py_out_tensor, - int nranks, - int rank_id) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - int64_t numel = (*in_dense).numel(); - int64_t send_numel = numel / nranks; - int64_t offset = send_numel * rank_id; - return self.AllGather_Partial( - in_tensors, out_tensors, offset, send_numel); - }, - py::arg("in"), - py::arg("out"), - py::arg("num"), - py::arg("id"), - py::call_guard()) - - .def( - "alltoall", - [](distributed::ProcessGroup &self, - py::handle py_in_tensor, - py::handle py_out_tensor) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllToAll(in_tensors, out_tensors); - }, - py::arg("in"), - py::arg("out"), - py::call_guard()) - - .def( - "alltoall", + "all_to_all", [](distributed::ProcessGroup &self, py::handle py_in_tensor_list, py::handle py_out_tensor_list, @@ -515,7 +334,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "alltoall_tensor", + "all_to_all_tensor", [](distributed::ProcessGroup &self, py::handle py_in_tensor, py::handle py_out_tensor, @@ -538,31 +357,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "alltoall_single", - [](distributed::ProcessGroup &self, - py::handle py_in_tensor, - py::handle py_out_tensor, - std::vector in_sizes, - std::vector out_sizes) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllToAll_Single( - in_tensors, out_tensors, in_sizes, out_sizes); - }, - py::arg("in"), - py::arg("out"), - py::arg("in_sizes"), - py::arg("out_sizes"), - py::call_guard()) - - .def( - "alltoall_single", + "all_to_all_single", [](distributed::ProcessGroup &self, py::handle py_in_tensor, py::handle py_out_tensor, @@ -589,26 +384,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "reduce", - [](distributed::ProcessGroup &self, - py::handle py_in_tensor, - int dst, - distributed::ReduceOp op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - distributed::ReduceOptions opts; - opts.reduce_op = op; - opts.root_rank = dst; - auto dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector tensors = {*dense}; - return self.Reduce(tensors, tensors, opts); - }, - py::arg("tensor"), - py::arg("dst"), - py::arg("op") = distributed::ReduceOp::SUM, - py::call_guard()) - .def( "reduce", [](distributed::ProcessGroup &self, @@ -685,29 +460,6 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) - .def( - "scatter", - [](distributed::ProcessGroup &self, - py::handle py_in_tensor, - py::handle py_out_tensor, - int src) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - distributed::ScatterOptions opts; - opts.root_rank = src; - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.Scatter(in_tensors, out_tensors, opts); - }, - py::arg("in"), - py::arg("out"), - py::arg("src"), - py::call_guard()) - .def( "scatter", [](distributed::ProcessGroup &self, @@ -762,6 +514,255 @@ void BindDistributed(py::module *m) { py::arg("sync_op"), py::call_guard()) + .def( + "barrier", + [](distributed::ProcessGroup &self, std::vector place_ids) { + distributed::BarrierOptions opts; + opts.place_ids = place_ids; + return self.Barrier(opts); + }, + py::arg("place_ids") = std::vector{}, + py::call_guard()) + + // TODO(liyurui): Interface below will be removed in the future. + .def( + "allreduce", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + distributed::ReduceOp op) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + distributed::AllreduceOptions opts; + opts.reduce_op = op; + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.AllReduce(tensors, tensors, opts); + }, + py::arg("tensor"), + py::arg("op") = distributed::ReduceOp::SUM, + py::call_guard()) + + .def( + "broadcast", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int source_rank) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + distributed::BroadcastOptions opts; + opts.source_rank = source_rank; + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Broadcast(tensors, tensors, opts); + }, + py::arg("tensor"), + py::arg("source_rank"), + py::call_guard()) + + .def( + "send", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int dst) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Send(tensors, dst); + }, + py::arg("tensor"), + py::arg("dst"), + py::call_guard()) + + .def( + "send_partial", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int dst_rank, + int nranks, + int rank_id) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + int64_t numel = (*dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; + return self.Send_Partial(*dense, dst_rank, offset, send_numel); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("num"), + py::arg("id"), + py::call_guard()) + + .def( + "recv", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int src) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Recv(tensors, src); + }, + py::arg("tensor"), + py::arg("src"), + py::call_guard()) + + .def( + "recv_partial", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int src_rank, + int nranks, + int rank_id) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + int64_t numel = (*dense).numel(); + int64_t recv_numel = numel / nranks; + int64_t offset = recv_numel * rank_id; + return self.Recv_Partial(*dense, src_rank, offset, recv_numel); + }, + py::arg("tensor"), + py::arg("src"), + py::arg("num"), + py::arg("id"), + py::call_guard()) + + .def( + "all_gather", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + return self.AllGather(in_tensors, out_tensors); + }, + py::arg("in"), + py::arg("out"), + py::call_guard()) + + .def( + "all_gather_partial", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + int nranks, + int rank_id) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + int64_t numel = (*in_dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; + return self.AllGather_Partial( + in_tensors, out_tensors, offset, send_numel); + }, + py::arg("in"), + py::arg("out"), + py::arg("num"), + py::arg("id"), + py::call_guard()) + + .def( + "alltoall", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + return self.AllToAll(in_tensors, out_tensors); + }, + py::arg("in"), + py::arg("out"), + py::call_guard()) + + .def( + "alltoall_single", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + std::vector in_sizes, + std::vector out_sizes) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + return self.AllToAll_Single( + in_tensors, out_tensors, in_sizes, out_sizes); + }, + py::arg("in"), + py::arg("out"), + py::arg("in_sizes"), + py::arg("out_sizes"), + py::call_guard()) + + .def( + "reduce", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + int dst, + distributed::ReduceOp op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + distributed::ReduceOptions opts; + opts.reduce_op = op; + opts.root_rank = dst; + auto dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector tensors = {*dense}; + return self.Reduce(tensors, tensors, opts); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("op") = distributed::ReduceOp::SUM, + py::call_guard()) + + .def( + "scatter", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + int src) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + distributed::ScatterOptions opts; + opts.root_rank = src; + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + return self.Scatter(in_tensors, out_tensors, opts); + }, + py::arg("in"), + py::arg("out"), + py::arg("src"), + py::call_guard()) + .def( "_reduce_scatter_base", [](distributed::ProcessGroup &self, @@ -788,7 +789,7 @@ void BindDistributed(py::module *m) { std::shared_ptr>( *m, "ProcessGroupStream", ProcessGroup) .def( - "allgather_on_calc_stream", + "all_gather_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_out_tensor_list, py::handle py_in_tensor) { @@ -818,7 +819,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "allgather_into_tensor_on_calc_stream", + "all_gather_into_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_out_tensor, py::handle py_in_tensor) { @@ -873,7 +874,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "allreduce_on_calc_stream", + "all_reduce_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_tensor, distributed::ReduceOp op) { @@ -890,11 +891,11 @@ void BindDistributed(py::module *m) { /*use_calc_stream*/ true); }, py::arg("tensor"), - py::arg("op"), + py::arg("op") = distributed::ReduceOp::SUM, py::call_guard()) .def( - "alltoall_on_calc_stream", + "all_to_all_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_in_tensor_list, py::handle py_out_tensor_list) { @@ -927,7 +928,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "alltoall_tensor_on_calc_stream", + "all_to_all_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_in_tensor, py::handle py_out_tensor) { @@ -951,7 +952,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "alltoall_single_on_calc_stream", + "all_to_all_single_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_in_tensor, py::handle py_out_tensor, diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 4923a531f9..cd5b5ac914 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -546,7 +546,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): out = paddle.empty(tensor_shape, tensor.dtype) else: out = paddle.concat(tensor_list, axis=0) - task = group.process_group.all_gather(tensor, out) + task = group.process_group.all_gather_into_tensor(out, tensor, sync_op) task.wait() tensor_list.clear() list_of_tensor = paddle.split(out, group.nranks, 0) diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index e74623f948..12f9e08f9d 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -44,12 +44,12 @@ def _all_gather_into_tensor_in_dygraph( _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) if use_calc_stream: - return group.process_group.allgather_into_tensor_on_calc_stream( + return group.process_group.all_gather_into_tensor_on_calc_stream( out_tensor, in_tensor, ) - task = group.process_group.allgather_into_tensor( + task = group.process_group.all_gather_into_tensor( out_tensor, in_tensor, sync_op ) if sync_op: @@ -69,9 +69,11 @@ def _all_gather_in_dygraph( _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) if use_calc_stream: - return group.process_group.allgather_on_calc_stream(tensor_list, tensor) + return group.process_group.all_gather_on_calc_stream( + tensor_list, tensor + ) - task = group.process_group.allgather(tensor_list, tensor, sync_op) + task = group.process_group.all_gather(tensor_list, tensor, sync_op) if sync_op: task.wait() diff --git a/python/paddle/distributed/communication/stream/all_reduce.py b/python/paddle/distributed/communication/stream/all_reduce.py index 79a359ab5f..1969b4d058 100644 --- a/python/paddle/distributed/communication/stream/all_reduce.py +++ b/python/paddle/distributed/communication/stream/all_reduce.py @@ -25,11 +25,10 @@ from paddle.distributed.communication.group import ( def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): op_type = _get_reduce_op(op, "allreduce") - group = _get_global_group() if group is None else group if use_calc_stream: - return group.process_group.allreduce_on_calc_stream(tensor, op_type) + return group.process_group.all_reduce_on_calc_stream(tensor, op_type) - task = group.process_group.allreduce(tensor, op_type, sync_op) + task = group.process_group.all_reduce(tensor, op_type, sync_op) if sync_op: task.wait() @@ -119,6 +118,7 @@ def all_reduce( ) if framework.in_dygraph_mode(): + group = _get_global_group() if group is None else group return _all_reduce_in_dygraph( tensor, op, group, sync_op, use_calc_stream ) diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index 663079fc0a..d05b53564a 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -47,11 +47,11 @@ def _all_to_all_tensor_in_dygraph( _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) if use_calc_stream: - return group.process_group.alltoall_tensor_on_calc_stream( + return group.process_group.all_to_all_tensor_on_calc_stream( in_tensor, out_tensor ) - task = group.process_group.alltoall_tensor(in_tensor, out_tensor, sync_op) + task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op) if sync_op: task.wait() @@ -74,11 +74,11 @@ def _all_to_all_in_dygraph( ) if use_calc_stream: - return group.process_group.alltoall_on_calc_stream( + return group.process_group.all_to_all_on_calc_stream( in_tensor_list, out_tensor_list ) - task = group.process_group.alltoall( + task = group.process_group.all_to_all( in_tensor_list, out_tensor_list, sync_op ) if sync_op: @@ -249,11 +249,11 @@ def _alltoall_single_in_dygraph( in_split_sizes = [] if use_calc_stream: - return group.process_group.alltoall_single_on_calc_stream( + return group.process_group.all_to_all_single_on_calc_stream( in_tensor, out_tensor, in_split_sizes, out_split_sizes ) - task = group.process_group.alltoall_single( + task = group.process_group.all_to_all_single( in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op ) if sync_op: diff --git a/python/paddle/distributed/communication/stream/reduce_scatter.py b/python/paddle/distributed/communication/stream/reduce_scatter.py index 3e46b51dde..aa0d1e9b95 100644 --- a/python/paddle/distributed/communication/stream/reduce_scatter.py +++ b/python/paddle/distributed/communication/stream/reduce_scatter.py @@ -52,7 +52,6 @@ def _reduce_scatter_tensor_in_dygraph( caller="reduce_scatter", ): op_type = _get_reduce_op(op, caller) - group = _get_global_group() if group is None else group _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) @@ -74,7 +73,6 @@ def _reduce_scatter_in_dygraph( tensor, tensor_list, op, group, sync_op, use_calc_stream ): op_type = _get_reduce_op(op, "reduce_scatter") - group = _get_global_group() if group is None else group _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) @@ -149,6 +147,7 @@ def reduce_scatter( ) if framework.in_dygraph_mode(): + group = _get_global_group() if group is None else group if paddle.is_tensor(tensor_or_tensor_list): return _reduce_scatter_tensor_in_dygraph( tensor, @@ -230,6 +229,7 @@ def _reduce_scatter_base( ) if framework.in_dygraph_mode(): + group = _get_global_group() if group is None else group return _reduce_scatter_tensor_in_dygraph( out_tensor, in_tensor, diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index f76b1d5232..40795ac258 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -62,7 +62,7 @@ def _c_identity(tensor, group=None): @staticmethod def backward(ctx, dy): op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") - group.process_group.allreduce_on_calc_stream(dy, op_type) + group.process_group.all_reduce_on_calc_stream(dy, op_type) return dy return c_identity_eager.apply(tensor) @@ -255,7 +255,7 @@ def _mp_allreduce( if use_calc_stream: op_type = _get_reduce_op(op, "_mp_allreduce") - group.process_group.allreduce_on_calc_stream( + group.process_group.all_reduce_on_calc_stream( tensor, op_type ) return tensor -- GitLab