nccl_helper.h 14.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
T
typhoonzero 已提交
18
#include <stdio.h>
19

Q
qingqing01 已提交
20
#include <memory>
21
#include <string>
22
#include <thread>  // NOLINT
Y
Yu Yang 已提交
23
#include <typeindex>
Q
qingqing01 已提交
24
#include <unordered_map>
25
#include <vector>
W
Wu Yi 已提交
26

Y
Yu Yang 已提交
27
#include "paddle/fluid/framework/data_type.h"
28
#include "paddle/fluid/platform/collective_helper.h"
29
#ifdef PADDLE_WITH_NCCL
Y
Yu Yang 已提交
30
#include "paddle/fluid/platform/dynload/nccl.h"
31 32 33 34
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
35 36
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
Y
Yu Yang 已提交
37
#include "paddle/fluid/platform/enforce.h"
W
Wu Yi 已提交
38
#include "paddle/fluid/platform/float16.h"
Y
Yu Yang 已提交
39

T
typhoonzero 已提交
40 41
#define NCCL_ID_VARNAME "NCCLID"

Y
Yu Yang 已提交
42 43 44
namespace paddle {
namespace platform {

Y
Yu Yang 已提交
45 46
inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
  if (type == framework::proto::VarType::FP32) {
Y
Yu Yang 已提交
47
    return ncclFloat;
Y
Yu Yang 已提交
48
  } else if (type == framework::proto::VarType::FP64) {
Y
Yu Yang 已提交
49
    return ncclDouble;
Y
Yu Yang 已提交
50
  } else if (type == framework::proto::VarType::INT32) {
Y
Yu Yang 已提交
51
    return ncclInt;
Y
Yu Yang 已提交
52
  } else if (type == framework::proto::VarType::INT64) {
53
    return ncclInt64;
W
Wu Yi 已提交
54 55
  } else if (type == framework::proto::VarType::FP16) {
    return ncclFloat16;
L
lilong12 已提交
56 57
  } else if (type == framework::proto::VarType::INT8) {
    return ncclInt8;
58 59 60 61
  } else if (type == framework::proto::VarType::UINT8) {
    return ncclUint8;
  } else if (type == framework::proto::VarType::BOOL) {
    return ncclUint8;
L
LiYuRio 已提交
62
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
63 64 65
  } else if (type == framework::proto::VarType::BF16) {
    return ncclBfloat16;
#endif
Y
Yu Yang 已提交
66
  } else {
G
GaoWei8 已提交
67 68
    PADDLE_THROW(platform::errors::Unimplemented(
        "This datatype in nccl is not supported."));
Y
Yu Yang 已提交
69 70 71
  }
}

72 73
inline ncclDataType_t ToNCCLDataType(phi::DataType type) {
  if (type == phi::DataType::FLOAT32) {
74
    return ncclFloat;
75
  } else if (type == phi::DataType::FLOAT64) {
76
    return ncclDouble;
77
  } else if (type == phi::DataType::INT32) {
78
    return ncclInt;
79
  } else if (type == phi::DataType::INT64) {
80
    return ncclInt64;
81
  } else if (type == phi::DataType::FLOAT16) {
82
    return ncclFloat16;
83
  } else if (type == phi::DataType::UINT8) {
84
    return ncclUint8;
85
  } else if (type == phi::DataType::INT8) {
86
    return ncclInt8;
87
  } else if (type == phi::DataType::BOOL) {
88
    return ncclUint8;
L
LiYuRio 已提交
89
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
90
  } else if (type == phi::DataType::BFLOAT16) {
91 92
    return ncclBfloat16;
#endif
93 94 95 96 97 98
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "This datatype in nccl is not supported."));
  }
}

99 100 101 102 103
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
// cause blocking problem when a runtime_error was thrown, so try only guard
// NCCL actions when use it.
Y
Yu Yang 已提交
104 105
class NCCLGroupGuard {
 public:
Y
Yu Yang 已提交
106 107 108 109 110
  static std::mutex &NCCLMutex() {
    static std::mutex mtx;
    return mtx;
  }

Y
Yu Yang 已提交
111
  inline NCCLGroupGuard() {
Y
Yu Yang 已提交
112
    NCCLMutex().lock();
113
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupStart());
Y
Yu Yang 已提交
114
  }
