diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 88d8fb69eb69808729b6e0ec3c374569b1575671..67715f410d443c38a1c5d92c560a35a909c5ec1c 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -156,36 +156,27 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // Same as Wait void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } -ProcessGroupNCCL::ProcessGroupNCCL(const ProcessGroupStrategy& strategy, +ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size) - : ProcessGroup(rank, size), strategy_(strategy) {} - -void ProcessGroupNCCL::BcastNCCLId( - std::vector& nccl_ids, // NOLINT - int root, int server_fd) { - if (strategy_.local_rank_ == root) { - std::vector other_trainers; - for (auto& ep : strategy_.trainer_endpoints_) { - if (ep != strategy_.current_endpoint_) { - other_trainers.push_back(ep); - } - } - platform::SendBroadCastCommID(other_trainers, &nccl_ids); - } else { - platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_, - &nccl_ids); - } -} + : ProcessGroup(rank, size), store_(store) {} void ProcessGroupNCCL::BroadcastUniqueNCCLID( std::vector& nccl_ids) { // NOLINT - - int server_fd = -1; - if (rank_ != 0) { - server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) - .socket(); + if (rank_ == 0) { + for (size_t i = 0; i < nccl_ids.size(); i++) { + auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i); + auto nccl_id = std::vector( + reinterpret_cast(&nccl_ids[i]), + reinterpret_cast(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES); + store_->set(key, nccl_id); + } + } else { + for (size_t i = 0; i < nccl_ids.size(); i++) { + auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i); + auto ret = store_->get(key); + std::memcpy(&nccl_ids[i], ret.data(), ret.size()); + } } - BcastNCCLId(nccl_ids, 0, server_fd); } // create NCCLManager cache for places_key @@ -213,8 +204,8 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( } BroadcastUniqueNCCLID(nccl_ids); - VLOG(3) << "init nccl rank: " << strategy_.local_rank_ - << ", nranks: " << strategy_.nranks_ << ", place: " << places_key + VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ + << ", place: " << places_key << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); std::vector> dev_ctx; diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index d63a5e768382c6fd9141ff9d96a3187b0adab7de..aa2a2b8fa2088cd30729ba5e6184ef7a9c507bf3 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -25,6 +25,7 @@ #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" @@ -75,7 +76,7 @@ class ProcessGroupNCCL : public ProcessGroup { private: }; - ProcessGroupNCCL(const ProcessGroupStrategy& strategy, int rank, int size); + ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size); const std::string GetBackendName() const override { return std::string(NCCL_BACKEND_NAME); @@ -118,7 +119,7 @@ class ProcessGroupNCCL : public ProcessGroup { const std::vector& inputs); protected: - ProcessGroupStrategy strategy_; + std::shared_ptr store_; std::shared_ptr nccl_comm_; std::mutex mutex_; std::unordered_map>> diff --git a/paddle/fluid/distributed/store/store.h b/paddle/fluid/distributed/store/store.h index 2581a74d7e8187b0a38b27a2f27e9b84ddf26b53..7b4ae7e70ff6f033e038f1c5214f46e0876257d2 100644 --- a/paddle/fluid/distributed/store/store.h +++ b/paddle/fluid/distributed/store/store.h @@ -25,15 +25,26 @@ namespace distributed { class Store { public: - Store() = delete; + Store() : _timeout(tcputils::kNoTimeout) {} explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {} virtual ~Store() = default; - virtual int64_t add(const std::string& key, int64_t value) = 0; - virtual std::vector get(const std::string& key) = 0; - virtual void wait(const std::string& key) = 0; - virtual void set(const std::string& key, - const std::vector& value) = 0; + virtual int64_t add(const std::string& key, int64_t value) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Implement the add method in the subclass.")); + } + virtual std::vector get(const std::string& key) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Implement the add method in the subclass.")); + } + virtual void wait(const std::string& key) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Implement the add method in the subclass.")); + } + virtual void set(const std::string& key, const std::vector& value) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Implement the add method in the subclass.")); + } virtual const std::chrono::seconds& timeout() const { return _timeout; } diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index c01accaf598aa849cf5406e96cc9b5743b46e448..1a6a395545a96b1980cae73ff65de3daef0acafc 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -30,18 +30,42 @@ namespace pybind { using TCPStore = paddle::distributed::TCPStore; -void BindTCPStore(py::module* m) { - py::class_>(*m, "TCPStore") +void BindTCPStore(py::module *m) { + auto Store = + py::class_>( + *m, "Store") + .def(py::init<>()) + .def("set", + [](distributed::Store &self, const std::string &key, + const std::string &value) { + std::vector data(value.begin(), value.end()); + self.set(key, data); + }, + py::arg("key"), py::arg("value"), + py::call_guard()) + .def("get", + [](distributed::Store &self, + const std::string &key) -> py::bytes { + auto data = self.get(key); + return py::bytes(reinterpret_cast(data.data()), + data.size()); + }, + py::arg("key"), py::call_guard()) + .def("add", &distributed::Store::add, + py::call_guard()) + .def("wait", &distributed::Store::wait, + py::call_guard()); + + py::class_>(*m, "TCPStore", Store) .def(py::init([](std::string hostname, uint16_t port, bool is_master, size_t world_size, std::chrono::seconds timeout) { return std::make_shared(hostname, port, is_master, world_size, timeout); }), py::arg("hostname"), py::arg("port"), py::arg("is_master"), - py::arg("world_size"), py::arg("timeout"), - py::call_guard()) - .def("add", &TCPStore::add) - .def("get", &TCPStore::get); + py::arg("world_size"), + py::arg("timeout") = distributed::tcputils::kNoTimeout, + py::call_guard()); } } // namespace pybind diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 17512863357d8dfe342f1a841471e1fdf1ac8072..9870eab8da9023cb6198a4a6da636664def60a17 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -197,7 +197,7 @@ void BindDistributed(py::module *m) { py::class_>( *m, "ProcessGroupNCCL", ProcessGroup) - .def(py::init(), + .def(py::init &, int, int>(), py::call_guard()); #endif @@ -210,44 +210,6 @@ void BindDistributed(py::module *m) { .def("synchronize", &distributed::ProcessGroup::Task::Synchronize, py::call_guard()); - // define parallel strategy, it will be removed - py::class_ pg_strategy( - *m, "ProcessGroupStrategy", ""); - pg_strategy.def(py::init()) - .def_property("nranks", - [](const distributed::ProcessGroupStrategy &self) { - return self.nranks_; - }, - [](distributed::ProcessGroupStrategy &self, int nranks) { - self.nranks_ = nranks; - }) - .def_property("local_rank", - [](const distributed::ProcessGroupStrategy &self) { - return self.local_rank_; - }, - [](distributed::ProcessGroupStrategy &self, - int local_rank) { self.local_rank_ = local_rank; }) - .def_property( - "trainer_endpoints", - [](const distributed::ProcessGroupStrategy &self) { - return self.trainer_endpoints_; - }, - [](distributed::ProcessGroupStrategy &self, - std::vector eps) { self.trainer_endpoints_ = eps; }) - .def_property("current_endpoint", - [](const distributed::ProcessGroupStrategy &self) { - return self.current_endpoint_; - }, - [](distributed::ProcessGroupStrategy &self, - const std::string &ep) { self.current_endpoint_ = ep; }) - .def_property("nrings", - [](const distributed::ProcessGroupStrategy &self) { - return self.nrings_; - }, - [](distributed::ProcessGroupStrategy &self, int nrings) { - self.nrings_ = nrings; - }); - #if defined(PADDLE_WITH_GLOO) py::class_(*m, "GlooOptions") .def(py::init<>()) @@ -279,9 +241,7 @@ void BindDistributed(py::module *m) { return std::make_shared(store, rank, world_size, opts); }), - py::arg("store"), py::arg("rank"), - py::arg("world_size"), // py::arg("timeout") = - // kProcessGroupDefaultTimeout, + py::arg("store"), py::arg("rank"), py::arg("world_size"), py::call_guard()) .def_static("create_default_device", &ProcessGroupGloo::createDefaultDevice); diff --git a/python/paddle/fluid/tests/unittests/process_group_nccl.py b/python/paddle/fluid/tests/unittests/process_group_nccl.py index 4833cea9a8d1ab7aafa82e4bb12f0c52902fa634..b1da0777feb3de1b1d6bb59a868802f736afb8e7 100644 --- a/python/paddle/fluid/tests/unittests/process_group_nccl.py +++ b/python/paddle/fluid/tests/unittests/process_group_nccl.py @@ -27,22 +27,13 @@ import paddle.fluid.core as core from paddle.fluid.framework import _test_eager_guard from paddle.fluid.dygraph.parallel import ParallelEnv -ProcessGroupStrategy = core.ProcessGroupStrategy - def init_process_group(strategy=None): - # this will remove - if strategy is None: - strategy = ProcessGroupStrategy() - strategy.nranks = ParallelEnv().nranks - strategy.local_rank = ParallelEnv().local_rank - strategy.trainer_endpoints = ParallelEnv().trainer_endpoints - strategy.current_endpoint = ParallelEnv().current_endpoint - if strategy.nranks < 2: - return - - pg_group = core.ProcessGroupNCCL(strategy, strategy.local_rank, - strategy.nranks) + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks) + pg_group = core.ProcessGroupNCCL(store, rank, nranks) return pg_group