collective_helper.cc 13.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/collective_helper.h"
16

17
#include <utility>
18

W
Wilber 已提交
19
#include "paddle/fluid/memory/allocation/allocator_facade.h"
20
#include "paddle/fluid/platform/device/device_wrapper.h"
21
#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"
W
WangXi 已提交
22

23 24
namespace paddle {
namespace platform {
25
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
26 27 28 29 30 31 32 33 34 35 36
class NCCLCommImpl : public NCCLComm {
 public:
  void set_ring_id(int ring_id) { ring_id_ = ring_id; }
  int ring_id() const override { return ring_id_; }

  void set_nranks(int nranks) { nranks_ = nranks; }
  int nranks() const override { return nranks_; }

  void set_rank(int rank) { rank_ = rank; }
  int rank() const override { return rank_; }

37
  int device_id() const override { return dev_ctx_->GetPlace().device; }
38

39 40
  void set_comm(ncclComm_t comm) { comm_ = comm; }
  ncclComm_t comm() const override { return comm_; }
41

42
  gpuStream_t stream() const override { return dev_ctx_->stream(); }
43

L
Leo Chen 已提交
44
  void set_dev_ctx(std::unique_ptr<phi::GPUContext>&& dev_ctx) {
45 46
    dev_ctx_ = std::move(dev_ctx);
  }
L
Leo Chen 已提交
47
  phi::GPUContext* dev_context() const override { return dev_ctx_.get(); }
48

W
WangXi 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
  gpuEvent_t compute_event() const override { return compute_event_.get(); }

  gpuEvent_t comm_event() const override { return comm_event_.get(); }

  void set_compute_event(
      std::shared_ptr<platform::CudaEventObject>&& compute_event) {
    compute_event_ = std::move(compute_event);
  }

  void set_comm_event(std::shared_ptr<platform::CudaEventObject>&& comm_event) {
    comm_event_ = std::move(comm_event);
  }

62 63 64 65
 private:
  int ring_id_;
  int nranks_;
  int rank_;
66
  ncclComm_t comm_;
L
Leo Chen 已提交
67
  std::unique_ptr<phi::GPUContext> dev_ctx_;
W
WangXi 已提交
68 69 70 71 72 73

  // used for comm wait compute, compute_stream-->event-->comm_stream
  std::shared_ptr<platform::CudaEventObject> compute_event_;

  // used for compute wait comm, comm_stream-->event-->compute_stream
  std::shared_ptr<platform::CudaEventObject> comm_event_;
74 75
};

S
sneaxiy 已提交
76 77 78 79 80
NCCLCommContext& NCCLCommContext::Instance() {
  static NCCLCommContext comm_ctx;
  return comm_ctx;
}

81 82
NCCLComm* NCCLCommContext::CreateComm(
    ncclUniqueId* nccl_id, int nranks, int rank, int dev_id, int ring_id) {
G
GaoWei8 已提交
83 84 85 86
  PADDLE_ENFORCE_NOT_NULL(nccl_id,
                          platform::errors::InvalidArgument(
                              "The nccl unique id should not be null."));
  PADDLE_ENFORCE_GT(
87 88
      nranks,
      1,
G
GaoWei8 已提交
89 90
      platform::errors::InvalidArgument(
          "Expected nranks > 1. But received nranks is %d.", nranks));
91 92
  PADDLE_ENFORCE_GE(rank,
                    0,
G
GaoWei8 已提交
93 94 95
                    platform::errors::InvalidArgument(
                        "Expected rank >= 0. But received rank is %d.", rank));
  PADDLE_ENFORCE_LT(
96 97
      rank,
      nranks,
G
GaoWei8 已提交
98 99
      platform::errors::InvalidArgument(
          "Expected rank < nranks. But received rank is %d, nranks is %d.",
100 101
          rank,
          nranks));
G
GaoWei8 已提交
102
  PADDLE_ENFORCE_GE(
103 104
      dev_id,
      0,
G
GaoWei8 已提交
105 106
      platform::errors::InvalidArgument(
          "Expected dev_id >= 0. But received dev_id is %d.", dev_id));
107 108

  ncclComm_t comm = nullptr;
L
Leo Chen 已提交
109
  SetDeviceId(dev_id);
110
  PADDLE_ENFORCE_GPU_SUCCESS(
111 112
      platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));

113
  auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id);
114

115
  VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id
116
          << " has been created on device " << dev_id;
117

118 119 120 121
  std::call_once(once_flag_, []() {
    std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
  });

122
  return comm_wrapper;
123 124 125 126
}

