diff --git a/paddle/fluid/operators/nccl/nccl_gpu_common.cc b/paddle/fluid/operators/nccl/nccl_gpu_common.cc index fa6aafceb06aefd9b4d318cbfbc9eb6152262bdd..a3ea0a4f895bb736959c854d0f97d26586dc3c1e 100644 --- a/paddle/fluid/operators/nccl/nccl_gpu_common.cc +++ b/paddle/fluid/operators/nccl/nccl_gpu_common.cc @@ -16,5 +16,44 @@ limitations under the License. */ #include "paddle/fluid/platform/gpu_info.h" namespace paddle { -namespace platform {} // namespace platform +namespace platform { +namespace { +// TODO(panyx0718): Where to destroy them. +std::unique_ptr> global_comms; +std::unique_ptr> comm_id_map; +bool inited = false; +size_t last_num_gpus = -1; +} + +int Communicator::GetCommId(int device_id) const { + return comm_id_map->at(device_id); +} + +void Communicator::InitAll(const std::vector& gpus) { + if (inited && last_num_gpus == gpus.size()) { + return; + } + last_num_gpus = gpus.size(); + if (global_comms) { + for (size_t i = 0; i < global_comms->size(); ++i) { + // FIXME(dzh) : PADDLE_ENFORCE return void + dynload::ncclCommDestroy((*global_comms)[i]); + } + } + global_comms.reset(new std::vector()); + comm_id_map.reset(new std::unordered_map()); + global_comms->resize(gpus.size()); + for (size_t i = 0; i < gpus.size(); ++i) { + (*comm_id_map)[gpus[i]] = i; + } + PADDLE_ENFORCE( + dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data())); + inited = true; +} + +const std::vector& Communicator::comms() const { + return *global_comms; +} + +} // namespace platform } // namespace paddle diff --git a/paddle/fluid/operators/nccl/nccl_gpu_common.h b/paddle/fluid/operators/nccl/nccl_gpu_common.h index be8c8a8f2c3cc53db992e121c1728f8a254988c0..113f93e346681e568524f9fb6a0ab9a56de8569e 100644 --- a/paddle/fluid/operators/nccl/nccl_gpu_common.h +++ b/paddle/fluid/operators/nccl/nccl_gpu_common.h @@ -29,39 +29,16 @@ limitations under the License. */ namespace paddle { namespace platform { - constexpr int kInvalidGPUId = -1; struct Communicator { - std::vector comms_; - std::unordered_map comm_id_map_; - bool inited_; - Communicator() {} - int GetCommId(int device_id) const { return comm_id_map_.at(device_id); } - - void InitAll(const std::vector& gpus) { - comms_.resize(gpus.size()); - inited_ = false; - for (size_t i = 0; i < gpus.size(); ++i) { - comm_id_map_[gpus[i]] = i; - } - PADDLE_ENFORCE( - dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data())); - inited_ = true; - } + int GetCommId(int device_id) const; - ~Communicator() { - if (inited_) { - for (size_t i = 0; i < comms_.size(); ++i) { - // FIXME(dzh) : PADDLE_ENFORCE return void - dynload::ncclCommDestroy(comms_[i]); - } - } - } + void InitAll(const std::vector& gpus); - DISABLE_COPY_AND_ASSIGN(Communicator); + const std::vector& comms() const; }; } // namespace platform diff --git a/paddle/fluid/operators/nccl_op.cu.cc b/paddle/fluid/operators/nccl_op.cu.cc index fc83aa2ac2231c208e1933209985432717b4e36e..683a520e99fe72875d52393a18463324a8b6c3f2 100644 --- a/paddle/fluid/operators/nccl_op.cu.cc +++ b/paddle/fluid/operators/nccl_op.cu.cc @@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::dynload::ncclAllReduce( ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), NCCLTypeWrapper::type, reduction_op_, - comm->comms_[idx], stream)); + comm->comms().at(idx), stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(1) << "gpu : " @@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel { std::hash hasher; for (size_t i = 0; i < ins.size(); ++i) { if (root == platform::kInvalidGPUId) { - root = hasher(ins_names[i]) % comm->comms_.size(); + root = hasher(ins_names[i]) % comm->comms().size(); } T* recvbuffer = nullptr; if (root == gpu_id) { @@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::dynload::ncclReduce( ins[i]->data(), recvbuffer, ins[i]->numel(), - NCCLTypeWrapper::type, reduction_op_, root, comm->comms_[idx], + NCCLTypeWrapper::type, reduction_op_, root, comm->comms().at(idx), stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); @@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel { VLOG(1) << " before ncclBcast"; PADDLE_ENFORCE(platform::dynload::ncclBcast( (void*)ins[i]->data(), ins[i]->numel(), NCCLTypeWrapper::type, - root, comm->comms_[idx], stream)); + root, comm->comms().at(idx), stream)); VLOG(1) << " after ncclBcast"; PADDLE_ENFORCE(cudaStreamSynchronize(stream)); @@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::dynload::ncclBcast( outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), - NCCLTypeWrapper::type, root, comm->comms_[idx], stream)); + NCCLTypeWrapper::type, root, comm->comms().at(idx), stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "