未验证 提交 bb2cb762 编写于 作者: L lilong12 提交者: GitHub

Use store for gloo process group (#40629)

上级 70726696
......@@ -171,10 +171,10 @@ ProcessGroupGloo::GlooTask::GlooTask(int rank,
"Only CPU place is supported for ProcessGroupGloo."));
}
ProcessGroupGloo::ProcessGroupGloo(const std::shared_ptr<GlooStore>& store,
int rank, int world_size,
const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(store) {
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size), _tag(0), _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(0), *_store);
......
......@@ -52,8 +52,7 @@ class ProcessGroupGloo : public ProcessGroup {
class GlooStore : public ::gloo::rendezvous::Store {
public:
explicit GlooStore(
const std::shared_ptr<paddle::distributed::TCPStore>& store)
explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store)
: _store(store) {}
~GlooStore() = default;
......@@ -87,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroup {
}
protected:
std::shared_ptr<paddle::distributed::TCPStore> _store;
std::shared_ptr<paddle::distributed::Store> _store;
};
class GlooOptions {
......@@ -100,9 +99,9 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<::gloo::transport::Device> device;
};
explicit ProcessGroupGloo(const std::shared_ptr<GlooStore>& store, int rank,
int world_size,
std::shared_ptr<GlooOptions> options);
explicit ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default;
......@@ -145,7 +144,7 @@ class ProcessGroupGloo : public ProcessGroup {
protected:
uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<GlooStore> _store;
std::shared_ptr<::gloo::rendezvous::Store> _store;
};
} // namespace distributed
......
......@@ -235,25 +235,13 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_GLOO)
py::class_<GlooOptions>(*m, "GlooOptions")
.def(py::init<>())
.def_readwrite("_device", &GlooOptions::device)
.def_static("create", &GlooOptions::create);
py::class_<GlooStore, std::shared_ptr<GlooStore>>(*m, "GlooStore")
.def(py::init(
[](const std::shared_ptr<paddle::distributed::TCPStore> &store) {
return std::make_shared<GlooStore>(store);
}),
py::call_guard<py::gil_scoped_release>());
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<GlooStore> &, int, int,
std::shared_ptr<GlooOptions> &>(),
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &, int,
int, std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<GlooStore> &store, int rank,
int world_size) {
.def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
int rank, int world_size) {
auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
......
......@@ -47,9 +47,7 @@ class TestProcessGroupFp32(unittest.TestCase):
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master,
nranks, datetime.timedelta(0))
gloo_store = paddle.fluid.core.GlooStore(store)
opt = paddle.fluid.core.GlooOptions()
pg = paddle.fluid.core.ProcessGroupGloo(gloo_store, rank, nranks)
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
# test allreduce sum
# rank 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册