ProcessGroupNCCL.cc 34.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
16

L
lilong12 已提交
17
#include "paddle/fluid/distributed/collective/Common.h"
18
#include "paddle/fluid/distributed/collective/utils.h"
19
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
20
#include "paddle/fluid/platform/place.h"
L
LiYuRio 已提交
21
#include "paddle/phi/api/lib/utils/allocator.h"
22 23 24 25 26 27 28 29 30

DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);

constexpr int64_t kWaitBlockTImeout = 10;

namespace paddle {
namespace distributed {

31 32 33 34 35 36 37
ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
                                     int rank,
                                     CommType comm_type,
                                     bool sync_op,
                                     bool use_calc_stream)
    : TaskStream(rank, comm_type, sync_op, use_calc_stream),
      comm_event_(place),
W
Wen Sun 已提交
38
      task_place_(place) {}
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

ProcessGroupNCCL::NCCLTask::~NCCLTask() {}

bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }

void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
    const phi::DeviceContext& ctx) {
  comm_event_.Record(&ctx);
}

// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
  // Warning here when use calc stream but also invoke waiting explicitly.
  if (UseCalcStream()) {
    VLOG(3) << "Warning: The communication is on calc stream, wait here is "
               "useless.";
    return true;
  }

W
Wen Sun 已提交
58 59 60
  const auto* calc_ctx =
      platform::DeviceContextPool::Instance().Get(task_place_);
  comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
61 62 63 64 65 66

  if (FLAGS_nccl_blocking_wait) {
    // NOTE(shenliang03): It will block host for sync
    while (!IsCompleted()) {
      std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
    }
67
  }
68

W
Wen Sun 已提交
69
  if (IsBlockCPUInWait()) {
70 71 72 73 74 75 76 77
    // If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
  }
  return true;
78 79
}

80 81 82 83 84 85 86
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
                                   int rank,
                                   int size,
                                   int gid)
87
    : ProcessGroupStream(rank, size, gid), store_(store) {}
88 89

void ProcessGroupNCCL::GroupStart() {
90
  NCCL_CHECK(platform::dynload::ncclGroupStart());
91 92 93
}

void ProcessGroupNCCL::GroupEnd() {
94
  NCCL_CHECK(platform::dynload::ncclGroupEnd());
95 96
}

97
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
98 99 100 101
    const Place& place) const {
  return GetDeviceContext(place, /*use_calc_stream*/ false);
}

102
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
103 104 105 106
    const Place& place, bool use_calc_stream) const {
  const std::string& key = GetKeyFromPlace(place);
  if (use_calc_stream) {
    const auto& iter = place_to_calc_ctx_.find(key);
107
    return iter->second;
108 109 110 111 112 113 114
  } else {
    const auto& iter = place_to_comm_ctx_.find(key);
    PADDLE_ENFORCE_NE(
        iter,
        place_to_comm_ctx_.end(),
        platform::errors::NotFound(
            "Cannot find the device context in this process group."));
115
    return iter->second.get();
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  }
}

ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
  const std::string& key = GetKeyFromPlace(place);
  const auto& iter = place_to_comm_ctx_.find(key);
  PADDLE_ENFORCE_NE(
      iter,
      place_to_comm_ctx_.end(),
      platform::errors::NotFound(
          "Cannot find the NCCL commmunicator in this process group."));
  return iter->second->nccl_comm();
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
133 134
    int64_t offset,
    int64_t numel,
135 136
    bool sync_op,
    bool use_calc_stream) {
137 138 139
  // numel > 0 indicates the tensor need to be sliced
  const phi::DenseTensor& in_tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
140 141
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
142
        NCCL_CHECK(platform::dynload::ncclAllGather(
143 144 145 146
            in_tensor_maybe_partial.data(),
            out_tensor->data(),
            in_tensor_maybe_partial.numel(),
            platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
147
            comm,
148
            stream));
149
      },
150
      in_tensor_maybe_partial,
151 152 153 154 155 156 157 158 159 160 161
      CommType::ALLGATHER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
162 163
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
164
        NCCL_CHECK(platform::dynload::ncclAllReduce(
165 166 167 168
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
169 170
            ToNCCLRedType(opts.reduce_op),
            comm,
171
            stream));
172
      },
173
      in_tensor,
