diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index ca9ab2c7aecff47924f0198802d710b7661f5576..0013597fd516d15c7d502370eec77e1a6a5dca88 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -39,20 +39,19 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { class NCCLGroupGuard { public: + static std::mutex &NCCLMutex() { + static std::mutex mtx; + return mtx; + } + inline NCCLGroupGuard() { - mutex().lock(); + NCCLMutex().lock(); PADDLE_ENFORCE(dynload::ncclGroupStart()); } inline ~NCCLGroupGuard() { PADDLE_ENFORCE(dynload::ncclGroupEnd()); - mutex().unlock(); - } - - private: - static std::mutex &mutex() { - static std::mutex mtx; - return mtx; + NCCLMutex().unlock(); } }; @@ -68,26 +67,6 @@ struct NCCLContext { int device_id() const { return boost::get(ctx_->GetPlace()).device; } - - static void InitNCCLContext(std::unordered_map *contexts, - const std::vector &places) { - std::vector comms; - std::vector devs; - comms.resize(contexts->size()); - devs.reserve(contexts->size()); - - for (auto &p : places) { - devs.push_back(boost::get(p).device); - } - - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(contexts->size()), &devs[0])); - - int i = 0; - for (auto &dev_id : devs) { - contexts->at(dev_id).comm_ = comms[i++]; - } - } }; struct NCCLContextMap { @@ -107,12 +86,12 @@ struct NCCLContextMap { "NCCL Context Map does not support contain two or more same device"); if (places.size() > 1) { - std::vector comms; - comms.resize(order_.size()); - - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(order_.size()), &order_[0])); - + std::unique_ptr comms(new ncclComm_t[order_.size()]); + { + std::lock_guard guard(NCCLGroupGuard::NCCLMutex()); + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + comms.get(), static_cast(order_.size()), order_.data())); + } int i = 0; for (auto &dev_id : order_) { contexts_.at(dev_id).comm_ = comms[i++]; @@ -120,6 +99,9 @@ struct NCCLContextMap { } } + NCCLContextMap(const NCCLContextMap &other) = delete; + NCCLContextMap &operator=(const NCCLContextMap &other) = delete; + CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); } CUDADeviceContext *DevCtx(platform::Place p) const {