diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index ac61f6d874c5959d972d5ec651130a6c03e12ecd..ef626ea2985dd378dad14439031aef114546d61f 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -11,7 +11,7 @@ cc_library( if(WITH_DISTRIBUTE) cc_library( process_group_gloo - SRCS process_group_gloo.cc send_recv.cc + SRCS process_group_gloo.cc gloo_send_recv.cc DEPS phi_api eager_api gloo_wrapper tcp_store) endif() diff --git a/paddle/fluid/distributed/collective/send_recv.cc b/paddle/fluid/distributed/collective/gloo_send_recv.cc similarity index 95% rename from paddle/fluid/distributed/collective/send_recv.cc rename to paddle/fluid/distributed/collective/gloo_send_recv.cc index 2079e5972a1066ea20af825daaa8e647e4f87eb9..970cb6ec93dc2612b03509572197822f2f5eb60c 100644 --- a/paddle/fluid/distributed/collective/send_recv.cc +++ b/paddle/fluid/distributed/collective/gloo_send_recv.cc @@ -18,7 +18,7 @@ #include "gloo/common/logging.h" #include "gloo/math.h" #include "gloo/types.h" -#include "paddle/fluid/distributed/collective/send_recv.h" +#include "paddle/fluid/distributed/collective/gloo_send_recv.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/collective/send_recv.h b/paddle/fluid/distributed/collective/gloo_send_recv.h similarity index 100% rename from paddle/fluid/distributed/collective/send_recv.h rename to paddle/fluid/distributed/collective/gloo_send_recv.h diff --git a/paddle/fluid/distributed/collective/process_group_gloo.cc b/paddle/fluid/distributed/collective/process_group_gloo.cc index d778652a9ca2740d1f49424094afe0c076ea4398..8a87008484d99340030defb6a454502d973863fb 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.cc +++ b/paddle/fluid/distributed/collective/process_group_gloo.cc @@ -25,12 +25,13 @@ #endif #include +#include #include #include #include "paddle/fluid/distributed/collective/common.h" +#include "paddle/fluid/distributed/collective/gloo_send_recv.h" #include "paddle/fluid/distributed/collective/process_group_gloo.h" -#include "paddle/fluid/distributed/collective/send_recv.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/platform/enforce.h" @@ -680,6 +681,65 @@ std::shared_ptr ProcessGroupGloo::Scatter( return Scatter(&out_tensors[0], in_tensors[0], opts, true); } +class GatherGlooTask : public ProcessGroupGloo::GlooTask { + public: + GatherGlooTask(int rank, + const std::shared_ptr& context, + const phi::DenseTensor& input, // NOLINT + phi::DenseTensor* output, // NOLINT + int src, + uint32_t tag) + : ProcessGroupGloo::GlooTask(rank, {input}, CommType::GATHER), + _context(context), + _input(input), + _output(*output), + _src(src), + _tag(tag) {} + + void Run() override { _do_gather(_input, _output, _src); } + + private: + std::shared_ptr _context; + phi::DenseTensor _input; + phi::DenseTensor _output; + int _src; + uint32_t _tag; + + void _do_gather(phi::DenseTensor& in, // NOLINT + phi::DenseTensor& out, // NOLINT + int src) { + const auto& dtype = in.dtype(); + gloo::GatherOptions opts(_context); + if (rank_ == src) { + GENERATE_FUNC(dtype, set_output, opts, out); + } + GENERATE_FUNC(dtype, set_input, opts, in); + + opts.setRoot(src); + opts.setTag(_tag); + gloo::gather(opts); + } +}; + +std::shared_ptr ProcessGroupGloo::Gather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_NE( + use_calc_stream, + true, + platform::errors::InvalidArgument("Gloo cannot use use_calc_stream.")); + std::shared_ptr task; + auto tag = next_tag(); + auto context = get_context(); + task = std::make_shared( + rank_, context, in_tensor, out_tensor, opts.root_rank, tag); + task->Run(); + return task; +} + std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) { ::gloo::transport::tcp::attr attr; diff --git a/paddle/fluid/distributed/collective/process_group_gloo.h b/paddle/fluid/distributed/collective/process_group_gloo.h index ba3bad76b273d328dbf9191f65c4822c38e98ea8..c45b3e74d84938563efa96431775929f95a19ceb 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.h +++ b/paddle/fluid/distributed/collective/process_group_gloo.h @@ -150,6 +150,12 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { const ScatterOptions& opts, bool sync_op) override; + std::shared_ptr Gather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) override; + // TODO(sunyilun): methods below will be removed later std::shared_ptr Broadcast( std::vector& inputs, @@ -210,6 +216,15 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { return platform::DeviceContextPool::Instance().Get(place); } + phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const override { + PADDLE_ENFORCE_NE( + use_calc_stream, + true, + platform::errors::InvalidArgument("Gloo cannot use use_calc_stream.")); + return GetDeviceContext(place); + } + // Helper functions for Gloo. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( const std::string& hostname); diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index f3639ced91f1745879c13fd6f6ff5c1d4e49177a..46d690c69a0525682129f64c634591a8bdd16c99 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -499,7 +499,6 @@ void BindDistributed(py::module *m) { py::arg("src"), py::arg("sync_op"), py::call_guard()) - .def( "scatter_tensor", [](distributed::ProcessGroup &self, @@ -547,11 +546,12 @@ void BindDistributed(py::module *m) { auto *dev_ctx = self.GetDeviceContext(in_tensor.place(), use_calc_stream); - distributed::GatherOptions gather_ops{dst}; + distributed::GatherOptions gather_opts{dst}; auto task = self.Gather( - out_dense, in_dense, gather_ops, sync_op, use_calc_stream); + out_dense, in_dense, gather_opts, sync_op, use_calc_stream); SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); - if (!use_calc_stream) { + if (!use_calc_stream && + dev_ctx->GetPlace() != platform::CPUPlace()) { // calculate stream will wait comm stream task->UpdateWaitChain(*dev_ctx); } @@ -561,7 +561,7 @@ void BindDistributed(py::module *m) { py::arg("out"), py::arg("dst"), py::arg("sync_op"), - py::arg("use_calc_stream"), + py::arg("use_calc_stream") = false, py::call_guard()) .def( "barrier", diff --git a/python/paddle/fluid/tests/unittests/collective/process_group_gloo.py b/python/paddle/fluid/tests/unittests/collective/process_group_gloo.py index 17aa27f1bc0df2b473e72bdaa1c9a3e6d8ec8135..20dcae5928ad977a627cf87cf07f411923d469c4 100644 --- a/python/paddle/fluid/tests/unittests/collective/process_group_gloo.py +++ b/python/paddle/fluid/tests/unittests/collective/process_group_gloo.py @@ -113,15 +113,15 @@ class TestProcessGroupFp32(unittest.TestCase): send_recv_result_1 = paddle.assign(tensor_x) send_recv_result_2 = paddle.assign(tensor_y_2) if pg.rank() == 0: - task = pg.send(tensor_x, 1, True) - else: + task = pg.send(tensor_x, pg.size() - 1, True) + elif pg.rank() == pg.size() - 1: task = pg.recv(tensor_y_1, 0, True) assert np.array_equal(send_recv_result_1, tensor_y_1) if pg.rank() == 0: - task = pg.recv(tensor_x, 1, True) + task = pg.recv(tensor_x, pg.size() - 1, True) assert np.array_equal(send_recv_result_2, tensor_x) - else: + elif pg.rank() == pg.size() - 1: task = pg.send(tensor_y_2, 0, True) print("test send_recv api ok") @@ -204,6 +204,30 @@ class TestProcessGroupFp32(unittest.TestCase): assert np.array_equal(tensor_y, out2) print("test scatter api ok\n") + # test Gather + def test_gather(root): + tensor_x = [ + paddle.zeros(self.shape).astype(self.dtype) + for _ in range(pg.size()) + ] + tensor_y = [ + paddle.to_tensor( + np.random.random(self.shape).astype(self.dtype) + ) + for _ in range(pg.size()) + ] + if pg.rank() == root: + task = pg.gather(tensor_y[root], tensor_x, root, True) + task.wait() + assert np.array_equal(tensor_x, tensor_y) + else: + task = pg.gather(tensor_y[pg.rank()], tensor_x, root, True) + task.wait() + + test_gather(0) + test_gather(pg.size() - 1) + print("test gather api ok\n") + if __name__ == "__main__": unittest.main()