nccl_helper.h 9.4 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

P
peizhilin 已提交
15
#ifndef _WIN32
Y
Yu Yang 已提交
16 17
#pragma once

T
typhoonzero 已提交
18
#include <stdio.h>
Q
qingqing01 已提交
19
#include <memory>
20
#include <string>
21
#include <thread>  // NOLINT
Y
Yu Yang 已提交
22
#include <typeindex>
Q
qingqing01 已提交
23
#include <unordered_map>
24
#include <vector>
W
Wu Yi 已提交
25

Y
Yu Yang 已提交
26
#include "paddle/fluid/framework/data_type.h"
Y
Yu Yang 已提交
27 28
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
W
Wu Yi 已提交
29
#include "paddle/fluid/platform/float16.h"
Y
Yu Yang 已提交
30

T
typhoonzero 已提交
31 32
#define NCCL_ID_VARNAME "NCCLID"

Y
Yu Yang 已提交
33 34 35
namespace paddle {
namespace platform {

Y
Yu Yang 已提交
36 37
inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
  if (type == framework::proto::VarType::FP32) {
Y
Yu Yang 已提交
38
    return ncclFloat;
Y
Yu Yang 已提交
39
  } else if (type == framework::proto::VarType::FP64) {
Y
Yu Yang 已提交
40
    return ncclDouble;
Y
Yu Yang 已提交
41
  } else if (type == framework::proto::VarType::INT32) {
Y
Yu Yang 已提交
42
    return ncclInt;
Y
Yu Yang 已提交
43
  } else if (type == framework::proto::VarType::INT64) {
44
    return ncclInt64;
W
Wu Yi 已提交
45 46
  } else if (type == framework::proto::VarType::FP16) {
    return ncclFloat16;
Y
Yu Yang 已提交
47 48 49 50 51
  } else {
    PADDLE_THROW("Not supported");
  }
}

52 53 54 55 56
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
// cause blocking problem when a runtime_error was thrown, so try only guard
// NCCL actions when use it.
Y
Yu Yang 已提交
57 58
class NCCLGroupGuard {
 public:
Y
Yu Yang 已提交
59 60 61 62 63
  static std::mutex &NCCLMutex() {
    static std::mutex mtx;
    return mtx;
  }

Y
Yu Yang 已提交
64
  inline NCCLGroupGuard() {
Y
Yu Yang 已提交
65
    NCCLMutex().lock();
Y
Yu Yang 已提交
66 67
    PADDLE_ENFORCE(dynload::ncclGroupStart());
  }
Y
Yu Yang 已提交
68 69

  inline ~NCCLGroupGuard() {
S
sneaxiy 已提交
70
    PADDLE_ENFORCE(dynload::ncclGroupEnd());
Y
Yu Yang 已提交
71
    NCCLMutex().unlock();
Y
Yu Yang 已提交
72 73 74
  }
};

Y
Yu Yang 已提交
75 76 77 78 79
struct NCCLContext {
  std::unique_ptr<CUDADeviceContext> ctx_;
  ncclComm_t comm_;

  explicit NCCLContext(int dev_id)
Y
Yu Yang 已提交
80
      : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
Y
Yu Yang 已提交
81 82

  cudaStream_t stream() const { return ctx_->stream(); }
Q
qingqing01 已提交
83 84
  ncclComm_t comm() const { return comm_; }

Y
Yu Yang 已提交
85 86 87 88 89
  int device_id() const {
    return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
  }
};

Y
Yu Yang 已提交
90 91 92 93
struct NCCLContextMap {
  std::unordered_map<int, NCCLContext> contexts_;
  std::vector<int> order_;

T
typhoonzero 已提交
94 95
  explicit NCCLContextMap(const std::vector<platform::Place> &places,
                          ncclUniqueId *nccl_id = nullptr,
Y
Yancey1989 已提交
96
                          size_t num_trainers = 1, size_t trainer_id = 0) {
Y
Yu Yang 已提交
97
    PADDLE_ENFORCE(!places.empty());
Y
Yu Yang 已提交
98 99 100 101 102 103 104 105 106 107
    order_.reserve(places.size());
    for (auto &p : places) {
      int dev_id = boost::get<CUDAPlace>(p).device;
      order_.emplace_back(dev_id);
      contexts_.emplace(dev_id, NCCLContext(dev_id));
    }
    PADDLE_ENFORCE_EQ(
        order_.size(), contexts_.size(),
        "NCCL Context Map does not support contain two or more same device");

T
typhoonzero 已提交
108
    std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
W
Wu Yi 已提交
109
    // if num_trainers == 1, should create a new nccl id for local comms.
Y
Yancey1989 已提交
110
    if (num_trainers == 1 && nccl_id == nullptr) {
T
typhoonzero 已提交
111 112 113
      std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
      PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
          comms.get(), static_cast<int>(order_.size()), order_.data()));
T
typhoonzero 已提交
114
    } else {
W
Wu Yi 已提交
115
      PADDLE_ENFORCE_NOT_NULL(nccl_id);
Y
Yu Yang 已提交
116
      {
T
typhoonzero 已提交
117
        int nranks = num_trainers * order_.size();
T
typhoonzero 已提交
118
        NCCLGroupGuard gurad;
119 120 121 122 123 124 125 126
        for (size_t i = 0; i < order_.size(); ++i) {
          int gpu_id = order_[i];
          int rank;
          if (order_.size() > 1) {
            rank = trainer_id * order_.size() + i;
          } else {
            rank = trainer_id;
          }
127 128
          VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks
                  << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
T
typhoonzero 已提交
129
          PADDLE_ENFORCE(cudaSetDevice(gpu_id));
T
testing  
typhoonzero 已提交
130
          PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
131
              comms.get() + i, nranks, *nccl_id, rank));
T
typhoonzero 已提交
132
        }
Y
Yu Yang 已提交
133
      }
