collective_helper.cc 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/platform/collective_helper.h"
16

17
#include <utility>
18

W
Wilber 已提交
19
#include "paddle/fluid/memory/allocation/allocator_facade.h"
20
#include "paddle/fluid/platform/device/device_wrapper.h"
21
#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"
W
WangXi 已提交
22

23 24
namespace paddle {
namespace platform {
25
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
26 27 28 29 30 31 32 33 34 35 36
class NCCLCommImpl : public NCCLComm {
 public:
  void set_ring_id(int ring_id) { ring_id_ = ring_id; }
  int ring_id() const override { return ring_id_; }

  void set_nranks(int nranks) { nranks_ = nranks; }
  int nranks() const override { return nranks_; }

  void set_rank(int rank) { rank_ = rank; }
  int rank() const override { return rank_; }

37
  int device_id() const override { return dev_ctx_->GetPlace().device; }
38

39 40
  void set_comm(ncclComm_t comm) { comm_ = comm; }
  ncclComm_t comm() const override { return comm_; }
41

42
  gpuStream_t stream() const override { return dev_ctx_->stream(); }
43

44 45 46
  void set_dev_ctx(std::unique_ptr<CUDADeviceContext>&& dev_ctx) {
    dev_ctx_ = std::move(dev_ctx);
  }
47
  CUDADeviceContext* dev_context() const override { return dev_ctx_.get(); }
48

W
WangXi 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
  gpuEvent_t compute_event() const override { return compute_event_.get(); }

  gpuEvent_t comm_event() const override { return comm_event_.get(); }

  void set_compute_event(
      std::shared_ptr<platform::CudaEventObject>&& compute_event) {
    compute_event_ = std::move(compute_event);
  }

  void set_comm_event(std::shared_ptr<platform::CudaEventObject>&& comm_event) {
    comm_event_ = std::move(comm_event);
  }

62 63 64 65
 private:
  int ring_id_;
  int nranks_;
  int rank_;
66
  ncclComm_t comm_;
67
  std::unique_ptr<CUDADeviceContext> dev_ctx_;
W
WangXi 已提交
68 69 70 71 72 73

  // used for comm wait compute, compute_stream-->event-->comm_stream
  std::shared_ptr<platform::CudaEventObject> compute_event_;

  // used for compute wait comm, comm_stream-->event-->compute_stream
  std::shared_ptr<platform::CudaEventObject> comm_event_;
74 75
};

76 77
NCCLComm* NCCLCommContext::CreateComm(
    ncclUniqueId* nccl_id, int nranks, int rank, int dev_id, int ring_id) {
G
GaoWei8 已提交
78 79 80 81
  PADDLE_ENFORCE_NOT_NULL(nccl_id,
                          platform::errors::InvalidArgument(
                              "The nccl unique id should not be null."));
  PADDLE_ENFORCE_GT(
82 83
      nranks,
      1,
G
GaoWei8 已提交
84 85
      platform::errors::InvalidArgument(
          "Expected nranks > 1. But received nranks is %d.", nranks));
86 87
  PADDLE_ENFORCE_GE(rank,
                    0,
G
GaoWei8 已提交
88 89 90
                    platform::errors::InvalidArgument(
                        "Expected rank >= 0. But received rank is %d.", rank));
  PADDLE_ENFORCE_LT(
91 92
      rank,
      nranks,
G
GaoWei8 已提交
93 94
      platform::errors::InvalidArgument(
          "Expected rank < nranks. But received rank is %d, nranks is %d.",
95 96
          rank,
          nranks));
G
GaoWei8 已提交
97
  PADDLE_ENFORCE_GE(
98 99
      dev_id,
      0,
G
GaoWei8 已提交
100 101
      platform::errors::InvalidArgument(
          "Expected dev_id >= 0. But received dev_id is %d.", dev_id));
102 103

  ncclComm_t comm = nullptr;
L
Leo Chen 已提交
104
  SetDeviceId(dev_id);
105
  PADDLE_ENFORCE_GPU_SUCCESS(
106 107
      platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));

108
  auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id);
109

110
  VLOG(1) << "nccl communicator of rank " << rank << " in ring " << ring_id
111
          << " has been created on device " << dev_id;
112

113 114 115 116
  std::call_once(once_flag_, []() {
    std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
  });

117
  return comm_wrapper;
118 119 120 121
}