174 175 176 177 178
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
                         const std::vector<int64_t>& size_on_each_rank,
                         int world_size) {
  int length_size_on_each_rank = size_on_each_rank.size();
  PADDLE_ENFORCE_EQ(
      length_size_on_each_rank,
      world_size,
      platform::errors::InvalidArgument(
          "The length of size_on_each_rank must be equal to world_size."));

  int64_t sum_size_on_each_rank =
      std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
  PADDLE_ENFORCE_EQ(
      sum_size_on_each_rank,
      tensor_dim[0],
      platform::errors::InvalidArgument(
          "The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op,
    bool use_calc_stream) {
  const phi::DDim& out_dim = out_tensor->dims();
  const phi::DDim& in_dim = in_tensor.dims();
  CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
  CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);

210 211 212 213
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
        int64_t in_row_size = in_tensor.numel() / in_dim[0],
                out_row_size = out_tensor->numel() / out_dim[0];
214 215 216 217 218 219
        int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
        phi::DenseTensor input_partial, output_partial;

        GroupStart();
        for (auto i = 0; i < size_; i++) {
          in_numel = in_size_each_rank[i] * in_row_size;
220
          input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
221
          NCCL_CHECK(platform::dynload::ncclSend(
222 223
              input_partial.data(),
              in_numel,
224
              platform::ToNCCLDataType(input_partial.dtype()),
225 226 227 228 229 230
              i,
              comm,
              stream));
          in_offset += in_numel;

          out_numel = out_size_each_rank[i] * out_row_size;
231
          output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
232
          NCCL_CHECK(platform::dynload::ncclRecv(
233 234
              output_partial.data(),
              out_numel,
235
              platform::ToNCCLDataType(output_partial.dtype()),
236 237 238 239 240 241 242
              i,
              comm,
              stream));
          out_offset += out_numel;
        }
        GroupEnd();
      },
243
      in_tensor,
244 245 246 247 248
      CommType::ALLTOALL,
      sync_op,
      use_calc_stream);
}

249 250
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
    const BarrierOptions& opts) {
251 252 253 254 255
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
                    platform::errors::PreconditionNotMet(
                        "The barrier device id must greater or equal than 0."));
  platform::CUDAPlace place(opts.device_id);
256
  auto allocator = std::unique_ptr<phi::Allocator>(
257
      new paddle::experimental::DefaultAllocator(place));
258 259 260 261 262 263 264 265 266
  phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
  phi::DenseTensor barrier_tensor{allocator.get(), meta};

  auto task = AllReduce(&barrier_tensor,
                        barrier_tensor,
                        {},
                        /*sync_op*/ true,
                        /*use_calc_stream*/ false);
  auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
W
Wen Sun 已提交
267
  nccl_task->SetBlockCPUInWait();
268 269 270 271 272 273 274 275 276
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
277 278
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
279
        int root = opts.source_rank + opts.source_root;
280
        NCCL_CHECK(platform::dynload::ncclBroadcast(
281 282 283 284
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
285 286
            root,
            comm,
287
            stream));
288
      },
289
      in_tensor,
290 291 292
      CommType::BROADCAST,
      sync_op,
      use_calc_stream);
293 294
}

295 296 297 298 299 300
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
301 302
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
303
        NCCL_CHECK(platform::dynload::ncclReduce(
304 305 306 307
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
308 309 310 311 312
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
      },
313
      in_tensor,
314 315 316 317 318 319 320 321 322 323 324
      CommType::REDUCE,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
325 326
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
327
        NCCL_CHECK(platform::dynload::ncclReduceScatter(
328 329 330 331
            in_tensor.data(),
            out_tensor->data(),
            out_tensor->numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
332 333 334 335
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream));
      },
336
      in_tensor,
337 338 339 340 341 342 343 344 345 346 347
      CommType::REDUCE_SCATTER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
348 349 350
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
        int64_t numel = in_tensor.numel() / size_;
351 352 353 354 355
        if (rank_ == opts.root_rank) {
          int64_t offset = 0;
          phi::DenseTensor partial_tensor;
          GroupStart();
          for (auto i = 0; i < size_; i++) {
356
            partial_tensor = GetPartialTensor(in_tensor, offset, numel);
357 358 359
            NCCL_CHECK(platform::dynload::ncclSend(
                partial_tensor.data(),
                numel,
360
                platform::ToNCCLDataType(partial_tensor.dtype()),
361 362 363 364 365 366
                i,
                comm,
                stream));
            offset += numel;
          }
          NCCL_CHECK(platform::dynload::ncclRecv(
367
              out_tensor->data(),
368
              numel,
369
              platform::ToNCCLDataType(out_tensor->dtype()),
370 371 372 373 374 375
              opts.root_rank,
              comm,
              stream));
          GroupEnd();
        } else {
          NCCL_CHECK(platform::dynload::ncclRecv(
376
              out_tensor->data(),
377
              numel,
378
              platform::ToNCCLDataType(out_tensor->dtype()),
379 380 381 382 383
              opts.root_rank,
              comm,
              stream));
        }
      },
