From c64190ecbb211c09054b0ffea25179fdcad50207 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 11 Apr 2018 14:44:22 +0800 Subject: [PATCH] Polish NCCLHelper --- paddle/fluid/platform/nccl_helper.h | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 3a2a423486..ca9ab2c7ae 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -61,7 +61,7 @@ struct NCCLContext { ncclComm_t comm_; explicit NCCLContext(int dev_id) - : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))) {} + : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {} cudaStream_t stream() const { return ctx_->stream(); } @@ -95,6 +95,7 @@ struct NCCLContextMap { std::vector order_; explicit NCCLContextMap(const std::vector &places) { + PADDLE_ENFORCE(!places.empty()); order_.reserve(places.size()); for (auto &p : places) { int dev_id = boost::get(p).device; @@ -105,15 +106,17 @@ struct NCCLContextMap { order_.size(), contexts_.size(), "NCCL Context Map does not support contain two or more same device"); - std::vector comms; - comms.resize(order_.size()); + if (places.size() > 1) { + std::vector comms; + comms.resize(order_.size()); - PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( - &comms[0], static_cast(order_.size()), &order_[0])); + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + &comms[0], static_cast(order_.size()), &order_[0])); - int i = 0; - for (auto &dev_id : order_) { - contexts_.at(dev_id).comm_ = comms[i++]; + int i = 0; + for (auto &dev_id : order_) { + contexts_.at(dev_id).comm_ = comms[i++]; + } } } -- GitLab