You need to sign in or sign up before continuing.
nccl_helper.h 4.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <thread>  // NOLINT
Y
Yu Yang 已提交
18
#include <typeindex>
19
#include <vector>
Y
Yu Yang 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace platform {

inline ncclDataType_t ToNCCLDataType(std::type_index type) {
  if (type == typeid(float)) {  // NOLINT
    return ncclFloat;
  } else if (type == typeid(double)) {  // NOLINT
    return ncclDouble;
  } else if (type == typeid(int)) {  // NOLINT
    return ncclInt;
33 34
  } else if (type == typeid(int64_t)) {  // NOLINT
    return ncclInt64;
Y
Yu Yang 已提交
35 36 37 38 39
  } else {
    PADDLE_THROW("Not supported");
  }
}

Y
Yu Yang 已提交
40 41
class NCCLGroupGuard {
 public:
Y
Yu Yang 已提交
42 43 44 45 46
  static std::mutex &NCCLMutex() {
    static std::mutex mtx;
    return mtx;
  }

Y
Yu Yang 已提交
47
  inline NCCLGroupGuard() {
Y
Yu Yang 已提交
48
    NCCLMutex().lock();
Y
Yu Yang 已提交
49 50
    PADDLE_ENFORCE(dynload::ncclGroupStart());
  }
Y
Yu Yang 已提交
51 52 53

  inline ~NCCLGroupGuard() {
    PADDLE_ENFORCE(dynload::ncclGroupEnd());
Y
Yu Yang 已提交
54
    NCCLMutex().unlock();
Y
Yu Yang 已提交
55 56 57
  }
};

Y
Yu Yang 已提交
58 59 60 61 62
struct NCCLContext {
  std::unique_ptr<CUDADeviceContext> ctx_;
  ncclComm_t comm_;

  explicit NCCLContext(int dev_id)
Y
Yu Yang 已提交
63
      : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
Y
Yu Yang 已提交
64 65 66 67 68 69 70 71

  cudaStream_t stream() const { return ctx_->stream(); }

  int device_id() const {
    return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
  }
};

Y
Yu Yang 已提交
72 73 74 75
struct NCCLContextMap {
  std::unordered_map<int, NCCLContext> contexts_;
  std::vector<int> order_;

T
typhoonzero 已提交
76 77 78
  explicit NCCLContextMap(const std::vector<platform::Place> &places,
                          ncclUniqueId *nccl_id = nullptr,
                          size_t node_count = 0, size_t trainer_id = 0) {
Y
Yu Yang 已提交
79
    PADDLE_ENFORCE(!places.empty());
Y
Yu Yang 已提交
80 81 82 83 84 85 86 87 88 89
    order_.reserve(places.size());
    for (auto &p : places) {
      int dev_id = boost::get<CUDAPlace>(p).device;
      order_.emplace_back(dev_id);
      contexts_.emplace(dev_id, NCCLContext(dev_id));
    }
    PADDLE_ENFORCE_EQ(
        order_.size(), contexts_.size(),
        "NCCL Context Map does not support contain two or more same device");

T
typhoonzero 已提交
90 91 92 93 94 95
    if (places.size() <= 1) {
      return;
    }
    std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
    // if pass nccl_id here, can assume we are doing multi node training
    if (nccl_id == nullptr) {
Y
Yu Yang 已提交
96 97 98 99 100
      {
        std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
        PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
            comms.get(), static_cast<int>(order_.size()), order_.data()));
      }
T
typhoonzero 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113
    } else {
      PADDLE_ENFORCE_GT(node_count, 0);
      PADDLE_ENFORCE_EQ(node_count % places.size(), 0,
                        "must have same number of GPUs on each node");
      {
        std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
        int nranks = node_count * order_.size();
        for (auto &gpu_id : order_) {
          int rank = trainer_id * order_.size() + gpu_id;
          PADDLE_ENFORCE(cudaSetDevice(gpu_id));
          PADDLE_ENFORCE(
              ncclCommInitRank(comms.get() + gpu_id, nranks, *nccl_id, rank));
        }
Y
Yu Yang 已提交
114
      }
Y
Yu Yang 已提交
115
    }
T
typhoonzero 已提交
116 117 118 119
    int i = 0;
    for (auto &dev_id : order_) {
      contexts_.at(dev_id).comm_ = comms[i++];
    }
Y
Yu Yang 已提交
120 121
  }

Y
Yu Yang 已提交
122 123 124
  NCCLContextMap(const NCCLContextMap &other) = delete;
  NCCLContextMap &operator=(const NCCLContextMap &other) = delete;

Y
Yu Yang 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }

  CUDADeviceContext *DevCtx(platform::Place p) const {
    return DevCtx(boost::get<CUDAPlace>(p).device);
  }

  const NCCLContext &at(platform::Place p) const {
    return this->at(boost::get<CUDAPlace>(p).device);
  }

  const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }

  void WaitAll() {
    for (auto &p : contexts_) {
      p.second.ctx_->Wait();
    }
  }
};

Y
Yu Yang 已提交
144 145
}  // namespace platform
}  // namespace paddle