NCCLTools.h 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

S
ShenLiang 已提交
17
#ifdef PADDLE_WITH_CUDA
18
#include <cuda_runtime.h>
S
ShenLiang 已提交
19 20 21 22 23
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif

24
#include <error.h>
25

26 27 28
#include <string>

#include "boost/variant.hpp"
29
#include "paddle/fluid/distributed/collective/Types.h"
30 31
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
S
ShenLiang 已提交
32 33

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
34
#include "paddle/fluid/platform/cuda_device_guard.h"
S
ShenLiang 已提交
35 36
#endif

37
#include "paddle/fluid/platform/device_context.h"
S
ShenLiang 已提交
38 39 40 41

#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#else
42
#include "paddle/fluid/platform/dynload/nccl.h"
S
ShenLiang 已提交
43 44
#endif

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace distributed {

#define NCCLCHECK(cmd)                                              \
  do {                                                              \
    ncclResult_t r = cmd;                                           \
    if (r != ncclSuccess) {                                         \
      printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
             platform::dynload::ncclGetErrorString(r));             \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.

class EventManager {
 public:
  EventManager() {}
  explicit EventManager(unsigned int flags) : flags_{flags} {}

  ~EventManager() {
    if (is_created_) {
      platform::CUDADeviceGuard guard(device_index_);
S
ShenLiang 已提交
75 76 77
#ifdef PADDLE_WITH_HIP
      hipEventDestroy(event_);
#else
78
      cudaEventDestroy(event_);
S
ShenLiang 已提交
79
#endif
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    }
  }

  EventManager(const EventManager&) = delete;
  EventManager& operator=(const EventManager&) = delete;

  EventManager(EventManager&& other) {
    std::swap(flags_, other.flags_);
    std::swap(is_created_, other.is_created_);
    std::swap(device_index_, other.device_index_);
    std::swap(event_, other.event_);
  }

  EventManager& operator=(EventManager&& other) {
    std::swap(flags_, other.flags_);
    std::swap(is_created_, other.is_created_);
    std::swap(device_index_, other.device_index_);
    std::swap(event_, other.event_);
    return *this;
  }

  bool IsCreated() const { return is_created_; }
  bool DeviceId() const { return device_index_; }
  gpuEvent_t GetRawCudaEvent() const { return event_; }

  void Record(const paddle::platform::CUDADeviceContext& ctx) {
    auto device_index = ctx.GetPlace().device;
    if (!is_created_) {
      CreateEvent(device_index);
    }
    PADDLE_ENFORCE_EQ(device_index, device_index_,
                      platform::errors::PreconditionNotMet(
                          "CUDADeviceContext's device %d does not match"
                          "Event's device %d",
                          device_index, device_index_));

    platform::CUDADeviceGuard guard(device_index_);
S
ShenLiang 已提交
117
#ifdef PADDLE_WITH_CUDA
118
    PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream()));
S
ShenLiang 已提交
119 120 121
#else
    PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, ctx.stream()));
#endif
122 123 124
  }

  bool Query() const {
S
ShenLiang 已提交
125 126 127 128 129 130 131 132 133
#ifdef PADDLE_WITH_HIP
    gpuError_t err = hipEventQuery(event_);
    if (err == hipSuccess) {
      return true;
    }
    if (err == hipErrorNotReady) {
      return false;
    }
#else
134 135 136
    gpuError_t err = cudaEventQuery(event_);
    if (err == cudaSuccess) {
      return true;
S
ShenLiang 已提交
137 138
    }
    if (err == cudaErrorNotReady) {
139 140
      return false;
    }
S
ShenLiang 已提交
141 142 143
#endif
    PADDLE_ENFORCE_GPU_SUCCESS(err);
    return false;
144 145 146 147
  }

  void Synchronize() const {
    if (is_created_) {
S
ShenLiang 已提交
148 149 150
#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
#else
151
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_));
S
ShenLiang 已提交
152
#endif
153 154 155 156 157 158 159 160 161 162 163 164
    }
  }

  void Block(const paddle::platform::CUDADeviceContext& ctx) const {
    if (is_created_) {
      auto device_index = ctx.GetPlace().device;
      PADDLE_ENFORCE_EQ(device_index, device_index_,
                        platform::errors::PreconditionNotMet(
                            "CUDADeviceContext's device %d does not match"
                            "Event's device %d",
                            device_index, device_index_));
      platform::CUDADeviceGuard guard(device_index_);
S
ShenLiang 已提交
165 166 167 168

#ifdef PADDLE_WITH_HIP
      PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(ctx.stream(), event_, 0));
#else
169
      PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0));
S
ShenLiang 已提交
170
#endif
171 172 173 174
    }
  }

 private:
S
ShenLiang 已提交
175 176 177
#ifdef PADDLE_WITH_HIP
  unsigned int flags_ = hipEventDefault;
#else
178
  unsigned int flags_ = cudaEventDefault;
S
ShenLiang 已提交
179 180
#endif

181 182 183 184 185 186 187 188
  bool is_created_{false};
  gpuEvent_t event_{};
  int8_t device_index_{0};

 private:
  void CreateEvent(int device_index) {
    device_index_ = device_index;
    platform::CUDADeviceGuard guard(device_index);
S
ShenLiang 已提交
189 190 191 192

#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(&event_, flags_));
#else
193
    PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_));
S
ShenLiang 已提交
194 195
#endif

196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
    is_created_ = true;
  }
};

// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm

class NCCLCommManager {
 public:
  explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {}

  NCCLCommManager() : NCCLCommManager(nullptr) {}

  ~NCCLCommManager() noexcept {
    std::unique_lock<std::mutex> lock(mutex_);
    if (nccl_comm_) {
      platform::dynload::ncclCommDestroy(nccl_comm_);
    }
  }

  static std::shared_ptr<NCCLCommManager> Create(int num_ranks, int rank,
                                                 ncclUniqueId comm_id) {
    auto nccl_manager = std::make_shared<NCCLCommManager>();
    NCCLCHECK(platform::dynload::ncclCommInitRank(&(nccl_manager->nccl_comm_),
                                                  num_ranks, comm_id, rank));

    nccl_manager->nccl_id_ = comm_id;
    nccl_manager->rank_ = rank;
    return nccl_manager;
  }

  ncclUniqueId GetNcclId() const {
    std::unique_lock<std::mutex> lock(mutex_);
    return nccl_id_;
  }

  ncclComm_t GetNcclComm() const {
    std::unique_lock<std::mutex> lock(mutex_);
    return nccl_comm_;
  }

  NCCLCommManager(const NCCLCommManager&) = delete;
  NCCLCommManager& operator=(const NCCLCommManager&) = delete;
  NCCLCommManager& operator=(NCCLCommManager&& other) = delete;

  NCCLCommManager(NCCLCommManager&& other) {
    std::unique_lock<std::mutex> lock(other.mutex_);
    std::swap(nccl_comm_, other.nccl_comm_);
  }

 protected:
  ncclComm_t nccl_comm_;
  ncclUniqueId nccl_id_;
  int rank_;
  mutable std::mutex mutex_;
};

L
lilong12 已提交
253 254 255
ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);

256 257
}  // namespace distributed
}  // namespace paddle