384
      in_tensor,
385 386 387 388 389
      CommType::SCATTER,
      sync_op,
      use_calc_stream);
}

390 391 392 393
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
    phi::DenseTensor* tensor,
    int src_rank,
    int64_t offset,
394
    int64_t numel,
395 396
    bool sync_op,
    bool use_calc_stream) {
397 398 399 400 401 402
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
403 404
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
405
        NCCL_CHECK(platform::dynload::ncclRecv(
406 407 408 409
            tensor->data(),
            tensor->numel(),
            platform::ToNCCLDataType(tensor->dtype()),
            src_rank,
410
            comm,
411
            stream));
412
      },
413
      *tensor,
414 415 416 417 418 419
      CommType::RECV,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
420
    const phi::DenseTensor& tensor,
421 422
    int dst_rank,
    int64_t offset,
423
    int64_t numel,
424 425
    bool sync_op,
    bool use_calc_stream) {
426
  // numel > 0 indicates the tensor need to be sliced
427 428 429 430
  const phi::DenseTensor& tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
431
        NCCL_CHECK(platform::dynload::ncclSend(
432 433 434 435
            tensor_maybe_partial.data(),
            tensor_maybe_partial.numel(),
            platform::ToNCCLDataType(tensor_maybe_partial.dtype()),
            dst_rank,
436
            comm,
437
            stream));
438
      },
439
      tensor_maybe_partial,
440 441 442 443 444
      CommType::SEND,
      sync_op,
      use_calc_stream);
}

445
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
446
    const Place& place,
447 448 449 450 451
    int rank,
    CommType comm_type,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
452
      place, rank, comm_type, is_sync, use_calc_stream);
453 454
}

455 456 457 458 459 460 461 462 463 464 465 466
void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
  const std::string key =
      "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
  if (rank_ == 0) {
    std::vector<uint8_t> nccl_id_wrapper(
        reinterpret_cast<uint8_t*>(nccl_id),
        reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
    store_->set(key, nccl_id_wrapper);
  } else {
    const auto& nccl_id_wrapper = store_->get(key);
    std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
  }
467 468
}

469 470
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
                                          const std::string& place_key) {
W
Wen Sun 已提交
471 472 473 474
  if (place_to_comm_ctx_.size() > 0) {
    VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
  }

475 476
  ncclUniqueId nccl_id;
  if (rank_ == 0) {
477
    NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
478 479
  }
  BroadcastUniqueNCCLID(&nccl_id);
480

481 482 483
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << place_key
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
484

485 486 487 488
  auto* calc_ctx = static_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto comm_ctx = std::make_unique<phi::GPUContext>(place);
  ncclComm_t nccl_comm;
489
  NCCL_CHECK(platform::dynload::ncclCommInitRank(
490 491 492
      &nccl_comm, GetSize(), nccl_id, GetRank()));
  comm_ctx->set_nccl_comm(nccl_comm);

W
Wen Sun 已提交
493 494 495
  place_to_calc_event_.emplace(place_key, place);
  place_to_calc_ctx_.emplace(place_key, calc_ctx);
  place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
496 497

  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
498 499 500
  std::vector<phi::GPUContext*> comm_ctx_wrapper{
      place_to_comm_ctx_[place_key].get()};
  places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
501 502
}

W
Wen Sun 已提交
503
void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
504
  const std::string& key = GetKeyFromPlace(place);
W
Wen Sun 已提交
505 506 507 508 509
  auto& calc_event = place_to_calc_event_.at(key);
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
  calc_event.Record(calc_ctx);
  calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
510 511
}

512 513 514
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
    std::function<void(ncclComm_t, gpuStream_t)> fn,
    const phi::DenseTensor& tensor,
515 516 517
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
518
  const auto& place = tensor.place();
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
  const auto& key = GetKeyFromPlace(place);

  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
    CreateNCCLEnvCache(place, key);
  }

  if (!use_calc_stream) {
    SyncCalcStream(place);
  }

  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
