diff --git a/paddle/operators/nccl/nccl_gpu_common.cc b/paddle/operators/nccl/nccl_gpu_common.cc index 492d79ca53f29c0ca0e22d702f15a58909a220d1..80cb66300e98bbe4c30ac6cacb8ea7bb8c2ec44b 100644 --- a/paddle/operators/nccl/nccl_gpu_common.cc +++ b/paddle/operators/nccl/nccl_gpu_common.cc @@ -8,27 +8,27 @@ NCCLManager::NCCLManager() {} NCCLManager::~NCCLManager() { for (auto& p : comm_table) { - auto* comm = p.second; + auto& comm = p.second; auto& gpus_ = comm->gpus_; - for (int i = 0; i < gpus_.size(); ++i) { + for (size_t i = 0; i < gpus_.size(); ++i) { int gid = gpus_[i]; platform::SetDeviceId(gid); // mapping gid to idx int idx = gid % gpus_.size(); // wait finish - NCCL_CHECK( + PADDLE_ENFORCE( cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0)); - NCCL_CHECK(cudaEventDestroy(comm->events_[idx])); + PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx])); - NCCL_CHECK(ncclCommDestroy(comm->comms_[idx])); + PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx])); } - delete comm; + comm.reset(nullptr); } } -Communicator* NCCLManager::GetCommunicator(const std::vector& gpus) const { +Communicator* NCCLManager::GetCommunicator(const std::vector& gpus) { std::string key; for (auto& id : gpus) { key += std::to_string(id); @@ -37,21 +37,24 @@ Communicator* NCCLManager::GetCommunicator(const std::vector& gpus) const { std::mutex mu; std::lock_guard lk(mu); - auto* comm = comm_table[key]; - if (comm == nullptr) { - comm = new Communicator(gpus.size()); - NCCL_CHECK(ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data())); + + auto it = comm_table.find(key); + + if (it->second == nullptr) { + auto* comm = new Communicator(gpus); + PADDLE_ENFORCE( + ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data())); for (size_t i = 0; i < gpus.size(); ++i) { platform::SetDeviceId(gpus[i]); // block wait - NCCL_CHECK(cudaEventCreateWithFlags( - &events_[i], cudaEventBlockingSync | cudaEventDisableTiming)); + PADDLE_ENFORCE(cudaEventCreateWithFlags( + &comm->events_[i], cudaEventBlockingSync | cudaEventDisableTiming)); } - comm_table[key] = comm; + comm_table[key].reset(comm); } - return comm; + return comm_table[key].get(); } } // namespace operators diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h index a50490f392b5036ca7fdc43448288e6f2380d462..96b3bb801af6ffc0ebe18485dec8700f3757d814 100644 --- a/paddle/operators/nccl/nccl_gpu_common.h +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -1,5 +1,4 @@ #pragma once -#include #include #include @@ -10,17 +9,11 @@ #include #include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace platform { -#define NCCL_CHECK(condition) \ - do { \ - ncclResult_t ret = (condition); \ - PADDLE_ENFORCE(ret == ncclSuccess, "Error invoking NCCL: ", __FILE__, \ - __LINE__, ncclGetErrorString(ret)); \ - } while (0) - class WaitGroup { public: inline void Add(int n) { @@ -101,7 +94,7 @@ class NCCLManager { ~NCCLManager(); // for each card only have one communicator - Communicator* GetCommunicator(const std::vector& gpus) const; + Communicator* GetCommunicator(const std::vector& gpus); private: // // the gpu id list available. Note that only support @@ -109,7 +102,8 @@ class NCCLManager { // std::vector _gpu_worlds; // communicator list - std::unordered_map comm_table; + std::unordered_map> + comm_table; }; } // namespace operators diff --git a/paddle/operators/nccl/nccl_ops.h b/paddle/operators/nccl/nccl_ops.h index 7e348a601a7f3ffbc32ec387c3250f36a78bef74..894859f6f0eac52a92d25d413eded1e6ccc6d625 100644 --- a/paddle/operators/nccl/nccl_ops.h +++ b/paddle/operators/nccl/nccl_ops.h @@ -54,14 +54,15 @@ class NCCLAllReduceKernel : public framework::OpKernel { comm->streams_[idx] = stream; for (size_t i = 0; i < ins.size(); ++i) { - NCCL_CHECK(ncclAllReduce(ins[i]->data(), outs[i]->mutable_data(), - outs[i]->numel() * sizeof(T), - NCCLTypeWrapper::type, op_type, - &comm->comms_[idx], comm->streams_[idx])); - NCCL_CHECK(cudaEventRecord(comm->events_[idx], *comms_->streams_[idx])); + PADDLE_ENFORCE( + ncclAllReduce(ins[i]->data(), outs[i]->mutable_data(), + outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, + op_type, &comm->comms_[idx], comm->streams_[idx])); + PADDLE_ENFORCE( + cudaEventRecord(comm->events_[idx], *comms_->streams_[idx])); // wait finish - NCCL_CHECK( + PADDLE_ENFORCE( cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); } diff --git a/paddle/platform/dynload/nccl.h b/paddle/platform/dynload/nccl.h index ad050da4ad32e5d2015fcbcaaffaa428ba0c3c42..fbfcec4c985079108cae61bdf7d3b679cb04cd51 100644 --- a/paddle/platform/dynload/nccl.h +++ b/paddle/platform/dynload/nccl.h @@ -30,13 +30,13 @@ extern void* nccl_dso_handle; #define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ struct DynLoad__##__name { \ template \ - ncclResult_t operator()(Args... args) { \ - typedef ncclResult_t (*ncclFunc)(Args...); \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using nccl_func = decltype(__name(args...)) (*)(Args...); \ std::call_once(nccl_dso_flag, \ paddle::platform::dynload::GetNcclDsoHandle, \ &nccl_dso_handle); \ void* p_##__name = dlsym(nccl_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ + return reinterpret_cast(p_##__name)(args...); \ } \ }; \ extern DynLoad__##__name __name @@ -65,7 +65,7 @@ extern void* nccl_dso_handle; __macro(ncclReduce); \ __macro(ncclGetErrorString); -NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP); +NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) } // namespace dynload } // namespace platform diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 2f9e7466f1f6de5ad0252ae4f89ee8a3805b1f2e..bfe708748a62ff9ac5d151bc652142e1f4925c83 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -29,6 +29,8 @@ limitations under the License. */ #include // for __cxa_demangle #endif +#include + #ifdef PADDLE_WITH_CUDA #include "paddle/platform/dynload/cublas.h"