diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc index 5dc43af117825bf95407255e93e1e4600e8ddd9a..cb82677a281e990d9837f081b0d4d2f3b0a34a26 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc @@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank, "Only CPU place is supported for ProcessGroupGloo.")); } -ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr& store, - int rank, int world_size, - const std::shared_ptr options) - : ProcessGroup(rank, world_size), _tag(0), _store(store) { +ProcessGroupGloo::ProcessGroupGloo( + const std::shared_ptr& store, int rank, + int world_size, const std::shared_ptr options) + : ProcessGroup(rank, world_size), _tag(0), _store(new GlooStore(store)) { _context = std::make_shared(rank, world_size); auto prefix_store = ::gloo::rendezvous::PrefixStore(std::to_string(0), *_store); diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index 24f156571a427128f09cd28e632212f47fa4cd47..71e0a40f8a76181d9f4db13ddd57b31de676910b 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup { class GlooStore : public ::gloo::rendezvous::Store { public: - explicit GlooStore( - const std::shared_ptr& store) + explicit GlooStore(const std::shared_ptr& store) : _store(store) {} ~GlooStore() = default; @@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup { } protected: - std::shared_ptr _store; + std::shared_ptr _store; }; class GlooOptions { @@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup { std::shared_ptr<::gloo::transport::Device> device; }; - explicit ProcessGroupGloo(const std::shared_ptr& store, int rank, - int world_size, - std::shared_ptr options); + explicit ProcessGroupGloo( + const std::shared_ptr& store, int rank, + int world_size, std::shared_ptr options); ~ProcessGroupGloo() = default; @@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup { protected: uint32_t _tag; std::shared_ptr _context; - std::shared_ptr _store; + std::shared_ptr<::gloo::rendezvous::Store> _store; }; } // namespace distributed diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 1df917b8c3594d4505d9e92cd9a8c64bffd50279..e89d8d96342e723724bb867a14bc4262c6ab7b16 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -235,25 +235,13 @@ void BindDistributed(py::module *m) { py::call_guard()); #if defined(PADDLE_WITH_GLOO) - py::class_(*m, "GlooOptions") - .def(py::init<>()) - .def_readwrite("_device", &GlooOptions::device) - .def_static("create", &GlooOptions::create); - - py::class_>(*m, "GlooStore") - .def(py::init( - [](const std::shared_ptr &store) { - return std::make_shared(store); - }), - py::call_guard()); - py::class_>( *m, "ProcessGroupGloo", ProcessGroup) - .def(py::init &, int, int, - std::shared_ptr &>(), + .def(py::init &, int, + int, std::shared_ptr &>(), py::call_guard()) - .def(py::init([](const std::shared_ptr &store, int rank, - int world_size) { + .def(py::init([](const std::shared_ptr &store, + int rank, int world_size) { auto opts = GlooOptions::create(); char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); if (ifname && strlen(ifname) > 1) { diff --git a/python/paddle/fluid/tests/unittests/process_group_gloo.py b/python/paddle/fluid/tests/unittests/process_group_gloo.py index c62c4615f74707796946137d3b44efc3cc8aeee9..b1f3a71ab3e94c7db53048b95d73795d155bd122 100644 --- a/python/paddle/fluid/tests/unittests/process_group_gloo.py +++ b/python/paddle/fluid/tests/unittests/process_group_gloo.py @@ -47,9 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase): is_master = True if rank == 0 else False store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master, nranks, datetime.timedelta(0)) - gloo_store = paddle.fluid.core.GlooStore(store) - opt = paddle.fluid.core.GlooOptions() - pg = paddle.fluid.core.ProcessGroupGloo(gloo_store, rank, nranks) + pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks) # test allreduce sum # rank 0