537
  fn(nccl_comm, nccl_stream);
538 539 540

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
541
      memory::RecordStream(tensor.Holder(), nccl_stream);
542 543 544 545 546 547 548
    }
    task->UpdateWaitChain(*comm_ctx);
  }

  return task;
}

549 550
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
W
Wen Sun 已提交
551
                       platform::DeviceEvent& nccl_event,         // NOLINT
552 553 554 555
                       std::vector<phi::GPUContext*>& dev_ctx) {  // NOLINT
  for (size_t i = 0; i < places.size(); ++i) {
    auto* default_ctx = static_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(places[i]));
W
Wen Sun 已提交
556 557
    nccl_event.Record(default_ctx);
    nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
B
Baibaifan 已提交
558
  }
559 560
}

561 562 563 564 565 566 567
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
    std::vector<Place> places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
      places, rank, comm_type, inputs);
568
}
569

570 571 572 573 574 575 576
ProcessGroupNCCL::NCCLTask::NCCLTask(
    const std::vector<Place>& places,
    int rank,
    CommType CommType,
    const std::vector<phi::DenseTensor>& inputs)
    : TaskStream(rank, inputs, CommType),
      comm_event_(places[0]),
W
Wen Sun 已提交
577
      task_place_(places[0]) {}
578

579 580 581
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
    const std::string& places_key, const std::vector<Place>& places) {
582 583
  PADDLE_ENFORCE_EQ(places_key.empty(),
                    false,
584 585 586 587
                    platform::errors::PreconditionNotMet(
                        "Not able to create/get the NCCL Communicator since "
                        "the GPU place are not known"));

588
  ncclUniqueId nccl_id;
589
  if (rank_ == 0) {
590
    NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
591
  }
592
  BroadcastUniqueNCCLID(&nccl_id);
593

594 595
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << places_key
596 597
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);

L
Leo Chen 已提交
598
  std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
599 600
  dev_ctx.resize(places.size());

601 602 603
  std::vector<phi::GPUContext*> dev_ctx_raw;
  dev_ctx_raw.resize(places.size());

604
  GroupStart();
605 606 607

  for (size_t i = 0; i < places.size(); ++i) {
    platform::CUDADeviceGuard guard(places[i]);
608

L
Leo Chen 已提交
609
    dev_ctx[i].reset(new phi::GPUContext(places[i]));
610
    ncclComm_t nccl_comm;
611
    NCCL_CHECK(platform::dynload::ncclCommInitRank(
612 613 614
        &nccl_comm, GetSize(), nccl_id, GetRank()));
    dev_ctx[i]->set_nccl_comm(nccl_comm);
    dev_ctx_raw[i] = dev_ctx[i].get();
615 616
  }

617
  GroupEnd();
618

619
  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
620 621 622 623 624 625
  place_to_calc_event_.emplace(places_key, places[0]);
  place_to_calc_ctx_.emplace(
      places_key,
      static_cast<phi::GPUContext*>(
          platform::DeviceContextPool::Instance().Get(places[0])));
  place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
626 627

  // These caches will be useful to process sync/wait/communicate
628
  places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
629 630 631 632
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
633
    std::vector<phi::DenseTensor>& inputs,
634 635 636
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType op_type) {
637 638 639 640 641
  const auto places = GetPlaceList(inputs);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
642
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
643 644 645 646
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
647 648
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
649 650 651 652 653 654

  auto task = CreateTask(places, rank_, op_type, inputs);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

S
ShenLiang 已提交
655 656
  {
    platform::NCCLGroupGuard nccl_guard;
657 658
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
659
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
660 661
      fn(inputs[i],
         outputs[i],
W
Wen Sun 已提交
662
         places_to_ctx_.at(key)[i]->nccl_comm(),
663
         nccl_stream);
664 665 666
    }
  }

S
ShenLiang 已提交
667
  if (FLAGS_use_stream_safe_cuda_allocator) {
668 669
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
S
ShenLiang 已提交
670
      memory::RecordStream(inputs[i].Holder(),
W
Wen Sun 已提交
671
                           places_to_ctx_.at(key)[i]->stream());
672 673 674 675 676
    }
  }

  for (size_t i = 0; i < inputs.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
677
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
678 679 680 681
  }
  return task;
}

B
Baibaifan 已提交
682 683
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
684 685 686
    std::vector<phi::DenseTensor>& tensors,
    Fn fn,
    int dst_rank,