Y
Yu Yang 已提交
134
    }
T
typhoonzero 已提交
135 136 137 138
    int i = 0;
    for (auto &dev_id : order_) {
      contexts_.at(dev_id).comm_ = comms[i++];
    }
Y
Yu Yang 已提交
139 140
  }

Y
Yu Yang 已提交
141 142 143
  NCCLContextMap(const NCCLContextMap &other) = delete;
  NCCLContextMap &operator=(const NCCLContextMap &other) = delete;

Y
Yu Yang 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
  CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }

  CUDADeviceContext *DevCtx(platform::Place p) const {
    return DevCtx(boost::get<CUDAPlace>(p).device);
  }

  const NCCLContext &at(platform::Place p) const {
    return this->at(boost::get<CUDAPlace>(p).device);
  }

  const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }

  void WaitAll() {
    for (auto &p : contexts_) {
      p.second.ctx_->Wait();
    }
  }
};

163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
inline std::string GetFlatNCCLVarName(size_t pos) {
  if (pos == 0) {
    return NCCL_ID_VARNAME;
  }
  return string::Sprintf("%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos));
}

inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
  return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME,
                         static_cast<int>(pos));
}
inline std::string GetHierarchicalInterNCCLVarName() {
  return string::Sprintf("Hierarchical_inter_%s", NCCL_ID_VARNAME);
}

class MultiNCCLContextMap {
 public:
  MultiNCCLContextMap() {}
  virtual ~MultiNCCLContextMap() {}

  NCCLContextMap *DefaultFlatCtx() const {
    if (flat_ctxs_.size() == 0) {
      return nullptr;
    }

    return flat_ctxs_[0].get();
  }

  std::vector<std::unique_ptr<NCCLContextMap>> *GetFlatCtxs() {
    return &flat_ctxs_;
  }

  NCCLContextMap *GetFlatCtx(size_t run_order) const {
    return flat_ctxs_[run_order % flat_ctxs_.size()].get();
  }

  NCCLContextMap *GetRunEnvNCCLCtx(size_t run_order,
                                   bool use_hierarchical_allreduce) const {
    if (!use_hierarchical_allreduce) {
      return GetFlatCtx(run_order);
    }

    return GetHierarchicalInterCtx(run_order);
  }

  void InitFlatCtxs(const std::vector<platform::Place> &places,
                    const std::vector<ncclUniqueId *> &nccl_ids,
                    size_t trainers_num, size_t trainer_id) {
    if (nccl_ids.size() == 0) {
      auto ptr = new platform::NCCLContextMap(places);
      VLOG(1) << "init local trainer";
      flat_ctxs_.emplace_back(ptr);
      return;
    }

    for (size_t i = 0; i < nccl_ids.size(); i++) {
      auto ptr = new platform::NCCLContextMap(places, nccl_ids[i], trainers_num,
                                              trainer_id);
      VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
      flat_ctxs_.emplace_back(ptr);
    }
  }

  void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
                            ncclUniqueId *inter_nccl_id,
                            const std::vector<ncclUniqueId *> &exter_nccl_id,
                            size_t trainers_num, size_t trainer_id,
                            size_t inter_trainers_num,
                            size_t exter_trainers_num) {
    PADDLE_ENFORCE(trainers_num == inter_trainers_num * exter_trainers_num,
                   "trainers_num:%llu != inter_trainers_num:%llu * "
                   "exter_trainers_num:%llu",
                   trainers_num, inter_trainers_num, exter_trainers_num);

    PADDLE_ENFORCE(inter_trainers_num > 1, "inter_trainers_num:%llu must > 1",
                   inter_trainers_num);

    int inter_trainer_id = trainer_id % inter_trainers_num;
    VLOG(1) << "init inter_trainer_id:" << inter_trainer_id;
    auto local = new NCCLContextMap(places, inter_nccl_id, inter_trainers_num,
                                    inter_trainer_id);

    h_inter_ctxs_.emplace_back(local);

    int exter_trainer_id = -1;
    if (trainer_id % inter_trainers_num == 0) {
      exter_trainer_id = trainer_id / inter_trainers_num;
    }

    if (exter_trainer_id >= 0) {
      for (size_t i = 0; i < exter_nccl_id.size(); i++) {
        auto ex = new NCCLContextMap(places, exter_nccl_id[i],
                                     exter_trainers_num, exter_trainer_id);
        VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
                << ", comm no:" << i;
        h_exter_ctxs_.emplace_back(ex);
      }
    }
  }

  bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }

  NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
    return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
  }

  NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
    return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
  }

  std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalInterCtxs() {
    return &h_inter_ctxs_;
  }

  std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalExterCtxs() {
    return &h_exter_ctxs_;
  }

 protected:
  // Support multi nccl comm on default nccl ring while NCCLContextMap can't.
  std::vector<std::unique_ptr<NCCLContextMap>> flat_ctxs_;

  // h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
  // And h_exter_ctxs_ can support multi comm too.
  std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_;
  std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_;
};

Y
Yu Yang 已提交
291 292
}  // namespace platform
}  // namespace paddle
P
peizhilin 已提交
293
#endif