void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
                                         int ring_id) {
G
GaoWei8 已提交
127
  PADDLE_ENFORCE_GT(
128 129
      dev_ids.size(),
      0,
G
GaoWei8 已提交
130 131 132
      platform::errors::InvalidArgument("Expected the size of dev_ids > 0. But "
                                        "received the size of dev_ids is %d.",
                                        dev_ids.size()));
133 134 135

  const int kDevices = dev_ids.size();
  ncclComm_t comms[kDevices];
136
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclCommInitAll(
137 138
      comms, dev_ids.size(), dev_ids.data()));

139 140
  PADDLE_ENFORCE_EQ(comm_map_.count(ring_id),
                    0,
G
GaoWei8 已提交
141 142 143 144
                    platform::errors::InvalidArgument(
                        "Expected comm_map_.count(ring_id) = 0. But received "
                        "comm_map_.count(ring_id) is %d.",
                        comm_map_.count(ring_id)));
145
  for (size_t i = 0; i < dev_ids.size(); ++i) {
146 147 148
    AssignNCCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id);
    VLOG(1) << "nccl communicator of rank " << i << " in ring " << ring_id
            << " has been created on device " << dev_ids[i];
149 150 151 152 153
  }

  std::call_once(once_flag_, []() {
    std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
  });
154 155
}

Y
yaoxuefeng 已提交
156
void NCCLCommContext::CreateNCCLCommMultiTrainer(
157 158 159 160 161
    const std::vector<int>& dev_ids,
    ncclUniqueId* nccl_id,
    int ntrainers,
    int train_id,
    int ring_id) {
Y
yaoxuefeng 已提交
162
  PADDLE_ENFORCE_GT(
163 164
      dev_ids.size(),
      0,
Y
yaoxuefeng 已提交
165 166 167
      paddle::platform::errors::InvalidArgument(
          "dev ids = [%d], it should greater than 0.", dev_ids.size()));
  const int kDevices = dev_ids.size();
Y
yaoxuefeng 已提交
168
  VLOG(1) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices
Y
yaoxuefeng 已提交
169 170 171 172
          << ", ntrainers: " << ntrainers << ", train_id: " << train_id
          << ", rind_id: " << ring_id;
  ncclComm_t comms[kDevices];
  {
173
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupStart());
Y
yaoxuefeng 已提交
174 175
    for (int i = 0; i < kDevices; i++) {
#ifdef PADDLE_WITH_HIP
176
      PADDLE_ENFORCE_GPU_SUCCESS(hipSetDevice(i));
Y
yaoxuefeng 已提交
177
#else
178
      PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(i));
Y
yaoxuefeng 已提交
179
#endif
180 181
      platform::dynload::ncclCommInitRank(
          comms + i, kDevices * ntrainers, *nccl_id, train_id * kDevices + i);
Y
yaoxuefeng 已提交
182
      VLOG(1) << "ncclCommInitRank: " << i;
Y
yaoxuefeng 已提交
183
    }
184
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupEnd());
Y
yaoxuefeng 已提交
185
    VLOG(1) << "nccl group end seccessss";
Y
yaoxuefeng 已提交
186
  }
187 188
  PADDLE_ENFORCE_EQ(comm_map_.count(ring_id),
                    0,
Y
yaoxuefeng 已提交
189 190
                    platform::errors::InvalidArgument(
                        "comm_map_ of ring_id: %s should be 0. %s is provided",
191 192
                        ring_id,
                        comm_map_.count(ring_id)));