687
    CommType op_type) {
B
Baibaifan 已提交
688 689 690 691 692
  const auto places = GetPlaceList(tensors);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
693
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
B
Baibaifan 已提交
694 695 696 697
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
698 699
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
B
Baibaifan 已提交
700 701 702 703 704 705

  auto task = CreateTask(places, rank_, op_type, tensors);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

706 707
  {
    platform::NCCLGroupGuard nccl_guard;
B
Baibaifan 已提交
708 709
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
710
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
711
      fn(tensors[i],
W
Wen Sun 已提交
712
         places_to_ctx_.at(key)[i]->nccl_comm(),
713 714
         nccl_stream,
         dst_rank);
B
Baibaifan 已提交
715 716 717
    }
  }

718
  if (FLAGS_use_stream_safe_cuda_allocator) {
B
Baibaifan 已提交
719 720
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
721
      memory::RecordStream(tensors[i].Holder(),
W
Wen Sun 已提交
722
                           places_to_ctx_.at(key)[i]->stream());
B
Baibaifan 已提交
723 724 725 726 727
    }
  }

  for (size_t i = 0; i < tensors.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
728
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
B
Baibaifan 已提交
729 730 731 732
  }
  return task;
}

733
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
734
    std::vector<phi::DenseTensor>& in_tensors,
735 736
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts) {
737
  PADDLE_ENFORCE_EQ(
738 739
      CheckTensorsInCudaPlace(in_tensors),
      true,
740
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
741
  return Collective(
742 743 744 745 746 747
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
748
        return platform::dynload::ncclAllReduce(
749 750 751
            input.data(),
            output.data(),
            input.numel(),
752
            platform::ToNCCLDataType(input.type()),
753 754 755
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream);
756 757
      },
      CommType::ALLREDUCE);
758 759 760
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
761
    std::vector<phi::DenseTensor>& in_tensors,
762 763
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts) {
764
  PADDLE_ENFORCE_EQ(
765 766
      CheckTensorsInCudaPlace(in_tensors),
      true,
767 768
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));

769
  return Collective(
770 771 772 773 774
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
775 776 777 778
          const gpuStream_t& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
        return platform::dynload::ncclBroadcast(
779 780 781 782 783 784 785
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
            stream);
786 787
      },
      CommType::BROADCAST);
788 789
}

790 791
void CheckTensorsInDifferentDevices(
    const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
B
Baibaifan 已提交
792
  PADDLE_ENFORCE_EQ(
793 794
      tensors.size() == 0,
      false,
B
Baibaifan 已提交
795 796
      platform::errors::InvalidArgument("Tensor list must be nonempty."));
  PADDLE_ENFORCE_LE(
797 798
      tensors.size(),
      num_devices,
B
Baibaifan 已提交
799 800 801 802 803 804
      platform::errors::InvalidArgument(
          "Tensor list mustn't be larger than the number of available GPUs."));

  std::set<Place> used_devices;

  for (const auto& t : tensors) {
805 806
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()),
                      true,
B
Baibaifan 已提交
807 808 809
                      platform::errors::InvalidArgument(
                          "Tensors must be CUDA and dense tensor."));

810
    const auto inserted = used_devices.insert(t.place()).second;
811 812
    PADDLE_ENFORCE_EQ(inserted,
                      true,
B
Baibaifan 已提交
813 814 815 816 817 818
                      platform::errors::InvalidArgument(
                          "Tensors must be on distinct GPU devices."));
  }
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
819
    std::vector<phi::DenseTensor>& tensors, int dst_rank) {
B
Baibaifan 已提交
820 821
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

822 823
  auto task = PointToPoint(
      tensors,
824 825 826
      [&](phi::DenseTensor& input,
          ncclComm_t comm,
          const gpuStream_t& stream,
827 828
          int dst_rank) {
        return platform::dynload::ncclSend(
829 830 831 832 833 834
            input.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            dst_rank,
            comm,
            stream);
835
      },
836 837
      dst_rank,
      CommType::SEND);
B
Baibaifan 已提交
838 839 840 841
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
842
    std::vector<phi::DenseTensor>& tensors, int src_rank) {
B
Baibaifan 已提交
843 844
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

845 846
  auto task = PointToPoint(
      tensors,
847 848 849
      [&](phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream,
850 851
          int src_rank) {
        return platform::dynload::ncclRecv(
852 853 854 855 856 857
            output.data(),
            output.numel(),
            platform::ToNCCLDataType(output.dtype()),
            src_rank,
            comm,
            stream);
858
      },
859 860
      src_rank,
      CommType::RECV);
861 862 863
  return task;
}

