From 2169e6fb5888609c595011b6dd6cccc67a4da587 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Mon, 30 Mar 2020 19:39:59 +0800 Subject: [PATCH] Initialize global nccl_comm in PE (#23275) --- paddle/fluid/imperative/nccl_context.cc | 11 +-- paddle/fluid/platform/collective_helper.cc | 88 +++++++++++----------- paddle/fluid/platform/collective_helper.h | 7 ++ paddle/fluid/platform/nccl_helper.h | 26 +++++-- 4 files changed, 76 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index bc711400589..e9987c762b7 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -127,14 +127,9 @@ void NCCLParallelContext::Init() { VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; - PADDLE_ENFORCE(cudaSetDevice(gpu_id)); - platform::NCCLComm *nccl_comm = - platform::NCCLCommContext::Instance().CreateNCCLComm( - &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0); - - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto *dev_ctx = static_cast(pool.Get(place_)); - dev_ctx->set_nccl_comm(nccl_comm->comm()); + // it will assign nccl_comm in CUDADeviceContext within ring_id 0 + platform::NCCLCommContext::Instance().CreateNCCLComm( + &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0); } #endif diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 7e1ad018f3c..4c4a6533449 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -38,7 +38,8 @@ class NCCLCommImpl : public NCCLComm { return boost::get(dev_ctx_->GetPlace()).device; } - ncclComm_t comm() const override { return dev_ctx_->nccl_comm(); } + void set_comm(ncclComm_t comm) { comm_ = comm; } + ncclComm_t comm() const override { return comm_; } cudaStream_t stream() const override { return dev_ctx_->stream(); } @@ -50,6 +51,7 @@ class NCCLCommImpl : public NCCLComm { int ring_id_; int nranks_; int rank_; + ncclComm_t comm_; std::unique_ptr dev_ctx_; }; @@ -65,34 +67,18 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); + comm_vec_.push_back(comm); - std::unique_ptr dev_ctx( - new CUDADeviceContext(CUDAPlace(dev_id))); - dev_ctx->set_nccl_comm(comm); - - NCCLCommImpl* c = new NCCLCommImpl; - c->set_ring_id(ring_id); - c->set_nranks(nranks); - c->set_rank(rank); - c->set_dev_ctx(std::move(dev_ctx)); - - comm_map_mutex_.lock(); - if (comm_map_.count(ring_id) == 0) { - comm_map_.emplace(ring_id, std::map>()); - } - auto& dev2comm = comm_map_[ring_id]; - - dev2comm.emplace(dev_id, std::unique_ptr(c)); - comm_map_mutex_.unlock(); + auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id); VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id - << " has been created"; + << " has been created on device " << dev_id; std::call_once(once_flag_, []() { std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); }); }); - return comm_map_[ring_id][dev_id].get(); + return comm_wrapper; } void NCCLCommContext::CreateAllNCCLComms(const std::vector& dev_ids, @@ -103,23 +89,13 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector& dev_ids, ncclComm_t comms[kDevices]; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll( comms, dev_ids.size(), dev_ids.data())); + comm_vec_.insert(comm_vec_.end(), comms, comms + kDevices); PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0); - comm_map_.emplace(ring_id, std::map>()); - - auto& dev2comm = comm_map_[ring_id]; for (size_t i = 0; i < dev_ids.size(); ++i) { - std::unique_ptr dev_ctx( - new CUDADeviceContext(CUDAPlace(dev_ids[i]))); - dev_ctx->set_nccl_comm(comms[i]); - - NCCLCommImpl* c = new NCCLCommImpl; - c->set_ring_id(ring_id); - c->set_nranks(dev_ids.size()); - c->set_rank(i); - c->set_dev_ctx(std::move(dev_ctx)); - - dev2comm.emplace(dev_ids[i], std::unique_ptr(c)); + AssignNCCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id); + VLOG(1) << "nccl communicator of rank " << i << " in ring " << ring_id + << " has been created on device " << dev_ids[i]; } std::call_once(once_flag_, []() { @@ -127,14 +103,42 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector& dev_ids, }); } +NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank, + int dev_id, int ring_id) { + std::unique_ptr dev_ctx( + new CUDADeviceContext(CUDAPlace(dev_id))); + + NCCLCommImpl* c = new NCCLCommImpl; + c->set_ring_id(ring_id); + c->set_nranks(nranks); + c->set_rank(rank); + c->set_comm(comm); + c->set_dev_ctx(std::move(dev_ctx)); + + comm_map_mutex_.lock(); + if (comm_map_.count(ring_id) == 0) { + comm_map_.emplace(ring_id, std::map>()); + } + auto& dev2comm = comm_map_[ring_id]; + + dev2comm.emplace(dev_id, std::unique_ptr(c)); + comm_map_mutex_.unlock(); + + if (ring_id == 0) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(dev_id))); + dev_ctx->set_nccl_comm(comm); + } + + return comm_map_[ring_id][dev_id].get(); +} + void NCCLCommContext::ReleaseNCCLComms() { - // CUDADeviceContext maintain the lifetime of nccl_comm_t, so we should not - // destroy nccl_comm_t explicitly. Please refer to - // platform::CUDADeviceContext::~CUDADeviceContext() - for (auto& p : comm_map_) { - for (auto& q : p.second) { - q.second.reset(); - } + for (auto comm : comm_vec_) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclCommDestroy(comm), + platform::errors::External("Fail to destroy nccl comm")); } } diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index a2b1e06de1b..8816703d45d 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -71,6 +71,11 @@ class NCCLCommContext { void CreateAllNCCLComms(const std::vector& dev_ids, int ring_id = 0); + // a latter comm with the same dev_id and the same ring_id + // will override the former + NCCLComm* AssignNCCLComm(ncclComm_t comm, int nranks, int rank, int dev_id, + int ring_id = 0); + // retrieve a communicator by the ring id in multiprocessing mode NCCLComm* Get(int ring_id) const { PADDLE_ENFORCE_GT(comm_map_.count(ring_id), 0, @@ -105,6 +110,8 @@ class NCCLCommContext { // ring id to dev-NCCLComm std::map>> comm_map_; + std::vector comm_vec_; + void ReleaseNCCLComms(); NCCLCommContext() = default; diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 0d04b997cbb..cfb1616e84c 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -24,6 +24,7 @@ #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" @@ -232,14 +233,27 @@ class NCCLCommunicator { auto ptr = new platform::NCCLContextMap(places); VLOG(1) << "init local trainer"; flat_ctxs_.emplace_back(ptr); - return; + } else { + for (size_t i = 0; i < nccl_ids.size(); i++) { + auto ptr = new platform::NCCLContextMap(places, nccl_ids[i], + trainers_num, trainer_id); + VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; + flat_ctxs_.emplace_back(ptr); + } } - for (size_t i = 0; i < nccl_ids.size(); i++) { - auto ptr = new platform::NCCLContextMap(places, nccl_ids[i], trainers_num, - trainer_id); - VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; - flat_ctxs_.emplace_back(ptr); + // as Executor have no way to use ncclComm created by ParallelExecutor, + // we assign all flatten contexts to NCCLCommContext to fix. + int nranks = static_cast(trainers_num * places.size()); + int nrings = static_cast(flat_ctxs_.size()); + for (int ring_id = 0; ring_id < nrings; ++ring_id) { + for (size_t p = 0; p < places.size(); ++p) { + int rank = trainer_id * places.size() + p; + int dev_id = boost::get(places[p]).device; + auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id); + NCCLCommContext::Instance().AssignNCCLComm(ctx.comm_, nranks, rank, + dev_id, ring_id); + } } } -- GitLab