未验证 提交 2169e6fb 编写于 作者: Y Yi Liu 提交者: GitHub

Initialize global nccl_comm in PE (#23275)

上级 012886df
......@@ -127,14 +127,9 @@ void NCCLParallelContext::Init() {
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
platform::NCCLComm *nccl_comm =
platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(pool.Get(place_));
dev_ctx->set_nccl_comm(nccl_comm->comm());
// it will assign nccl_comm in CUDADeviceContext within ring_id 0
platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0);
}
#endif
......
......@@ -38,7 +38,8 @@ class NCCLCommImpl : public NCCLComm {
return boost::get<CUDAPlace>(dev_ctx_->GetPlace()).device;
}
ncclComm_t comm() const override { return dev_ctx_->nccl_comm(); }
void set_comm(ncclComm_t comm) { comm_ = comm; }
ncclComm_t comm() const override { return comm_; }
cudaStream_t stream() const override { return dev_ctx_->stream(); }
......@@ -50,6 +51,7 @@ class NCCLCommImpl : public NCCLComm {
int ring_id_;
int nranks_;
int rank_;
ncclComm_t comm_;
std::unique_ptr<CUDADeviceContext> dev_ctx_;
};
......@@ -65,34 +67,18 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
comm_vec_.push_back(comm);
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
comm_map_mutex_.unlock();
auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id);
VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created";
<< " has been created on device " << dev_id;
std::call_once(once_flag_, []() {
std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
});
return comm_map_[ring_id][dev_id].get();
return comm_wrapper;
}
void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
......@@ -103,23 +89,13 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
ncclComm_t comms[kDevices];
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
comms, dev_ids.size(), dev_ids.data()));
comm_vec_.insert(comm_vec_.end(), comms, comms + kDevices);
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0);
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
auto& dev2comm = comm_map_[ring_id];
for (size_t i = 0; i < dev_ids.size(); ++i) {
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_ids[i])));
dev_ctx->set_nccl_comm(comms[i]);
NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(dev_ids.size());
c->set_rank(i);
c->set_dev_ctx(std::move(dev_ctx));
dev2comm.emplace(dev_ids[i], std::unique_ptr<NCCLComm>(c));
AssignNCCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id);
VLOG(1) << "nccl communicator of rank " << i << " in ring " << ring_id
<< " has been created on device " << dev_ids[i];
}
std::call_once(once_flag_, []() {
......@@ -127,14 +103,42 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
});
}
NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
int dev_id, int ring_id) {
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_id)));
NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
comm_map_mutex_.unlock();
if (ring_id == 0) {
auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
}
return comm_map_[ring_id][dev_id].get();
}
void NCCLCommContext::ReleaseNCCLComms() {
// CUDADeviceContext maintain the lifetime of nccl_comm_t, so we should not
// destroy nccl_comm_t explicitly. Please refer to
// platform::CUDADeviceContext::~CUDADeviceContext()
for (auto& p : comm_map_) {
for (auto& q : p.second) {
q.second.reset();
}
for (auto comm : comm_vec_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommDestroy(comm),
platform::errors::External("Fail to destroy nccl comm"));
}
}
......
......@@ -71,6 +71,11 @@ class NCCLCommContext {
void CreateAllNCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
// a latter comm with the same dev_id and the same ring_id
// will override the former
NCCLComm* AssignNCCLComm(ncclComm_t comm, int nranks, int rank, int dev_id,
int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode
NCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0,
......@@ -105,6 +110,8 @@ class NCCLCommContext {
// ring id to dev-NCCLComm
std::map<int, std::map<int, std::unique_ptr<NCCLComm>>> comm_map_;
std::vector<ncclComm_t> comm_vec_;
void ReleaseNCCLComms();
NCCLCommContext() = default;
......
......@@ -24,6 +24,7 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
......@@ -232,14 +233,27 @@ class NCCLCommunicator {
auto ptr = new platform::NCCLContextMap(places);
VLOG(1) << "init local trainer";
flat_ctxs_.emplace_back(ptr);
return;
} else {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto ptr = new platform::NCCLContextMap(places, nccl_ids[i],
trainers_num, trainer_id);
VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
flat_ctxs_.emplace_back(ptr);
}
}
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto ptr = new platform::NCCLContextMap(places, nccl_ids[i], trainers_num,
trainer_id);
VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
flat_ctxs_.emplace_back(ptr);
// as Executor have no way to use ncclComm created by ParallelExecutor,
// we assign all flatten contexts to NCCLCommContext to fix.
int nranks = static_cast<int>(trainers_num * places.size());
int nrings = static_cast<int>(flat_ctxs_.size());
for (int ring_id = 0; ring_id < nrings; ++ring_id) {
for (size_t p = 0; p < places.size(); ++p) {
int rank = trainer_id * places.size() + p;
int dev_id = boost::get<CUDAPlace>(places[p]).device;
auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
NCCLCommContext::Instance().AssignNCCLComm(ctx.comm_, nranks, rank,
dev_id, ring_id);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册