864
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
865 866
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
867
  PADDLE_ENFORCE_EQ(
868 869
      CheckTensorsInCudaPlace(in_tensors),
      true,
870 871
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
872 873
      CheckTensorsInCudaPlace(out_tensors),
      true,
874
      platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
875
  return Collective(
876 877 878 879 880 881
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
882
        return platform::dynload::ncclAllGather(
883 884 885 886 887 888
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
            stream);
889 890
      },
      CommType::ALLGATHER);
891 892
}

893 894
void* GetPointerByOffset(void* raw_pointer,
                         size_t offset,
895 896 897 898 899 900 901
                         experimental::DataType type) {
  if (type == experimental::DataType::FLOAT32) {
    return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT64) {
    return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
                                   offset);
902 903 904
  } else if (type == experimental::DataType::FLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
                                   offset);
905 906 907 908 909 910
  } else if (type == experimental::DataType::INT32) {
    return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT64) {
    return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
                                   offset);
911 912 913 914 915 916 917 918
  } else if (type == experimental::DataType::INT8) {
    return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::UINT8) {
    return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::BOOL) {
    return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
919
                                   offset);
920 921 922
  } else if (type == experimental::DataType::BFLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
                                   offset);
923 924
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
925
        "Datatype %s in NCCL is not supported.", type));
926
  }
927
  return nullptr;
928 929 930
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
931 932
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
933
  PADDLE_ENFORCE_EQ(
934 935
      CheckTensorsInCudaPlace(in_tensors),
      true,
936 937
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
938 939
      CheckTensorsInCudaPlace(out_tensors),
      true,
940 941
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
942 943 944 945 946
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
947 948
          const gpuStream_t& stream) {
        size_t offset = 0;
949
        GroupStart();
950 951
        for (auto i = 0; i < size_; i++) {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
952
              GetPointerByOffset(input.data(), offset, input.dtype()),
953 954 955 956 957
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
958
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
959
              GetPointerByOffset(output.data(), offset, input.dtype()),
960 961 962 963 964
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
965
          offset += input.numel() / size_;
966
        }
967
        GroupEnd();
968
      },
969 970 971
      CommType::ALLTOALL);
}

972
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
973
    std::vector<phi::DenseTensor>& in_tensors,
974 975
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceOptions& opts) {
976
  PADDLE_ENFORCE_EQ(
977 978
      CheckTensorsInCudaPlace(in_tensors),
      true,
979 980
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
981 982 983 984 985 986
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
987
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
988 989 990
            input.data(),
            output.data(),
            input.numel(),
991
            platform::ToNCCLDataType(input.dtype()),
992 993 994 995
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
996 997 998 999 1000
      },
      CommType::REDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
1001
    std::vector<phi::DenseTensor>& in_tensors,
1002 1003
    std::vector<phi::DenseTensor>& out_tensors,
    const ScatterOptions& opts) {
1004
  PADDLE_ENFORCE_EQ(
1005 1006
      CheckTensorsInCudaPlace(in_tensors),
      true,
1007 1008
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1009 1010
      CheckTensorsInCudaPlace(out_tensors),
      true,
1011 1012
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1013 1014 1015 1016 1017
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1018 1019 1020
          const gpuStream_t& stream) {
        size_t offset = 0;
        if (rank_ == opts.root_rank) {
1021
          GroupStart();
1022 1023
          for (auto i = 0; i < size_; i++) {
            PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
1024
                GetPointerByOffset(input.data(), offset, input.dtype()),
1025 1026 1027 1028 1029
                input.numel() / size_,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
1030
            offset += input.numel() / size_;
1031 1032
          }
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1033 1034 1035 1036 1037
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1038
              stream));
1039
          GroupEnd();
1040 1041
        } else {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1042 1043 1044 1045 1046
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1047 1048 1049 1050 1051 1052
              stream));
        }
      },
      CommType::SCATTER);
}

L
LiYuRio 已提交
1053 1054 1055 1056 1057 1058 1059 1060
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
    const std::shared_ptr<Store>& store, int rank, int size, int gid) {
  auto process_group =
      std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
  ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
  return process_group;
}

1061 1062
}  //  namespace distributed
}  //  namespace paddle