nccl_gpu_common.cc 1.4 KB
Newer Older
D
dzhwinter 已提交
1
#include "paddle/operators/nccl/nccl_gpu_common.h"
D
Dong Zhihong 已提交
2
#include "paddle/platform/gpu_info.h"
D
dzhwinter 已提交
3 4 5 6

namespace paddle {
namespace platform {

D
Dong Zhihong 已提交
7
NCCLManager::NCCLManager() {}
D
dzhwinter 已提交
8

D
Dong Zhihong 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
NCCLManager::~NCCLManager() {
  for (auto& p : comm_table) {
    auto* comm = p.second;
    auto& gpus_ = comm->gpus_;
    for (int i = 0; i < gpus_.size(); ++i) {
      int gid = gpus_[i];
      platform::SetDeviceId(gid);

      // mapping gid to idx
      int idx = gid % gpus_.size();
      // wait finish
      NCCL_CHECK(
          cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0));

      NCCL_CHECK(cudaEventDestroy(comm->events_[idx]));

      NCCL_CHECK(ncclCommDestroy(comm->comms_[idx]));
    }
    delete comm;
  }
}

Communicator* NCCLManager::GetCommunicator(const std::vector<int>& gpus) const {
  std::string key;
  for (auto& id : gpus) {
    key += std::to_string(id);
  }
  std::sort(key.begin(), key.end());

  std::mutex mu;
  std::lock_guard<std::mutex> lk(mu);
  auto* comm = comm_table[key];
  if (comm == nullptr) {
    comm = new Communicator(gpus.size());
    NCCL_CHECK(ncclCommInitAll(comm->comms_.data(), gpus.size(), gpus.data()));

    for (size_t i = 0; i < gpus.size(); ++i) {
      platform::SetDeviceId(gpus[i]);

      // block wait
      NCCL_CHECK(cudaEventCreateWithFlags(
          &events_[i], cudaEventBlockingSync | cudaEventDisableTiming));
    }
    comm_table[key] = comm;
  }
  return comm;
}
D
dzhwinter 已提交
56 57 58

}  // namespace operators
}  // namespace paddle