nccl_gpu_common.h 2.8 KB
Newer Older
D
dongzhihong 已提交
1 2 3
#pragma once
#include <nccl.h>

D
Dong Zhihong 已提交
4 5
#include <algorithm>
#include <condition_variable>
D
dzhwinter 已提交
6 7
#include <memory>
#include <mutex>
D
Dong Zhihong 已提交
8
#include <string>
D
dzhwinter 已提交
9
#include <unordered_map>
D
Dong Zhihong 已提交
10
#include <vector>
D
dzhwinter 已提交
11

D
dongzhihong 已提交
12 13 14 15 16
#include "paddle/platform/device_context.h"

namespace paddle {
namespace platform {

D
Dong Zhihong 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
#define NCCL_CHECK(condition)                                             \
  do {                                                                    \
    ncclResult_t ret = (condition);                                       \
    PADDLE_ENFORCE(ret == ncclSuccess, "Error invoking NCCL: ", __FILE__, \
                   __LINE__, ncclGetErrorString(ret));                    \
  } while (0)

class WaitGroup {
 public:
  inline void Add(int n) {
    std::unique_lock<std::mutex> lk(mu_);
    PADDLE_ENFORCE(n >= 0, "add wait must >=0.");
    counter_ += n;
  }

  inline void Done(int n) {
    std::unique_lock<std::mutex> lk(mu_);
    PADDLE_ENFORCE(n <= counter_, " wait group done unmatch to add.");
    counter_ -= n;
    if (counter_ == 0) {
      cv_.notify_all();
    }
  }

  inline void Add() { Add(1); }

  inline void Done() { Done(1); }

  inline void Wait() {
    std::unique_lock<std::mutex> lk(mu_);
    cv_.wait(lk, [&] { return counter_ == 0; });
  }

  inline int GetCount() {
    std::unique_lock<std::mutex> lk(mu_);
    return counter_;
  }

 private:
  int counter_ = 0;
  std::mutex mu_;
  std::condition_variable cv_;
};
D
dzhwinter 已提交
60 61 62 63 64 65 66 67 68 69 70

// class NCCLContext : public DeviceContext {
// public:
//   explicit NCCLContext(GPUPlace place);
//   virtual ~NCCLContext();

// private:
//   std::vector<int> gpu_ids_;
//   std::vector<cudaStream_t> streams_;
// };

D
Dong Zhihong 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
// TODO(dzh) : make resources managed unified with framework
struct Communicator {
  std::vector<ncclComm_t> comms_;
  std::vector<cudaStream_t*> streams_;
  std::vector<cudaEvent_t> events_;
  std::vector<int> gpus_;
  WaitGroup wg_;
  int root_gpu = -1;
  // cudaEvent_t root_monitor;
  explicit Communicator(const std::vector<int>& gpus) : gpus_(gpus) {
    comms_.resize(gpus.size());
    streams_.resize(gpus.size());
    events_.resize(gpus.size());
  }
  // Communicator(int num_device): comms_.resize(num_device) {}

  inline int get_root_gpu() const { return root_gpu; }
D
dzhwinter 已提交
88

D
Dong Zhihong 已提交
89 90
  inline void set_root_gpu(int id) { root_gpu = id; }
};
D
dzhwinter 已提交
91

D
dongzhihong 已提交
92 93 94 95 96 97 98
class NCCLManager {
 public:
  static NCCLManager* Get() {
    static NCCLManager m;
    return &m;
  }

D
Dong Zhihong 已提交
99 100 101
  NCCLManager();

  ~NCCLManager();
D
dongzhihong 已提交
102

D
dzhwinter 已提交
103
  // for each card only have one communicator
D
Dong Zhihong 已提交
104
  Communicator* GetCommunicator(const std::vector<int>& gpus) const;
D
dzhwinter 已提交
105

D
dongzhihong 已提交
106
 private:
D
Dong Zhihong 已提交
107 108 109
  // // the gpu id list available. Note that only support
  // // whole world communication.
  // std::vector<int> _gpu_worlds;
D
dongzhihong 已提交
110

D
dzhwinter 已提交
111
  // communicator list
D
Dong Zhihong 已提交
112
  std::unordered_map<std::string /* key*/, Communicator*> comm_table;
D
dongzhihong 已提交
113
};
D
dzhwinter 已提交
114 115 116

}  // namespace operators
}  // namespace paddle