nccl_gpu_common.h 1.1 KB
Newer Older
D
dongzhihong 已提交
1 2 3
#pragma once
#include <nccl.h>

D
dzhwinter 已提交
4 5 6 7 8 9
#include <memory>
#include <mutex>
#include <condition_variable>
#include <vector>
#include <unordered_map>

D
dongzhihong 已提交
10 11 12 13 14
#include "paddle/platform/device_context.h"

namespace paddle {
namespace platform {

D
dzhwinter 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28

// class NCCLContext : public DeviceContext {
// public:
//   explicit NCCLContext(GPUPlace place);
//   virtual ~NCCLContext();

// private:
//   std::vector<int> gpu_ids_;
//   std::vector<cudaStream_t> streams_;
// };


class Communicator;

D
dongzhihong 已提交
29 30 31 32 33 34 35
class NCCLManager {
 public:
  static NCCLManager* Get() {
    static NCCLManager m;
    return &m;
  }

D
dzhwinter 已提交
36 37
  NCCLManager() {
  }
D
dongzhihong 已提交
38 39
  ~NCCLManager() {}

D
dzhwinter 已提交
40 41 42
  // for each card only have one communicator
  Communicator* GetCommunicator() const;

D
dongzhihong 已提交
43
 private:
D
dzhwinter 已提交
44 45 46 47 48 49
  struct Communicator {
    std::vector<ncclComm_t> comms_;
    std::vector<cudaStream_t*> streams_; // do not own
    std::vector<cudaEvent_t> events_;
    int root_gpu;
  };
D
dongzhihong 已提交
50

D
dzhwinter 已提交
51 52 53
  // the gpu id list available. Note that only support
  // whole world communication.
  std::vector<int> _gpu_worlds;
D
dongzhihong 已提交
54

D
dzhwinter 已提交
55 56
  // communicator list
  std::unordered_map<std::string /* key*/, Communicator*> comms_;
D
dongzhihong 已提交
57
};
D
dzhwinter 已提交
58 59 60

}  // namespace operators
}  // namespace paddle