void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
                                         int ring_id) {
G
GaoWei8 已提交
122
  PADDLE_ENFORCE_GT(
123 124
      dev_ids.size(),
      0,
G
GaoWei8 已提交
125 126 127
      platform::errors::InvalidArgument("Expected the size of dev_ids > 0. But "
                                        "received the size of dev_ids is %d.",
                                        dev_ids.size()));
128 129 130

  const int kDevices = dev_ids.size();
  ncclComm_t comms[kDevices];
131
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclCommInitAll(
132 133
      comms, dev_ids.size(), dev_ids.data()));

134 135
  PADDLE_ENFORCE_EQ(comm_map_.count(ring_id),
                    0,
G
GaoWei8 已提交
136 137 138 139
                    platform::errors::InvalidArgument(
                        "Expected comm_map_.count(ring_id) = 0. But received "
                        "comm_map_.count(ring_id) is %d.",
                        comm_map_.count(ring_id)));
140
  for (size_t i = 0; i < dev_ids.size(); ++i) {
141 142 143
    AssignNCCLComm(comms[i], dev_ids.size(), i, dev_ids[i], ring_id);
    VLOG(1) << "nccl communicator of rank " << i << " in ring " << ring_id
            << " has been created on device " << dev_ids[i];
144 145 146 147 148
  }

  std::call_once(once_flag_, []() {
    std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
  });
149 150
}

Y
yaoxuefeng 已提交
151
void NCCLCommContext::CreateNCCLCommMultiTrainer(
152 153 154 155 156
    const std::vector<int>& dev_ids,
    ncclUniqueId* nccl_id,
    int ntrainers,
    int train_id,
    int ring_id) {
Y
yaoxuefeng 已提交
157
  PADDLE_ENFORCE_GT(
158 159
      dev_ids.size(),
      0,
Y
yaoxuefeng 已提交
160 161 162
      paddle::platform::errors::InvalidArgument(
          "dev ids = [%d], it should greater than 0.", dev_ids.size()));
  const int kDevices = dev_ids.size();
Y
yaoxuefeng 已提交
163
  VLOG(1) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices
Y
yaoxuefeng 已提交
164 165 166 167
          << ", ntrainers: " << ntrainers << ", train_id: " << train_id
          << ", rind_id: " << ring_id;
  ncclComm_t comms[kDevices];
  {
168
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupStart());
Y
yaoxuefeng 已提交
169 170
    for (int i = 0; i < kDevices; i++) {
#ifdef PADDLE_WITH_HIP
171
      PADDLE_ENFORCE_GPU_SUCCESS(hipSetDevice(i));
Y
yaoxuefeng 已提交
172
#else
173
      PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(i));
Y
yaoxuefeng 已提交
174
#endif
175 176
      platform::dynload::ncclCommInitRank(
          comms + i, kDevices * ntrainers, *nccl_id, train_id * kDevices + i);
Y
yaoxuefeng 已提交
177
      VLOG(1) << "ncclCommInitRank: " << i;
Y
yaoxuefeng 已提交
178
    }
179
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupEnd());
Y
yaoxuefeng 已提交
180
    VLOG(1) << "nccl group end seccessss";
Y
yaoxuefeng 已提交
181
  }
182 183
  PADDLE_ENFORCE_EQ(comm_map_.count(ring_id),
                    0,
Y
yaoxuefeng 已提交
184 185
                    platform::errors::InvalidArgument(
                        "comm_map_ of ring_id: %s should be 0. %s is provided",
186 187
                        ring_id,
                        comm_map_.count(ring_id)));