Y
Yu Yang 已提交
115

Z
Zeng Jinle 已提交
116
  inline ~NCCLGroupGuard() PADDLE_MAY_THROW {
117
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupEnd());
Y
Yu Yang 已提交
118
    NCCLMutex().unlock();
Y
Yu Yang 已提交
119 120 121
  }
};

Y
Yu Yang 已提交
122
struct NCCLContext {
L
Leo Chen 已提交
123
  std::unique_ptr<phi::GPUContext> ctx_;
Y
Yu Yang 已提交
124 125
  ncclComm_t comm_;

W
Wilber 已提交
126
  explicit NCCLContext(int dev_id) : comm_{nullptr} {
L
Leo Chen 已提交
127
    ctx_.reset(new phi::GPUContext(CUDAPlace(dev_id)));
W
Wilber 已提交
128 129 130 131 132 133 134 135 136 137 138
    ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                           .GetAllocator(CUDAPlace(dev_id), ctx_->stream())
                           .get());
    ctx_->SetHostAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CPUPlace())
            .get());
    ctx_->SetZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(CUDAPlace(dev_id))
            .get());
139 140 141 142
    ctx_->SetHostZeroAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetZeroAllocator(paddle::platform::CPUPlace())
            .get());
W
wanghuancoder 已提交
143 144 145 146
    ctx_->SetPinnedAllocator(
        paddle::memory::allocation::AllocatorFacade::Instance()
            .GetAllocator(paddle::platform::CUDAPinnedPlace())
            .get());
W
Wilber 已提交
147 148
    ctx_->PartialInitWithAllocator();
  }
Y
Yu Yang 已提交
149

150
  gpuStream_t stream() const { return ctx_->stream(); }
Q
qingqing01 已提交
151 152
  ncclComm_t comm() const { return comm_; }

153
  int device_id() const { return ctx_->GetPlace().device; }
Y
Yu Yang 已提交
154 155
};

Y
Yu Yang 已提交
156 157 158 159
struct NCCLContextMap {
  std::unordered_map<int, NCCLContext> contexts_;
  std::vector<int> order_;

T
typhoonzero 已提交
160 161
  explicit NCCLContextMap(const std::vector<platform::Place> &places,
                          ncclUniqueId *nccl_id = nullptr,
162 163 164 165
                          size_t num_trainers = 1,
                          size_t trainer_id = 0) {
    PADDLE_ENFORCE_EQ(!places.empty(),
                      true,
G
GaoWei8 已提交
166 167
                      platform::errors::InvalidArgument(
                          "The NCCL place should not be empty."));
Y
Yu Yang 已提交
168 169
    order_.reserve(places.size());
    for (auto &p : places) {
170
      int dev_id = p.device;
Y
Yu Yang 已提交
171 172 173 174
      order_.emplace_back(dev_id);
      contexts_.emplace(dev_id, NCCLContext(dev_id));
    }
    PADDLE_ENFORCE_EQ(
175 176
        order_.size(),
        contexts_.size(),
G
GaoWei8 已提交
177 178
        platform::errors::Unavailable("NCCL Context Map does not support "
                                      "contain two or more same device."));
Y
Yu Yang 已提交
179

T
typhoonzero 已提交
180
    std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
W
Wu Yi 已提交
181
    // if num_trainers == 1, should create a new nccl id for local comms.
Y
Yancey1989 已提交
182
    if (num_trainers == 1 && nccl_id == nullptr) {
T
typhoonzero 已提交
183
      std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
184
      PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
T
typhoonzero 已提交
185
          comms.get(), static_cast<int>(order_.size()), order_.data()));
T
typhoonzero 已提交
186
    } else {
187 188 189
      PADDLE_ENFORCE_NOT_NULL(
          nccl_id,
          platform::errors::InvalidArgument("The NCCL id should not be null."));
Y
Yu Yang 已提交
190
      {
T
typhoonzero 已提交
191
        int nranks = num_trainers * order_.size();
T
typhoonzero 已提交
192
        NCCLGroupGuard gurad;
193 194 195 196 197 198 199 200
        for (size_t i = 0; i < order_.size(); ++i) {
          int gpu_id = order_[i];
          int rank;
          if (order_.size() > 1) {
            rank = trainer_id * order_.size() + i;
          } else {
            rank = trainer_id;
          }
201 202
          VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks
                  << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
L
Leo Chen 已提交
203
          SetDeviceId(gpu_id);
204
          PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
205
              comms.get() + i, nranks, *nccl_id, rank));
T
typhoonzero 已提交
206
        }
