nccl_gpu_common.cc 1.5 KB
Newer Older
D
dzhwinter 已提交
1
#include "paddle/operators/nccl/nccl_gpu_common.h"
D
Dong Zhihong 已提交
2
#include "paddle/platform/gpu_info.h"
D
dzhwinter 已提交
3 4 5 6

namespace paddle {
namespace platform {

D
Dong Zhihong 已提交
7
NCCLManager::NCCLManager() {}
D
dzhwinter 已提交
8

D
Dong Zhihong 已提交
9 10
NCCLManager::~NCCLManager() {
  for (auto& p : comm_table) {
D
Dong Zhihong 已提交
11
    auto& comm = p.second;
D
Dong Zhihong 已提交
12
    auto& gpus_ = comm->gpus_;
D
Dong Zhihong 已提交
13
    for (size_t i = 0; i < gpus_.size(); ++i) {
D
Dong Zhihong 已提交
14 15 16 17 18 19
      int gid = gpus_[i];
      platform::SetDeviceId(gid);

      // mapping gid to idx
      int idx = gid % gpus_.size();
      // wait finish
D
Dong Zhihong 已提交
20
      PADDLE_ENFORCE(
D
Dong Zhihong 已提交
21 22
          cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0));

D
Dong Zhihong 已提交
23
      PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
D
Dong Zhihong 已提交
24

D
Dong Zhihong 已提交
25
      PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
D
Dong Zhihong 已提交
26
    }
D
Dong Zhihong 已提交
27
    comm.reset(nullptr);
D
Dong Zhihong 已提交
28 29 30
  }
}

D
Dong Zhihong 已提交
31
Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) {
D
Dong Zhihong 已提交
32 33 34 35 36 37 38 39
  std::string key;
  for (auto& id : gpus) {
    key += std::to_string(id);
  }
  std::sort(key.begin(), key.end());

  std::mutex mu;
  std::lock_guard<std::mutex> lk(mu);
D
Dong Zhihong 已提交
40 41 42 43 44 45 46

  auto it = comm_table.find(key);

  if (it->second == nullptr) {
    auto* comm = new Communicator(gpus);
    PADDLE_ENFORCE(
        ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));
D
Dong Zhihong 已提交
47 48 49 50 51

    for (size_t i = 0; i < gpus.size(); ++i) {
      platform::SetDeviceId(gpus[i]);

      // block wait
D
Dong Zhihong 已提交
52 53
      PADDLE_ENFORCE(cudaEventCreateWithFlags(
          &comm->events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
D
Dong Zhihong 已提交
54
    }
D
Dong Zhihong 已提交
55
    comm_table[key].reset(comm);
D
Dong Zhihong 已提交
56
  }
D
Dong Zhihong 已提交
57
  return comm_table[key].get();
D
Dong Zhihong 已提交
58
}
D
dzhwinter 已提交
59 60 61

}  // namespace operators
}  // namespace paddle