#pragma once #include #include #include #include #include #include #include #include #include "paddle/platform/device_context.h" namespace paddle { namespace platform { #define NCCL_CHECK(condition) \ do { \ ncclResult_t ret = (condition); \ PADDLE_ENFORCE(ret == ncclSuccess, "Error invoking NCCL: ", __FILE__, \ __LINE__, ncclGetErrorString(ret)); \ } while (0) class WaitGroup { public: inline void Add(int n) { std::unique_lock lk(mu_); PADDLE_ENFORCE(n >= 0, "add wait must >=0."); counter_ += n; } inline void Done(int n) { std::unique_lock lk(mu_); PADDLE_ENFORCE(n <= counter_, " wait group done unmatch to add."); counter_ -= n; if (counter_ == 0) { cv_.notify_all(); } } inline void Add() { Add(1); } inline void Done() { Done(1); } inline void Wait() { std::unique_lock lk(mu_); cv_.wait(lk, [&] { return counter_ == 0; }); } inline int GetCount() { std::unique_lock lk(mu_); return counter_; } private: int counter_ = 0; std::mutex mu_; std::condition_variable cv_; }; // class NCCLContext : public DeviceContext { // public: // explicit NCCLContext(GPUPlace place); // virtual ~NCCLContext(); // private: // std::vector gpu_ids_; // std::vector streams_; // }; // TODO(dzh) : make resources managed unified with framework struct Communicator { std::vector comms_; std::vector streams_; std::vector events_; std::vector gpus_; WaitGroup wg_; int root_gpu = -1; // cudaEvent_t root_monitor; explicit Communicator(const std::vector& gpus) : gpus_(gpus) { comms_.resize(gpus.size()); streams_.resize(gpus.size()); events_.resize(gpus.size()); } // Communicator(int num_device): comms_.resize(num_device) {} inline int get_root_gpu() const { return root_gpu; } inline void set_root_gpu(int id) { root_gpu = id; } }; class NCCLManager { public: static NCCLManager* Get() { static NCCLManager m; return &m; } NCCLManager(); ~NCCLManager(); // for each card only have one communicator Communicator* GetCommunicator(const std::vector& gpus) const; private: // // the gpu id list available. Note that only support // // whole world communication. // std::vector _gpu_worlds; // communicator list std::unordered_map comm_table; }; } // namespace operators } // namespace paddle