Y
yaoxuefeng 已提交
193
  for (int i = 0; i < kDevices; ++i) {
194 195 196 197 198
    AssignNCCLComm(comms[i],
                   kDevices * ntrainers,
                   train_id * kDevices + i,
                   dev_ids[i],
                   ring_id);
Y
yaoxuefeng 已提交
199
    VLOG(1) << "nccl communicator of train_id " << train_id * kDevices + i
Y
yaoxuefeng 已提交
200 201 202 203 204 205 206 207 208
            << " in ring " << ring_id << " has been created on device "
            << dev_ids[i];
  }

  std::call_once(once_flag_, []() {
    std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
  });
}

209 210
NCCLComm* NCCLCommContext::AssignNCCLComm(
    ncclComm_t comm, int nranks, int rank, int dev_id, int ring_id) {
L
Leo Chen 已提交
211 212
  std::unique_ptr<phi::GPUContext> dev_ctx(
      new phi::GPUContext(CUDAPlace(dev_id)));
W
Wilber 已提交
213 214 215 216 217 218 219 220 221 222 223
  dev_ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                            .GetAllocator(CUDAPlace(dev_id), dev_ctx->stream())
                            .get());
  dev_ctx->SetHostAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(paddle::platform::CPUPlace())
          .get());
  dev_ctx->SetZeroAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetZeroAllocator(CUDAPlace(dev_id))
          .get());
224 225 226 227
  dev_ctx->SetHostZeroAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetZeroAllocator(paddle::platform::CPUPlace())
          .get());
W
wanghuancoder 已提交
228 229 230 231
  dev_ctx->SetPinnedAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(paddle::platform::CUDAPinnedPlace())
          .get());
W
Wilber 已提交
232
  dev_ctx->PartialInitWithAllocator();
233

W
WangXi 已提交
234 235 236 237 238
  std::shared_ptr<platform::CudaEventObject> compute_event(
      platform::CudaEventResourcePool::Instance().New(dev_id));
  std::shared_ptr<platform::CudaEventObject> comm_event(
      platform::CudaEventResourcePool::Instance().New(dev_id));

239 240 241 242 243 244
  NCCLCommImpl* c = new NCCLCommImpl;
  c->set_ring_id(ring_id);
  c->set_nranks(nranks);
  c->set_rank(rank);
  c->set_comm(comm);
  c->set_dev_ctx(std::move(dev_ctx));
W
WangXi 已提交
245 246
  c->set_compute_event(std::move(compute_event));
  c->set_comm_event(std::move(comm_event));
247 248 249 250 251 252 253 254 255 256 257

  comm_map_mutex_.lock();
  if (comm_map_.count(ring_id) == 0) {
    comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
  }
  auto& dev2comm = comm_map_[ring_id];

  dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
  comm_map_mutex_.unlock();

  if (ring_id == 0) {
L
Leo Chen 已提交
258
    auto* dev_ctx = static_cast<phi::GPUContext*>(
259 260 261 262
        platform::DeviceContextPool::Instance().Get(
            platform::CUDAPlace(dev_id)));
    dev_ctx->set_nccl_comm(comm);
  }
263 264
  VLOG(4) << "add mccl comm: " << comm_map_[ring_id][dev_id].get()
          << ", ring_id:" << ring_id << ", dev_id:" << dev_id;
265 266 267
  return comm_map_[ring_id][dev_id].get();
}

268
void NCCLCommContext::ReleaseNCCLComms() {
269 270 271 272
  for (auto& p : comm_map_) {
    for (auto& q : p.second) {
      q.second.reset();
    }
273 274 275
  }
}

276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
#endif

#if defined(PADDLE_WITH_XPU_BKCL)

class BKCLCommImpl : public BKCLComm {
 public:
  void set_ring_id(int ring_id) { ring_id_ = ring_id; }
  int ring_id() const override { return ring_id_; }

  void set_nranks(int nranks) { nranks_ = nranks; }
  int nranks() const override { return nranks_; }

  void set_rank(int rank) { rank_ = rank; }
  int rank() const override { return rank_; }

291
  int device_id() const override { return dev_ctx_->GetPlace().device; }
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312

  void set_comm(BKCLContext_t comm) { comm_ = comm; }
  BKCLContext_t comm() const override { return comm_; }

  XPUStream stream() const override {
    return dev_ctx_->x_context()->xpu_stream;
  }

