未验证 提交 2169e6fb 编写于 作者: Y Yi Liu 提交者: GitHub

Initialize global nccl_comm in PE (#23275)

上级 012886df
...@@ -127,14 +127,9 @@ void NCCLParallelContext::Init() { ...@@ -127,14 +127,9 @@ void NCCLParallelContext::Init() {
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id)); // it will assign nccl_comm in CUDADeviceContext within ring_id 0
platform::NCCLComm *nccl_comm =
platform::NCCLCommContext::Instance().CreateNCCLComm( platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0); &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(pool.Get(place_));
dev_ctx->set_nccl_comm(nccl_comm->comm());
} }
#endif #endif
......
...@@ -38,7 +38,8 @@ class NCCLCommImpl : public NCCLComm { ...@@ -38,7 +38,8 @@ class NCCLCommImpl : public NCCLComm {
return boost::get<CUDAPlace>(dev_ctx_->GetPlace()).device; return boost::get<CUDAPlace>(dev_ctx_->GetPlace()).device;
} }
ncclComm_t comm() const override { return dev_ctx_->nccl_comm(); } void set_comm(ncclComm_t comm) { comm_ = comm; }
ncclComm_t comm() const override { return comm_; }
cudaStream_t stream() const override { return dev_ctx_->stream(); } cudaStream_t stream() const override { return dev_ctx_->stream(); }
...@@ -50,6 +51,7 @@ class NCCLCommImpl : public NCCLComm { ...@@ -50,6 +51,7 @@ class NCCLCommImpl : public NCCLComm {
int ring_id_; int ring_id_;
int nranks_; int nranks_;
int rank_; int rank_;
ncclComm_t comm_;
std::unique_ptr<CUDADeviceContext> dev_ctx_; std::unique_ptr<CUDADeviceContext> dev_ctx_;
}; };
...@@ -65,34 +67,18 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, ...@@ -65,34 +67,18 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
comm_vec_.push_back(comm);
std::unique_ptr<CUDADeviceContext> dev_ctx( auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id);
new CUDADeviceContext(CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
comm_map_mutex_.unlock();
VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created"; << " has been created on device " << dev_id;
std::call_once(once_flag_, []() { std::call_once(once_flag_, []() {
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); }); std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
}); });
return comm_map_[ring_id][dev_id].get(); return comm_wrapper;
} }
void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids, void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
...@@ -103,38 +89,56 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids, ...@@ -103,38 +89,56 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
ncclComm_t comms[kDevices]; ncclComm_t comms[kDevices];
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
comms, dev_ids.size(), dev_ids.data())); comms, dev_ids.size(), dev_ids.data()));
comm_vec_.insert(comm_vec_.end(), comms, comms + kDevices);
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0); PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0);
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
auto& dev2comm = comm_map_[ring_id];
for (size_t i = 0; i < dev_ids.size(); ++i) { for (size_t i = 0; i < dev_ids.size(); ++i) {
AssignNCCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id);
VLOG(1) << "nccl communicator of rank " << i << " in ring " << ring_id
<< " has been created on device " << dev_ids[i];
}
std::call_once(once_flag_, []() {
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
});
}
NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
int dev_id, int ring_id) {
std::unique_ptr<CUDADeviceContext> dev_ctx( std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_ids[i]))); new CUDADeviceContext(CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comms[i]);
NCCLCommImpl* c = new NCCLCommImpl; NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id); c->set_ring_id(ring_id);
c->set_nranks(dev_ids.size()); c->set_nranks(nranks);
c->set_rank(i); c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx)); c->set_dev_ctx(std::move(dev_ctx));
dev2comm.emplace(dev_ids[i], std::unique_ptr<NCCLComm>(c)); comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
comm_map_mutex_.unlock();
if (ring_id == 0) {
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
} }
std::call_once(once_flag_, []() { return comm_map_[ring_id][dev_id].get();
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
});
} }
void NCCLCommContext::ReleaseNCCLComms() { void NCCLCommContext::ReleaseNCCLComms() {
// CUDADeviceContext maintain the lifetime of nccl_comm_t, so we should not for (auto comm : comm_vec_) {
// destroy nccl_comm_t explicitly. Please refer to PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::CUDADeviceContext::~CUDADeviceContext() platform::dynload::ncclCommDestroy(comm),
for (auto& p : comm_map_) { platform::errors::External("Fail to destroy nccl comm"));
for (auto& q : p.second) {
q.second.reset();
}
} }
} }
......
...@@ -71,6 +71,11 @@ class NCCLCommContext { ...@@ -71,6 +71,11 @@ class NCCLCommContext {
void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0); void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
// a latter comm with the same dev_id and the same ring_id
// will override the former
NCCLComm* AssignNCCLComm(ncclComm_t comm, int nranks, int rank, int dev_id,
int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode // retrieve a communicator by the ring id in multiprocessing mode
NCCLComm* Get(int ring_id) const { NCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0, PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0,
...@@ -105,6 +110,8 @@ class NCCLCommContext { ...@@ -105,6 +110,8 @@ class NCCLCommContext {
// ring id to dev-NCCLComm // ring id to dev-NCCLComm
std::map<int, std::map<int, std::unique_ptr<NCCLComm>>> comm_map_; std::map<int, std::map<int, std::unique_ptr<NCCLComm>>> comm_map_;
std::vector<ncclComm_t> comm_vec_;
void ReleaseNCCLComms(); void ReleaseNCCLComms();
NCCLCommContext() = default; NCCLCommContext() = default;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -232,17 +233,30 @@ class NCCLCommunicator { ...@@ -232,17 +233,30 @@ class NCCLCommunicator {
auto ptr = new platform::NCCLContextMap(places); auto ptr = new platform::NCCLContextMap(places);
VLOG(1) << "init local trainer"; VLOG(1) << "init local trainer";
flat_ctxs_.emplace_back(ptr); flat_ctxs_.emplace_back(ptr);
return; } else {
}
for (size_t i = 0; i < nccl_ids.size(); i++) { for (size_t i = 0; i < nccl_ids.size(); i++) {
auto ptr = new platform::NCCLContextMap(places, nccl_ids[i], trainers_num, auto ptr = new platform::NCCLContextMap(places, nccl_ids[i],
trainer_id); trainers_num, trainer_id);
VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
flat_ctxs_.emplace_back(ptr); flat_ctxs_.emplace_back(ptr);
} }
} }
// as Executor have no way to use ncclComm created by ParallelExecutor,
// we assign all flatten contexts to NCCLCommContext to fix.
int nranks = static_cast<int>(trainers_num * places.size());
int nrings = static_cast<int>(flat_ctxs_.size());
for (int ring_id = 0; ring_id < nrings; ++ring_id) {
for (size_t p = 0; p < places.size(); ++p) {
int rank = trainer_id * places.size() + p;
int dev_id = boost::get<CUDAPlace>(places[p]).device;
auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
NCCLCommContext::Instance().AssignNCCLComm(ctx.comm_, nranks, rank,
dev_id, ring_id);
}
}
}
void InitHierarchicalCtxs(const std::vector<platform::Place> &places, void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
const std::vector<ncclUniqueId *> &inter_nccl_ids, const std::vector<ncclUniqueId *> &inter_nccl_ids,
const std::vector<ncclUniqueId *> &exter_nccl_ids, const std::vector<ncclUniqueId *> &exter_nccl_ids,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册