From 093d227a7796e50dc2f7a04094b4725c6f40f399 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 16 Apr 2018 10:33:01 +0800 Subject: [PATCH] Use mutex to stablize ncclCtxMap --- paddle/fluid/platform/nccl_helper.h | 50 +++++++++-------------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index ca9ab2c7a..0013597fd 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { class NCCLGroupGuard { public: + static std::mutex &NCCLMutex() { + static std::mutex mtx; + return mtx; + } + inline NCCLGroupGuard() { - mutex().lock(); + NCCLMutex().lock(); PADDLE_ENFORCE(dynload::ncclGroupStart()); } inline ~NCCLGroupGuard() { PADDLE_ENFORCE(dynload::ncclGroupEnd()); - mutex().unlock(); - } - - private: - static std::mutex &mutex() { - static std::mutex mtx; - return mtx; + NCCLMutex().unlock(); } }; @@ -68,26 +67,6 @@ struct NCCLContext { int device_id() const { return boost::get(ctx_->GetPlace()).device; } - - static void InitNCCLContext(std::unordered_map *contexts, - const std::vector &places) { - std::vector comms; - std::vector devs; - comms.resize(contexts->size()); - devs.reserve(contexts->size()); - - for (auto &p : places) { - devs.push_back(boost::get(p).device); - } - - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(contexts->size()), &devs[0])); - - int i = 0; - for (auto &dev_id : devs) { - contexts->at(dev_id).comm_ = comms[i++]; - } - } }; struct NCCLContextMap { @@ -107,12 +86,12 @@ struct NCCLContextMap { "NCCL Context Map does not support contain two or more same device"); if (places.size() > 1) { - std::vector comms; - comms.resize(order_.size()); - - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(order_.size()), &order_[0])); - + std::unique_ptr comms(new ncclComm_t[order_.size()]); + { + std::lock_guard guard(NCCLGroupGuard::NCCLMutex()); + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + comms.get(), static_cast(order_.size()), order_.data())); + } int i = 0; for (auto &dev_id : order_) { contexts_.at(dev_id).comm_ = comms[i++]; @@ -120,6 +99,9 @@ struct NCCLContextMap { } } + NCCLContextMap(const NCCLContextMap &other) = delete; + NCCLContextMap &operator=(const NCCLContextMap &other) = delete; + CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); } CUDADeviceContext *DevCtx(platform::Place p) const { -- GitLab