  void set_dev_ctx(std::unique_ptr<XPUDeviceContext>&& dev_ctx) {
    dev_ctx_ = std::move(dev_ctx);
  }
  XPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }

 private:
  int ring_id_;
  int nranks_;
  int rank_;
  BKCLContext_t comm_;
  std::unique_ptr<XPUDeviceContext> dev_ctx_;
};

313 314
BKCLComm* BKCLCommContext::CreateComm(
    BKCLUniqueId* bkcl_id, int nranks, int rank, int dev_id, int ring_id) {
315 316 317 318
  PADDLE_ENFORCE_NOT_NULL(bkcl_id,
                          platform::errors::InvalidArgument(
                              "The bkcl unique id should not be null."));
  PADDLE_ENFORCE_GT(
319 320
      nranks,
      1,
321 322
      platform::errors::InvalidArgument(
          "Expected nranks > 1. But received nranks is %d.", nranks));
323 324
  PADDLE_ENFORCE_GE(rank,
                    0,
325 326 327
                    platform::errors::InvalidArgument(
                        "Expected rank >= 0. But received rank is %d.", rank));
  PADDLE_ENFORCE_LT(
328 329
      rank,
      nranks,
330 331
      platform::errors::InvalidArgument(
          "Expected rank < nranks. But received rank is %d, nranks is %d.",
332 333
          rank,
          nranks));
334
  PADDLE_ENFORCE_GE(
335 336
      dev_id,
      0,
337 338 339 340
      platform::errors::InvalidArgument(
          "Expected dev_id >= 0. But received dev_id is %d.", dev_id));

  BKCLContext_t comm = nullptr;
341 342
  platform::SetXPUDeviceId(dev_id);
  PADDLE_ENFORCE_XPU_SUCCESS(bkcl_init_rank(&comm, rank, nranks, bkcl_id));
343 344 345 346 347 348 349 350 351 352 353 354 355

  auto* comm_wrapper = AssignBKCLComm(comm, nranks, rank, dev_id, ring_id);

  VLOG(1) << "bkcl communicator of rank " << rank << " in ring " << ring_id
          << " has been created on device " << dev_id;

  std::call_once(once_flag_, []() {
    std::atexit([]() { BKCLCommContext::Instance().ReleaseBKCLComms(); });
  });

  return comm_wrapper;
}

356 357
BKCLComm* BKCLCommContext::AssignBKCLComm(
    BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) {
358 359
  std::unique_ptr<XPUDeviceContext> dev_ctx(
      new XPUDeviceContext(XPUPlace(dev_id)));
360 361 362 363 364 365 366 367 368 369 370
  dev_ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                            .GetAllocator(XPUPlace(dev_id))
                            .get());
  dev_ctx->SetHostAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(paddle::platform::CPUPlace())
          .get());
  dev_ctx->SetZeroAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetZeroAllocator(XPUPlace(dev_id))
          .get());
371 372 373 374
  dev_ctx->SetHostZeroAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetZeroAllocator(paddle::platform::CPUPlace())
          .get());
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395

  BKCLCommImpl* c = new BKCLCommImpl;
  c->set_ring_id(ring_id);
  c->set_nranks(nranks);
  c->set_rank(rank);
  c->set_comm(comm);
  c->set_dev_ctx(std::move(dev_ctx));

  comm_map_mutex_.lock();
  if (comm_map_.count(ring_id) == 0) {
    comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<BKCLComm>>());
  }
  auto& dev2comm = comm_map_[ring_id];

  dev2comm.emplace(dev_id, std::unique_ptr<BKCLComm>(c));
  comm_map_mutex_.unlock();

  if (ring_id == 0) {
    auto* dev_ctx = static_cast<platform::XPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(
            platform::XPUPlace(dev_id)));
W
Wilber 已提交
396
    dev_ctx->SetBkclContext(comm);
397 398 399 400 401 402 403 404 405 406 407 408
  }

  return comm_map_[ring_id][dev_id].get();
}

void BKCLCommContext::ReleaseBKCLComms() {
  for (auto& p : comm_map_) {
    for (auto& q : p.second) {
      q.second.reset();
    }
  }
}
409 410

#endif
411 412 413

}  // namespace platform
}  // namespace paddle