Y
Yu Yang 已提交
207
      }
Y
Yu Yang 已提交
208
    }
T
typhoonzero 已提交
209 210 211 212
    int i = 0;
    for (auto &dev_id : order_) {
      contexts_.at(dev_id).comm_ = comms[i++];
    }
Y
Yu Yang 已提交
213 214
  }

Y
Yu Yang 已提交
215 216 217
  NCCLContextMap(const NCCLContextMap &other) = delete;
  NCCLContextMap &operator=(const NCCLContextMap &other) = delete;

L
Leo Chen 已提交
218
  phi::GPUContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
Y
Yu Yang 已提交
219

L
Leo Chen 已提交
220
  phi::GPUContext *DevCtx(platform::Place p) const { return DevCtx(p.device); }
Y
Yu Yang 已提交
221

222
  const NCCLContext &at(platform::Place p) const { return this->at(p.device); }
Y
Yu Yang 已提交
223 224 225 226 227 228 229 230 231 232

  const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }

  void WaitAll() {
    for (auto &p : contexts_) {
      p.second.ctx_->Wait();
    }
  }
};

233 234 235 236 237 238 239 240
inline std::string GetFlatNCCLVarName(size_t pos) {
  if (pos == 0) {
    return NCCL_ID_VARNAME;
  }
  return string::Sprintf("%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos));
}

inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
241 242
  return string::Sprintf(
      "Hierarchical_exter_%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos));
243
}
G
gongweibao 已提交
244
inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
245 246
  return string::Sprintf(
      "Hierarchical_inter_%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos));
247 248
}

249
class NCCLCommunicator {
250
 public:
251
  NCCLCommunicator() {}
Z
Zeng Jinle 已提交
252
  virtual ~NCCLCommunicator() PADDLE_MAY_THROW {}
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

  NCCLContextMap *DefaultFlatCtx() const {
    if (flat_ctxs_.size() == 0) {
      return nullptr;
    }

    return flat_ctxs_[0].get();
  }

  std::vector<std::unique_ptr<NCCLContextMap>> *GetFlatCtxs() {
    return &flat_ctxs_;
  }

  NCCLContextMap *GetFlatCtx(size_t run_order) const {
    return flat_ctxs_[run_order % flat_ctxs_.size()].get();
  }

  NCCLContextMap *GetRunEnvNCCLCtx(size_t run_order,
                                   bool use_hierarchical_allreduce) const {
    if (!use_hierarchical_allreduce) {
      return GetFlatCtx(run_order);
    }

    return GetHierarchicalInterCtx(run_order);
  }

279 280 281 282 283
  /*
   *When nccl inits nccl comm using ncclCommInitAll, it meets error when
   *allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
   *create a new nccl comm for sync_batch_norm_op. And these codes should be
   *polished with a unified nccl management.
284
   */
285 286 287 288 289 290 291 292 293 294 295 296 297
  NCCLContextMap *GetSyncBatchNormCtx(
      framework::Scope *scope, const std::vector<platform::Place> &places) {
    auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
    if (nccl_id_var != nullptr) {
      return DefaultFlatCtx();
    }

    if (sync_batch_norm_ctx_.get() == nullptr) {
      sync_batch_norm_ctx_.reset(new NCCLContextMap(places));
    }
    return sync_batch_norm_ctx_.get();
  }

298 299
  void InitFlatCtxs(const std::vector<platform::Place> &places,
                    const std::vector<ncclUniqueId *> &nccl_ids,
300 301
                    size_t trainers_num,
                    size_t trainer_id) {
302 303 304 305
    if (nccl_ids.size() == 0) {
      auto ptr = new platform::NCCLContextMap(places);
      VLOG(1) << "init local trainer";
      flat_ctxs_.emplace_back(ptr);
306 307
    } else {
      for (size_t i = 0; i < nccl_ids.size(); i++) {
308 309
        auto ptr = new platform::NCCLContextMap(
            places, nccl_ids[i], trainers_num, trainer_id);
310 311 312
        VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
        flat_ctxs_.emplace_back(ptr);
      }
313 314
    }

315 316 317 318 319 320 321
    // as Executor have no way to use ncclComm created by ParallelExecutor,
    // we assign all flatten contexts to NCCLCommContext to fix.
    int nranks = static_cast<int>(trainers_num * places.size());
    int nrings = static_cast<int>(flat_ctxs_.size());
    for (int ring_id = 0; ring_id < nrings; ++ring_id) {
      for (size_t p = 0; p < places.size(); ++p) {
        int rank = trainer_id * places.size() + p;
322
        int dev_id = places[p].device;
323
        auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
324 325
        NCCLCommContext::Instance().AssignNCCLComm(
            ctx.comm_, nranks, rank, dev_id, ring_id);
326
      }
327 328 329 330
    }
  }

  void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
G
gongweibao 已提交
331 332
                            const std::vector<ncclUniqueId *> &inter_nccl_ids,
                            const std::vector<ncclUniqueId *> &exter_nccl_ids,
333 334
                            size_t trainers_num,
                            size_t trainer_id,
335 336
                            size_t inter_trainers_num,
                            size_t exter_trainers_num) {
337 338 339 340 341 342 343 344
    PADDLE_ENFORCE_EQ(trainers_num,
                      inter_trainers_num * exter_trainers_num,
                      platform::errors::InvalidArgument(
                          "trainers_num:%llu != inter_trainers_num:%llu * "
                          "exter_trainers_num:%llu",
                          trainers_num,
                          inter_trainers_num,
                          exter_trainers_num));
345

G
GaoWei8 已提交
346
    PADDLE_ENFORCE_GT(
347 348
        inter_trainers_num,
        1,
G
GaoWei8 已提交
349 350 351
        platform::errors::InvalidArgument(
            "The inter_trainers_num:%llu should be larger than 1.",
            inter_trainers_num));
352 353

    int inter_trainer_id = trainer_id % inter_trainers_num;
G
gongweibao 已提交
354 355 356
    for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
      VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
              << ", comm no:" << i;
357 358
      auto local = new NCCLContextMap(
          places, inter_nccl_ids[i], inter_trainers_num, inter_trainer_id);
359

G
gongweibao 已提交
360 361
      h_inter_ctxs_.emplace_back(local);
    }
362 363 364 365 366 367 368

    int exter_trainer_id = -1;
    if (trainer_id % inter_trainers_num == 0) {
      exter_trainer_id = trainer_id / inter_trainers_num;
    }

    if (exter_trainer_id >= 0) {
G
gongweibao 已提交
369
      for (size_t i = 0; i < exter_nccl_ids.size(); i++) {
370 371
        auto ex = new NCCLContextMap(
            places, exter_nccl_ids[i], exter_trainers_num, exter_trainer_id);
372 373 374 375 376 377 378 379 380 381
        VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
                << ", comm no:" << i;
        h_exter_ctxs_.emplace_back(ex);
      }
    }
  }

  bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }

  NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
