未验证 提交 bb2cb762 编写于 作者: L lilong12 提交者: GitHub

Use store for gloo process group (#40629)

上级 70726696
...@@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank, ...@@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank,
"Only CPU place is supported for ProcessGroupGloo.")); "Only CPU place is supported for ProcessGroupGloo."));
} }
ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, ProcessGroupGloo::ProcessGroupGloo(
int rank, int world_size, const std::shared_ptr<paddle::distributed::Store>& store, int rank,
const std::shared_ptr<GlooOptions> options) int world_size, const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(store) { : ProcessGroup(rank, world_size), _tag(0), _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size); _context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store = auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(0), *_store); ::gloo::rendezvous::PrefixStore(std::to_string(0), *_store);
......
...@@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup {
class GlooStore : public ::gloo::rendezvous::Store { class GlooStore : public ::gloo::rendezvous::Store {
public: public:
explicit GlooStore( explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store)
const std::shared_ptr<paddle::distributed::TCPStore>& store)
: _store(store) {} : _store(store) {}
~GlooStore() = default; ~GlooStore() = default;
...@@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup {
} }
protected: protected:
std::shared_ptr<paddle::distributed::TCPStore> _store; std::shared_ptr<paddle::distributed::Store> _store;
}; };
class GlooOptions { class GlooOptions {
...@@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<::gloo::transport::Device> device; std::shared_ptr<::gloo::transport::Device> device;
}; };
explicit ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, int rank, explicit ProcessGroupGloo(
int world_size, const std::shared_ptr<paddle::distributed::Store>& store, int rank,
std::shared_ptr<GlooOptions> options); int world_size, std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
...@@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup {
protected: protected:
uint32_t _tag; uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context; std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<GlooStore> _store; std::shared_ptr<::gloo::rendezvous::Store> _store;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -235,25 +235,13 @@ void BindDistributed(py::module *m) { ...@@ -235,25 +235,13 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
py::class_<GlooOptions>(*m, "GlooOptions")
.def(py::init<>())
.def_readwrite("_device", &GlooOptions::device)
.def_static("create", &GlooOptions::create);
py::class_<GlooStore, std::shared_ptr<GlooStore>>(*m, "GlooStore")
.def(py::init(
[](const std::shared_ptr<paddle::distributed::TCPStore> &store) {
return std::make_shared<GlooStore>(store);
}),
py::call_guard<py::gil_scoped_release>());
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>( py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup) *m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<GlooStore> &, int, int, .def(py::init<const std::shared_ptr<paddle::distributed::Store> &, int,
std::shared_ptr<GlooOptions> &>(), int, std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<GlooStore> &store, int rank, .def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
int world_size) { int rank, int world_size) {
auto opts = GlooOptions::create(); auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) { if (ifname && strlen(ifname) > 1) {
......
...@@ -47,9 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -47,9 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase):
is_master = True if rank == 0 else False is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master, store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master,
nranks, datetime.timedelta(0)) nranks, datetime.timedelta(0))
gloo_store = paddle.fluid.core.GlooStore(store) pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
opt = paddle.fluid.core.GlooOptions()
pg = paddle.fluid.core.ProcessGroupGloo(gloo_store, rank, nranks)
# test allreduce sum # test allreduce sum
# rank 0 # rank 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册