nccl_gpu_common.h 2.4 KB
Newer Older
D
dongzhihong 已提交
1 2
#pragma once

D
Dong Zhihong 已提交
3 4
#include <algorithm>
#include <condition_variable>
D
dzhwinter 已提交
5 6
#include <memory>
#include <mutex>
D
Dong Zhihong 已提交
7
#include <string>
D
dzhwinter 已提交
8
#include <unordered_map>
D
Dong Zhihong 已提交
9
#include <vector>
D
dzhwinter 已提交
10

D
dongzhihong 已提交
11
#include "paddle/platform/device_context.h"
D
Dong Zhihong 已提交
12
#include "paddle/platform/enforce.h"
D
dongzhihong 已提交
13 14 15 16

namespace paddle {
namespace platform {

D
Dong Zhihong 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
class WaitGroup {
 public:
  inline void Add(int n) {
    std::unique_lock<std::mutex> lk(mu_);
    PADDLE_ENFORCE(n >= 0, "add wait must >=0.");
    counter_ += n;
  }

  inline void Done(int n) {
    std::unique_lock<std::mutex> lk(mu_);
    PADDLE_ENFORCE(n <= counter_, " wait group done unmatch to add.");
    counter_ -= n;
    if (counter_ == 0) {
      cv_.notify_all();
    }
  }

  inline void Add() { Add(1); }

  inline void Done() { Done(1); }

  inline void Wait() {
    std::unique_lock<std::mutex> lk(mu_);
    cv_.wait(lk, [&] { return counter_ == 0; });
  }

  inline int GetCount() {
    std::unique_lock<std::mutex> lk(mu_);
    return counter_;
  }

 private:
  int counter_ = 0;
  std::mutex mu_;
  std::condition_variable cv_;
};
D
dzhwinter 已提交
53 54 55 56 57 58 59 60 61 62 63

// class NCCLContext : public DeviceContext {
// public:
//   explicit NCCLContext(GPUPlace place);
//   virtual ~NCCLContext();

// private:
//   std::vector<int> gpu_ids_;
//   std::vector<cudaStream_t> streams_;
// };

D
Dong Zhihong 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
// TODO(dzh) : make resources managed unified with framework
struct Communicator {
  std::vector<ncclComm_t> comms_;
  std::vector<cudaStream_t*> streams_;
  std::vector<cudaEvent_t> events_;
  std::vector<int> gpus_;
  WaitGroup wg_;
  int root_gpu = -1;
  // cudaEvent_t root_monitor;
  explicit Communicator(const std::vector<int>& gpus) : gpus_(gpus) {
    comms_.resize(gpus.size());
    streams_.resize(gpus.size());
    events_.resize(gpus.size());
  }
  // Communicator(int num_device): comms_.resize(num_device) {}

  inline int get_root_gpu() const { return root_gpu; }
D
dzhwinter 已提交
81

D
Dong Zhihong 已提交
82 83
  inline void set_root_gpu(int id) { root_gpu = id; }
};
D
dzhwinter 已提交
84

D
dongzhihong 已提交
85 86 87 88 89 90 91
class NCCLManager {
 public:
  static NCCLManager* Get() {
    static NCCLManager m;
    return &m;
  }

D
Dong Zhihong 已提交
92 93 94
  NCCLManager();

  ~NCCLManager();
D
dongzhihong 已提交
95

D
dzhwinter 已提交
96
  // for each card only have one communicator
D
Dong Zhihong 已提交
97
  Communicator* GetCommunicator(const std::vector<int>& gpus);
D
dzhwinter 已提交
98

D
dongzhihong 已提交
99
 private:
D
Dong Zhihong 已提交
100 101 102
  // // the gpu id list available. Note that only support
  // // whole world communication.
  // std::vector<int> _gpu_worlds;
D
dongzhihong 已提交
103

D
dzhwinter 已提交
104
  // communicator list
D
Dong Zhihong 已提交
105 106
  std::unordered_map<std::string /* key*/, std::unique_ptr<Communicator>>
      comm_table;
D
dongzhihong 已提交
107
};
D
dzhwinter 已提交
108 109 110

}  // namespace operators
}  // namespace paddle