382 383
    PADDLE_ENFORCE_GT(h_inter_ctxs_.size(),
                      0,
G
GaoWei8 已提交
384 385
                      platform::errors::InvalidArgument(
                          "Hierarchical ctxs should be initialized firstly!"));
386 387 388 389
    return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
  }

  NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
390 391
    PADDLE_ENFORCE_GT(h_exter_ctxs_.size(),
                      0,
G
GaoWei8 已提交
392 393
                      platform::errors::InvalidArgument(
                          "Hierarchical ctxs should be initialized firstly!"));
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
    return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
  }

  std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalInterCtxs() {
    return &h_inter_ctxs_;
  }

  std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalExterCtxs() {
    return &h_exter_ctxs_;
  }

 protected:
  // Support multi nccl comm on default nccl ring while NCCLContextMap can't.
  std::vector<std::unique_ptr<NCCLContextMap>> flat_ctxs_;

  // h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
  // And h_exter_ctxs_ can support multi comm too.
  std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_;
  std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_;
413 414 415

  // just used for sync_batch_norm op.
  std::unique_ptr<NCCLContextMap> sync_batch_norm_ctx_;
416 417
};

Y
Yu Yang 已提交
418 419
}  // namespace platform
}  // namespace paddle
P
peizhilin 已提交
420
#endif