未验证 提交 0ad25fb9 编写于 作者: L lilong12 提交者: GitHub

initialize processgroupnccl with store (#40181)

上级 f5ec0314
......@@ -156,36 +156,27 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const ProcessGroupStrategy& strategy,
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int size)
: ProcessGroup(rank, size), strategy_(strategy) {}
: ProcessGroup(rank, size), store_(store) {}
void ProcessGroupNCCL::BcastNCCLId(
std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto& ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT
if (rank_ == 0) {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i);
auto nccl_id = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&nccl_ids[i]),
reinterpret_cast<uint8_t*>(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES);
store_->set(key, nccl_id);
}
platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else {
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&nccl_ids);
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i);
auto ret = store_->get(key);
std::memcpy(&nccl_ids[i], ret.data(), ret.size());
}
}
void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT
int server_fd = -1;
if (rank_ != 0) {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastNCCLId(nccl_ids, 0, server_fd);
}
// create NCCLManager cache for places_key
......@@ -213,8 +204,8 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
}
BroadcastUniqueNCCLID(nccl_ids);
VLOG(3) << "init nccl rank: " << strategy_.local_rank_
<< ", nranks: " << strategy_.nranks_ << ", place: " << places_key
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
std::vector<std::unique_ptr<CUDADeviceContext>> dev_ctx;
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
......@@ -75,7 +76,7 @@ class ProcessGroupNCCL : public ProcessGroup {
private:
};
ProcessGroupNCCL(const ProcessGroupStrategy& strategy, int rank, int size);
ProcessGroupNCCL(const std::shared_ptr<Store>& store, int rank, int size);
const std::string GetBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
......@@ -118,7 +119,7 @@ class ProcessGroupNCCL : public ProcessGroup {
const std::vector<Tensor>& inputs);
protected:
ProcessGroupStrategy strategy_;
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
......
......@@ -25,15 +25,26 @@ namespace distributed {
class Store {
public:
Store() = delete;
Store() : _timeout(tcputils::kNoTimeout) {}
explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {}
virtual ~Store() = default;
virtual int64_t add(const std::string& key, int64_t value) = 0;
virtual std::vector<uint8_t> get(const std::string& key) = 0;
virtual void wait(const std::string& key) = 0;
virtual void set(const std::string& key,
const std::vector<uint8_t>& value) = 0;
virtual int64_t add(const std::string& key, int64_t value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual std::vector<uint8_t> get(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void wait(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void set(const std::string& key, const std::vector<uint8_t>& value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual const std::chrono::seconds& timeout() const { return _timeout; }
......
......@@ -30,18 +30,42 @@ namespace pybind {
using TCPStore = paddle::distributed::TCPStore;
void BindTCPStore(py::module* m) {
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore")
void BindTCPStore(py::module *m) {
auto Store =
py::class_<distributed::Store, std::shared_ptr<distributed::Store>>(
*m, "Store")
.def(py::init<>())
.def("set",
[](distributed::Store &self, const std::string &key,
const std::string &value) {
std::vector<uint8_t> data(value.begin(), value.end());
self.set(key, data);
},
py::arg("key"), py::arg("value"),
py::call_guard<py::gil_scoped_release>())
.def("get",
[](distributed::Store &self,
const std::string &key) -> py::bytes {
auto data = self.get(key);
return py::bytes(reinterpret_cast<char *>(data.data()),
data.size());
},
py::arg("key"), py::call_guard<py::gil_scoped_release>())
.def("add", &distributed::Store::add,
py::call_guard<py::gil_scoped_release>())
.def("wait", &distributed::Store::wait,
py::call_guard<py::gil_scoped_release>());
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
.def(py::init([](std::string hostname, uint16_t port, bool is_master,
size_t world_size, std::chrono::seconds timeout) {
return std::make_shared<TCPStore>(hostname, port, is_master,
world_size, timeout);
}),
py::arg("hostname"), py::arg("port"), py::arg("is_master"),
py::arg("world_size"), py::arg("timeout"),
py::call_guard<py::gil_scoped_release>())
.def("add", &TCPStore::add)
.def("get", &TCPStore::get);
py::arg("world_size"),
py::arg("timeout") = distributed::tcputils::kNoTimeout,
py::call_guard<py::gil_scoped_release>());
}
} // namespace pybind
......
......@@ -197,7 +197,7 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroup)
.def(py::init<const distributed::ProcessGroupStrategy &, int, int>(),
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int>(),
py::call_guard<py::gil_scoped_release>());
#endif
......@@ -210,44 +210,6 @@ void BindDistributed(py::module *m) {
.def("synchronize", &distributed::ProcessGroup::Task::Synchronize,
py::call_guard<py::gil_scoped_release>());
// define parallel strategy, it will be removed
py::class_<distributed::ProcessGroupStrategy> pg_strategy(
*m, "ProcessGroupStrategy", "");
pg_strategy.def(py::init())
.def_property("nranks",
[](const distributed::ProcessGroupStrategy &self) {
return self.nranks_;
},
[](distributed::ProcessGroupStrategy &self, int nranks) {
self.nranks_ = nranks;
})
.def_property("local_rank",
[](const distributed::ProcessGroupStrategy &self) {
return self.local_rank_;
},
[](distributed::ProcessGroupStrategy &self,
int local_rank) { self.local_rank_ = local_rank; })
.def_property(
"trainer_endpoints",
[](const distributed::ProcessGroupStrategy &self) {
return self.trainer_endpoints_;
},
[](distributed::ProcessGroupStrategy &self,
std::vector<std::string> eps) { self.trainer_endpoints_ = eps; })
.def_property("current_endpoint",
[](const distributed::ProcessGroupStrategy &self) {
return self.current_endpoint_;
},
[](distributed::ProcessGroupStrategy &self,
const std::string &ep) { self.current_endpoint_ = ep; })
.def_property("nrings",
[](const distributed::ProcessGroupStrategy &self) {
return self.nrings_;
},
[](distributed::ProcessGroupStrategy &self, int nrings) {
self.nrings_ = nrings;
});
#if defined(PADDLE_WITH_GLOO)
py::class_<GlooOptions>(*m, "GlooOptions")
.def(py::init<>())
......@@ -279,9 +241,7 @@ void BindDistributed(py::module *m) {
return std::make_shared<ProcessGroupGloo>(store, rank, world_size,
opts);
}),
py::arg("store"), py::arg("rank"),
py::arg("world_size"), // py::arg("timeout") =
// kProcessGroupDefaultTimeout,
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice);
......
......@@ -27,22 +27,13 @@ import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
ProcessGroupStrategy = core.ProcessGroupStrategy
def init_process_group(strategy=None):
# this will remove
if strategy is None:
strategy = ProcessGroupStrategy()
strategy.nranks = ParallelEnv().nranks
strategy.local_rank = ParallelEnv().local_rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
pg_group = core.ProcessGroupNCCL(strategy, strategy.local_rank,
strategy.nranks)
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks)
pg_group = core.ProcessGroupNCCL(store, rank, nranks)
return pg_group
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册