Y
yaoxuefeng 已提交
188
  for (int i = 0; i < kDevices; ++i) {
189 190 191 192 193
    AssignNCCLComm(comms[i],
                   kDevices * ntrainers,
                   train_id * kDevices + i,
                   dev_ids[i],
                   ring_id);
Y
yaoxuefeng 已提交
194
    VLOG(1) << "nccl communicator of train_id " << train_id * kDevices + i
Y
yaoxuefeng 已提交
195 196 197 198 199 200 201 202 203
            << " in ring " << ring_id << " has been created on device "
            << dev_ids[i];
  }

  std::call_once(once_flag_, []() {
    std::atexit([]() { NCCLCommContext::Instance().ReleaseNCCLComms(); });
  });
}

204 205
NCCLComm* NCCLCommContext::AssignNCCLComm(
    ncclComm_t comm, int nranks, int rank, int dev_id, int ring_id) {
206 207
  std::unique_ptr<CUDADeviceContext> dev_ctx(
      new CUDADeviceContext(CUDAPlace(dev_id)));
W
Wilber 已提交
208 209 210 211 212 213 214 215 216 217 218
  dev_ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
                            .GetAllocator(CUDAPlace(dev_id), dev_ctx->stream())
                            .get());
  dev_ctx->SetHostAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(paddle::platform::CPUPlace())
          .get());
  dev_ctx->SetZeroAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetZeroAllocator(CUDAPlace(dev_id))
          .get());
W
wanghuancoder 已提交
219 220 221 222
  dev_ctx->SetPinnedAllocator(
      paddle::memory::allocation::AllocatorFacade::Instance()
          .GetAllocator(paddle::platform::CUDAPinnedPlace())
          .get());
W
Wilber 已提交
223
  dev_ctx->PartialInitWithAllocator();
224

W
WangXi 已提交
225 226 227 228 229
  std::shared_ptr<platform::CudaEventObject> compute_event(
      platform::CudaEventResourcePool::Instance().New(dev_id));
  std::shared_ptr<platform::CudaEventObject> comm_event(
      platform::CudaEventResourcePool::Instance().New(dev_id));

230 231 232 233 234 235
  NCCLCommImpl* c = new NCCLCommImpl;
  c->set_ring_id(ring_id);
  c->set_nranks(nranks);
  c->set_rank(rank);
  c->set_comm(comm);
  c->set_dev_ctx(std::move(dev_ctx));
W
WangXi 已提交
236 237
  c->set_compute_event(std::move(compute_event));
  c->set_comm_event(std::move(comm_event));
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257

  comm_map_mutex_.lock();
  if (comm_map_.count(ring_id) == 0) {
    comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<NCCLComm>>());
  }
  auto& dev2comm = comm_map_[ring_id];

  dev2comm.emplace(dev_id, std::unique_ptr<NCCLComm>(c));
  comm_map_mutex_.unlock();

  if (ring_id == 0) {
    auto* dev_ctx = static_cast<platform::CUDADeviceContext*>(
        platform::DeviceContextPool::Instance().Get(
            platform::CUDAPlace(dev_id)));
    dev_ctx->set_nccl_comm(comm);
  }

  return comm_map_[ring_id][dev_id].get();
}

258
void NCCLCommContext::ReleaseNCCLComms() {
259 260 261 262
  for (auto& p : comm_map_) {
    for (auto& q : p.second) {
      q.second.reset();
    }
263 264 265
  }
}

266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
#endif

#if defined(PADDLE_WITH_XPU_BKCL)

class BKCLCommImpl : public BKCLComm {
 public:
  void set_ring_id(int ring_id) { ring_id_ = ring_id; }
  int ring_id() const override { return ring_id_; }

  void set_nranks(int nranks) { nranks_ = nranks; }
  int nranks() const override { return nranks_; }

  void set_rank(int rank) { rank_ = rank; }
  int rank() const override { return rank_; }

281
  int device_id() const override { return dev_ctx_->GetPlace().device; }
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302

  void set_comm(BKCLContext_t comm) { comm_ = comm; }
  BKCLContext_t comm() const override { return comm_; }

  XPUStream stream() const override {
    return dev_ctx_->x_context()->xpu_stream;
  }

  void set_dev_ctx(std::unique_ptr<XPUDeviceContext>&& dev_ctx) {
    dev_ctx_ = std::move(dev_ctx);
  }
  XPUDeviceContext* dev_context() const override { return dev_ctx_.get(); }

 private:
  int ring_id_;
  int nranks_;
  int rank_;
  BKCLContext_t comm_;
  std::unique_ptr<XPUDeviceContext> dev_ctx_;
};

303 304
BKCLComm* BKCLCommContext::CreateComm(
    BKCLUniqueId* bkcl_id, int nranks, int rank, int dev_id, int ring_id) {
305 306 307 308
  PADDLE_ENFORCE_NOT_NULL(bkcl_id,
                          platform::errors::InvalidArgument(
                              "The bkcl unique id should not be null."));
  PADDLE_ENFORCE_GT(
309 310
      nranks,
      1,
311 312
      platform::errors::InvalidArgument(
          "Expected nranks > 1. But received nranks is %d.", nranks));
313 314
  PADDLE_ENFORCE_GE(rank,
                    0,
315 316 317
                    platform::errors::InvalidArgument(
                        "Expected rank >= 0. But received rank is %d.", rank));
  PADDLE_ENFORCE_LT(
318 319
      rank,
      nranks,
320 321
      platform::errors::InvalidArgument(
          "Expected rank < nranks. But received rank is %d, nranks is %d.",
322 323
          rank,
          nranks));
324
  PADDLE_ENFORCE_GE(
325 326
      dev_id,
      0,
327 328 329 330
      platform::errors::InvalidArgument(
          "Expected dev_id >= 0. But received dev_id is %d.", dev_id));

  BKCLContext_t comm = nullptr;
331 332
  platform::SetXPUDeviceId(dev_id);
  PADDLE_ENFORCE_XPU_SUCCESS(bkcl_init_rank(&comm, rank, nranks, bkcl_id));
333 334 335 336 337 338 339 340 341 342 343 344 345

  auto* comm_wrapper = AssignBKCLComm(comm, nranks, rank, dev_id, ring_id);

  VLOG(1) << "bkcl communicator of rank " << rank << " in ring " << ring_id
          << " has been created on device " << dev_id;

  std::call_once(once_flag_, []() {
    std::atexit([]() { BKCLCommContext::Instance().ReleaseBKCLComms(); });
  });

  return comm_wrapper;
}

346 347
BKCLComm* BKCLCommContext::AssignBKCLComm(
    BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) {
348 349
  std::unique_ptr<XPUDeviceContext> dev_ctx(
      new XPUDeviceContext(XPUPlace(dev_id)));
350 351 352 353 354 355
  // used in BKCL as comm_stream, for every dev_id there is
  // a comm_stream at each ring. this stream is passed as input var
  // when calling collective comm commands like bkcl_all_reduce
  XPUStream comm_stream;
  PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_stream));
  dev_ctx->SetXPUStream(comm_stream);
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376

  BKCLCommImpl* c = new BKCLCommImpl;
  c->set_ring_id(ring_id);
  c->set_nranks(nranks);
  c->set_rank(rank);
  c->set_comm(comm);
  c->set_dev_ctx(std::move(dev_ctx));

  comm_map_mutex_.lock();
  if (comm_map_.count(ring_id) == 0) {
    comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<BKCLComm>>());
  }
  auto& dev2comm = comm_map_[ring_id];

  dev2comm.emplace(dev_id, std::unique_ptr<BKCLComm>(c));
  comm_map_mutex_.unlock();

  if (ring_id == 0) {
    auto* dev_ctx = static_cast<platform::XPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(
            platform::XPUPlace(dev_id)));
W
Wilber 已提交
377
    dev_ctx->SetBkclContext(comm);
378 379 380 381 382 383 384 385 386 387 388 389
  }

  return comm_map_[ring_id][dev_id].get();
}

void BKCLCommContext::ReleaseBKCLComms() {
  for (auto& p : comm_map_) {
    for (auto& q : p.second) {
      q.second.reset();
    }
  }
}
390 391

#endif
392 393 394

}  // namespace platform
}  // namespace paddle