ProcessGroupNCCL.cc 35.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
16

L
lilong12 已提交
17
#include "paddle/fluid/distributed/collective/Common.h"
18
#include "paddle/fluid/distributed/collective/utils.h"
19
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
20
#include "paddle/fluid/platform/place.h"
L
LiYuRio 已提交
21
#include "paddle/phi/api/lib/utils/allocator.h"
22 23 24 25 26 27 28 29 30

DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);

constexpr int64_t kWaitBlockTImeout = 10;

namespace paddle {
namespace distributed {

31 32 33 34 35 36 37
ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
                                     int rank,
                                     CommType comm_type,
                                     bool sync_op,
                                     bool use_calc_stream)
    : TaskStream(rank, comm_type, sync_op, use_calc_stream),
      comm_event_(place),
W
Wen Sun 已提交
38
      task_place_(place) {}
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

ProcessGroupNCCL::NCCLTask::~NCCLTask() {}

bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }

void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
    const phi::DeviceContext& ctx) {
  comm_event_.Record(&ctx);
}

// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
  // Warning here when use calc stream but also invoke waiting explicitly.
  if (UseCalcStream()) {
    VLOG(3) << "Warning: The communication is on calc stream, wait here is "
               "useless.";
    return true;
  }

W
Wen Sun 已提交
58 59 60
  const auto* calc_ctx =
      platform::DeviceContextPool::Instance().Get(task_place_);
  comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
61 62 63 64 65 66

  if (FLAGS_nccl_blocking_wait) {
    // NOTE(shenliang03): It will block host for sync
    while (!IsCompleted()) {
      std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
    }
67
  }
68

W
Wen Sun 已提交
69
  if (IsBlockCPUInWait()) {
70 71 72 73 74 75 76 77
    // If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
  }
  return true;
78 79
}

80 81 82 83 84 85 86
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
                                   int rank,
                                   int size,
                                   int gid)
87
    : ProcessGroupStream(rank, size, gid), store_(store) {}
88 89

void ProcessGroupNCCL::GroupStart() {
90
  NCCL_CHECK(platform::dynload::ncclGroupStart());
91 92 93
}

void ProcessGroupNCCL::GroupEnd() {
94
  NCCL_CHECK(platform::dynload::ncclGroupEnd());
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
}

const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
    const Place& place) const {
  return GetDeviceContext(place, /*use_calc_stream*/ false);
}

const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
    const Place& place, bool use_calc_stream) const {
  const std::string& key = GetKeyFromPlace(place);
  if (use_calc_stream) {
    const auto& iter = place_to_calc_ctx_.find(key);
    return *iter->second;
  } else {
    const auto& iter = place_to_comm_ctx_.find(key);
    PADDLE_ENFORCE_NE(
        iter,
        place_to_comm_ctx_.end(),
        platform::errors::NotFound(
            "Cannot find the device context in this process group."));
    return *iter->second;
  }
}

ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
  const std::string& key = GetKeyFromPlace(place);
  const auto& iter = place_to_comm_ctx_.find(key);
  PADDLE_ENFORCE_NE(
      iter,
      place_to_comm_ctx_.end(),
      platform::errors::NotFound(
          "Cannot find the NCCL commmunicator in this process group."));
  return iter->second->nccl_comm();
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
133 134
    int64_t offset,
    int64_t numel,
135 136
    bool sync_op,
    bool use_calc_stream) {
137 138 139
  // numel > 0 indicates the tensor need to be sliced
  const phi::DenseTensor& in_tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
140 141
  return Collective(
      out_tensor,
142 143 144 145 146
      in_tensor_maybe_partial,
      [](phi::DenseTensor* output,
         const phi::DenseTensor& input,
         ncclComm_t comm,
         gpuStream_t stream) {
147
        NCCL_CHECK(platform::dynload::ncclAllGather(
148 149 150 151 152
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
153
            stream));
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
      },
      CommType::ALLGATHER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
172
          gpuStream_t stream) {
173
        NCCL_CHECK(platform::dynload::ncclAllReduce(
174 175 176 177 178 179
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            ToNCCLRedType(opts.reduce_op),
            comm,
180
            stream));
181 182 183 184 185 186
      },
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
                         const std::vector<int64_t>& size_on_each_rank,
                         int world_size) {
  int length_size_on_each_rank = size_on_each_rank.size();
  PADDLE_ENFORCE_EQ(
      length_size_on_each_rank,
      world_size,
      platform::errors::InvalidArgument(
          "The length of size_on_each_rank must be equal to world_size."));

  int64_t sum_size_on_each_rank =
      std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
  PADDLE_ENFORCE_EQ(
      sum_size_on_each_rank,
      tensor_dim[0],
      platform::errors::InvalidArgument(
          "The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op,
    bool use_calc_stream) {
  const phi::DDim& out_dim = out_tensor->dims();
  const phi::DDim& in_dim = in_tensor.dims();
  CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
  CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);

  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
          gpuStream_t stream) {
        int64_t in_row_size = input.numel() / in_dim[0],
                out_row_size = output->numel() / out_dim[0];
        int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
        phi::DenseTensor input_partial, output_partial;

        GroupStart();
        for (auto i = 0; i < size_; i++) {
          in_numel = in_size_each_rank[i] * in_row_size;
          input_partial = GetPartialTensor(input, in_offset, in_numel);
234
          NCCL_CHECK(platform::dynload::ncclSend(
235 236 237 238 239 240 241 242 243 244
              input_partial.data(),
              in_numel,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
          in_offset += in_numel;

          out_numel = out_size_each_rank[i] * out_row_size;
          output_partial = GetPartialTensor(*output, out_offset, out_numel);
245
          NCCL_CHECK(platform::dynload::ncclRecv(
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
              output_partial.data(),
              out_numel,
              platform::ToNCCLDataType(output->dtype()),
              i,
              comm,
              stream));
          out_offset += out_numel;
        }
        GroupEnd();
      },
      CommType::ALLTOALL,
      sync_op,
      use_calc_stream);
}

261 262
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
    const BarrierOptions& opts) {
263 264 265 266 267
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
                    platform::errors::PreconditionNotMet(
                        "The barrier device id must greater or equal than 0."));
  platform::CUDAPlace place(opts.device_id);
268
  auto allocator = std::unique_ptr<phi::Allocator>(
269
      new paddle::experimental::DefaultAllocator(place));
270 271 272 273 274 275 276 277 278
  phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
  phi::DenseTensor barrier_tensor{allocator.get(), meta};

  auto task = AllReduce(&barrier_tensor,
                        barrier_tensor,
                        {},
                        /*sync_op*/ true,
                        /*use_calc_stream*/ false);
  auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
W
Wen Sun 已提交
279
  nccl_task->SetBlockCPUInWait();
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
295
          gpuStream_t stream) {
296
        int root = opts.source_rank + opts.source_root;
297
        NCCL_CHECK(platform::dynload::ncclBroadcast(
298 299 300 301 302 303
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
304
            stream));
305 306 307 308
      },
      CommType::BROADCAST,
      sync_op,
      use_calc_stream);
309 310
}

311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
          gpuStream_t stream) {
        NCCL_CHECK(platform::dynload::ncclReduce(
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
      },
      CommType::REDUCE,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
          gpuStream_t stream) {
        NCCL_CHECK(platform::dynload::ncclReduceScatter(
            input.data(),
            output->data(),
            output->numel(),
            platform::ToNCCLDataType(input.dtype()),
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream));
      },
      CommType::REDUCE_SCATTER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
          gpuStream_t stream) {
        int64_t numel = input.numel() / size_;
        if (rank_ == opts.root_rank) {
          int64_t offset = 0;
          phi::DenseTensor partial_tensor;
          GroupStart();
          for (auto i = 0; i < size_; i++) {
            partial_tensor = GetPartialTensor(input, offset, numel);
            NCCL_CHECK(platform::dynload::ncclSend(
                partial_tensor.data(),
                numel,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
            offset += numel;
          }
          NCCL_CHECK(platform::dynload::ncclRecv(
              output->data(),
              numel,
              platform::ToNCCLDataType(output->dtype()),
              opts.root_rank,
              comm,
              stream));
          GroupEnd();
        } else {
          NCCL_CHECK(platform::dynload::ncclRecv(
              output->data(),
              numel,
              platform::ToNCCLDataType(output->dtype()),
              opts.root_rank,
              comm,
              stream));
        }
      },
      CommType::SCATTER,
      sync_op,
      use_calc_stream);
}

418 419 420 421
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
    phi::DenseTensor* tensor,
    int src_rank,
    int64_t offset,
422
    int64_t numel,
423 424
    bool sync_op,
    bool use_calc_stream) {
425 426 427 428 429 430
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
431
  return PointToPoint(
432
      tensor,
433
      src_rank,
434 435 436 437
      [](phi::DenseTensor* output,
         int src,
         ncclComm_t comm,
         gpuStream_t stream) {
438
        NCCL_CHECK(platform::dynload::ncclRecv(
439 440 441 442 443
            output->data(),
            output->numel(),
            platform::ToNCCLDataType(output->dtype()),
            src,
            comm,
444
            stream));
445 446 447 448 449 450 451 452 453 454
      },
      CommType::RECV,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
    phi::DenseTensor* tensor,
    int dst_rank,
    int64_t offset,
455
    int64_t numel,
456 457
    bool sync_op,
    bool use_calc_stream) {
458 459 460 461 462 463
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
464
  return PointToPoint(
465
      tensor,
466
      dst_rank,
467 468 469 470
      [](phi::DenseTensor* input,
         int dst,
         ncclComm_t comm,
         gpuStream_t stream) {
471
        NCCL_CHECK(platform::dynload::ncclSend(
472 473 474 475 476
            input->data(),
            input->numel(),
            platform::ToNCCLDataType(input->dtype()),
            dst,
            comm,
477
            stream));
478 479 480 481 482 483
      },
      CommType::SEND,
      sync_op,
      use_calc_stream);
}

484
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
485
    const Place& place,
486 487 488 489 490
    int rank,
    CommType comm_type,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
491
      place, rank, comm_type, is_sync, use_calc_stream);
492 493
}

494 495 496 497 498 499 500 501 502 503 504 505
void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
  const std::string key =
      "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
  if (rank_ == 0) {
    std::vector<uint8_t> nccl_id_wrapper(
        reinterpret_cast<uint8_t*>(nccl_id),
        reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
    store_->set(key, nccl_id_wrapper);
  } else {
    const auto& nccl_id_wrapper = store_->get(key);
    std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
  }
506 507
}

508 509
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
                                          const std::string& place_key) {
W
Wen Sun 已提交
510 511 512 513
  if (place_to_comm_ctx_.size() > 0) {
    VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
  }

514 515
  ncclUniqueId nccl_id;
  if (rank_ == 0) {
516
    NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
517 518
  }
  BroadcastUniqueNCCLID(&nccl_id);
519

520 521 522
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << place_key
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
523

524 525 526 527
  auto* calc_ctx = static_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto comm_ctx = std::make_unique<phi::GPUContext>(place);
  ncclComm_t nccl_comm;
528
  NCCL_CHECK(platform::dynload::ncclCommInitRank(
529 530 531
      &nccl_comm, GetSize(), nccl_id, GetRank()));
  comm_ctx->set_nccl_comm(nccl_comm);

W
Wen Sun 已提交
532 533 534
  place_to_calc_event_.emplace(place_key, place);
  place_to_calc_ctx_.emplace(place_key, calc_ctx);
  place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
535 536

  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
537 538 539
  std::vector<phi::GPUContext*> comm_ctx_wrapper{
      place_to_comm_ctx_[place_key].get()};
  places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
540 541
}

W
Wen Sun 已提交
542
void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
543
  const std::string& key = GetKeyFromPlace(place);
W
Wen Sun 已提交
544 545 546 547 548
  auto& calc_event = place_to_calc_event_.at(key);
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
  calc_event.Record(calc_ctx);
  calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
549 550
}

551 552 553 554 555 556 557 558 559 560 561
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    Fn fn,
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& place = in_tensor.place();
  const auto& key = GetKeyFromPlace(place);

W
Wen Sun 已提交
562 563 564
  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
565
    CreateNCCLEnvCache(place, key);
566 567
  }

568
  if (!use_calc_stream) {
W
Wen Sun 已提交
569
    SyncCalcStream(place);
570
  }
571

572 573
  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

W
Wen Sun 已提交
574 575 576
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
577
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
W
Wen Sun 已提交
578
  fn(out_tensor, in_tensor, nccl_comm, nccl_stream);
579 580 581 582 583

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
      memory::RecordStream(in_tensor.Holder(), nccl_stream);
    }
W
Wen Sun 已提交
584
    task->UpdateWaitChain(*comm_ctx);
585 586 587
  }

  return task;
588 589
}

590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
    phi::DenseTensor* tensor,
    int rank,
    Fn fn,
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& place = tensor->place();
  const auto& key = GetKeyFromPlace(place);

  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
    CreateNCCLEnvCache(place, key);
  }

  if (!use_calc_stream) {
    SyncCalcStream(place);
  }

  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
  fn(tensor, rank, nccl_comm, nccl_stream);

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
      memory::RecordStream(tensor->Holder(), nccl_stream);
    }
    task->UpdateWaitChain(*comm_ctx);
  }

  return task;
}

629 630
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
W
Wen Sun 已提交
631
                       platform::DeviceEvent& nccl_event,         // NOLINT
632 633 634 635
                       std::vector<phi::GPUContext*>& dev_ctx) {  // NOLINT
  for (size_t i = 0; i < places.size(); ++i) {
    auto* default_ctx = static_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(places[i]));
W
Wen Sun 已提交
636 637
    nccl_event.Record(default_ctx);
    nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
B
Baibaifan 已提交
638
  }
639 640
}

641 642 643 644 645 646 647
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
    std::vector<Place> places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
      places, rank, comm_type, inputs);
648
}
649

650 651 652 653 654 655 656
ProcessGroupNCCL::NCCLTask::NCCLTask(
    const std::vector<Place>& places,
    int rank,
    CommType CommType,
    const std::vector<phi::DenseTensor>& inputs)
    : TaskStream(rank, inputs, CommType),
      comm_event_(places[0]),
W
Wen Sun 已提交
657
      task_place_(places[0]) {}
658

659 660 661
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
    const std::string& places_key, const std::vector<Place>& places) {
662 663
  PADDLE_ENFORCE_EQ(places_key.empty(),
                    false,
664 665 666 667
                    platform::errors::PreconditionNotMet(
                        "Not able to create/get the NCCL Communicator since "
                        "the GPU place are not known"));

668
  ncclUniqueId nccl_id;
669
  if (rank_ == 0) {
670
    NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
671
  }
672
  BroadcastUniqueNCCLID(&nccl_id);
673

674 675
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << places_key
676 677
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);

L
Leo Chen 已提交
678
  std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
679 680
  dev_ctx.resize(places.size());

681 682 683
  std::vector<phi::GPUContext*> dev_ctx_raw;
  dev_ctx_raw.resize(places.size());

684
  GroupStart();
685 686 687

  for (size_t i = 0; i < places.size(); ++i) {
    platform::CUDADeviceGuard guard(places[i]);
688

L
Leo Chen 已提交
689
    dev_ctx[i].reset(new phi::GPUContext(places[i]));
690
    ncclComm_t nccl_comm;
691
    NCCL_CHECK(platform::dynload::ncclCommInitRank(
692 693 694
        &nccl_comm, GetSize(), nccl_id, GetRank()));
    dev_ctx[i]->set_nccl_comm(nccl_comm);
    dev_ctx_raw[i] = dev_ctx[i].get();
695 696
  }

697
  GroupEnd();
698

699
  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
700 701 702 703 704 705
  place_to_calc_event_.emplace(places_key, places[0]);
  place_to_calc_ctx_.emplace(
      places_key,
      static_cast<phi::GPUContext*>(
          platform::DeviceContextPool::Instance().Get(places[0])));
  place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
706 707

  // These caches will be useful to process sync/wait/communicate
708
  places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
709 710 711 712
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
713
    std::vector<phi::DenseTensor>& inputs,
714 715 716
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType op_type) {
717 718 719 720 721
  const auto places = GetPlaceList(inputs);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
722
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
723 724 725 726
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
727 728
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
729 730 731 732 733 734

  auto task = CreateTask(places, rank_, op_type, inputs);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

S
ShenLiang 已提交
735 736
  {
    platform::NCCLGroupGuard nccl_guard;
737 738
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
739
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
740 741
      fn(inputs[i],
         outputs[i],
W
Wen Sun 已提交
742
         places_to_ctx_.at(key)[i]->nccl_comm(),
743
         nccl_stream);
744 745 746
    }
  }

S
ShenLiang 已提交
747
  if (FLAGS_use_stream_safe_cuda_allocator) {
748 749
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
S
ShenLiang 已提交
750
      memory::RecordStream(inputs[i].Holder(),
W
Wen Sun 已提交
751
                           places_to_ctx_.at(key)[i]->stream());
752 753 754 755 756
    }
  }

  for (size_t i = 0; i < inputs.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
757
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
758 759 760 761
  }
  return task;
}

B
Baibaifan 已提交
762 763
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
764 765 766
    std::vector<phi::DenseTensor>& tensors,
    Fn fn,
    int dst_rank,
767
    CommType op_type) {
B
Baibaifan 已提交
768 769 770 771 772
  const auto places = GetPlaceList(tensors);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
773
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
B
Baibaifan 已提交
774 775 776 777
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
778 779
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
B
Baibaifan 已提交
780 781 782 783 784 785

  auto task = CreateTask(places, rank_, op_type, tensors);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

786 787
  {
    platform::NCCLGroupGuard nccl_guard;
B
Baibaifan 已提交
788 789
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
790
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
791
      fn(tensors[i],
W
Wen Sun 已提交
792
         places_to_ctx_.at(key)[i]->nccl_comm(),
793 794
         nccl_stream,
         dst_rank);
B
Baibaifan 已提交
795 796 797
    }
  }

798
  if (FLAGS_use_stream_safe_cuda_allocator) {
B
Baibaifan 已提交
799 800
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
801
      memory::RecordStream(tensors[i].Holder(),
W
Wen Sun 已提交
802
                           places_to_ctx_.at(key)[i]->stream());
B
Baibaifan 已提交
803 804 805 806 807
    }
  }

  for (size_t i = 0; i < tensors.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
808
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
B
Baibaifan 已提交
809 810 811 812
  }
  return task;
}

813
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
814
    std::vector<phi::DenseTensor>& in_tensors,
815 816
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts) {
817
  PADDLE_ENFORCE_EQ(
818 819
      CheckTensorsInCudaPlace(in_tensors),
      true,
820
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
821
  return Collective(
822 823 824 825 826 827
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
828
        return platform::dynload::ncclAllReduce(
829 830 831
            input.data(),
            output.data(),
            input.numel(),
832
            platform::ToNCCLDataType(input.type()),
833 834 835
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream);
836 837
      },
      CommType::ALLREDUCE);
838 839 840
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
841
    std::vector<phi::DenseTensor>& in_tensors,
842 843
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts) {
844
  PADDLE_ENFORCE_EQ(
845 846
      CheckTensorsInCudaPlace(in_tensors),
      true,
847 848
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));

849
  return Collective(
850 851 852 853 854
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
855 856 857 858
          const gpuStream_t& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
        return platform::dynload::ncclBroadcast(
859 860 861 862 863 864 865
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
            stream);
866 867
      },
      CommType::BROADCAST);
868 869
}

870 871
void CheckTensorsInDifferentDevices(
    const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
B
Baibaifan 已提交
872
  PADDLE_ENFORCE_EQ(
873 874
      tensors.size() == 0,
      false,
B
Baibaifan 已提交
875 876
      platform::errors::InvalidArgument("Tensor list must be nonempty."));
  PADDLE_ENFORCE_LE(
877 878
      tensors.size(),
      num_devices,
B
Baibaifan 已提交
879 880 881 882 883 884
      platform::errors::InvalidArgument(
          "Tensor list mustn't be larger than the number of available GPUs."));

  std::set<Place> used_devices;

  for (const auto& t : tensors) {
885 886
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()),
                      true,
B
Baibaifan 已提交
887 888 889
                      platform::errors::InvalidArgument(
                          "Tensors must be CUDA and dense tensor."));

890
    const auto inserted = used_devices.insert(t.place()).second;
891 892
    PADDLE_ENFORCE_EQ(inserted,
                      true,
B
Baibaifan 已提交
893 894 895 896 897 898
                      platform::errors::InvalidArgument(
                          "Tensors must be on distinct GPU devices."));
  }
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
899
    std::vector<phi::DenseTensor>& tensors, int dst_rank) {
B
Baibaifan 已提交
900 901
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

902 903
  auto task = PointToPoint(
      tensors,
904 905 906
      [&](phi::DenseTensor& input,
          ncclComm_t comm,
          const gpuStream_t& stream,
907 908
          int dst_rank) {
        return platform::dynload::ncclSend(
909 910 911 912 913 914
            input.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            dst_rank,
            comm,
            stream);
915
      },
916 917
      dst_rank,
      CommType::SEND);
B
Baibaifan 已提交
918 919 920 921
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
922
    std::vector<phi::DenseTensor>& tensors, int src_rank) {
B
Baibaifan 已提交
923 924
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

925 926
  auto task = PointToPoint(
      tensors,
927 928 929
      [&](phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream,
930 931
          int src_rank) {
        return platform::dynload::ncclRecv(
932 933 934 935 936 937
            output.data(),
            output.numel(),
            platform::ToNCCLDataType(output.dtype()),
            src_rank,
            comm,
            stream);
938
      },
939 940
      src_rank,
      CommType::RECV);
941 942 943
  return task;
}

944
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
945 946
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
947
  PADDLE_ENFORCE_EQ(
948 949
      CheckTensorsInCudaPlace(in_tensors),
      true,
950 951
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
952 953
      CheckTensorsInCudaPlace(out_tensors),
      true,
954
      platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
955
  return Collective(
956 957 958 959 960 961
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
962
        return platform::dynload::ncclAllGather(
963 964 965 966 967 968
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
            stream);
969 970
      },
      CommType::ALLGATHER);
971 972
}

973 974
void* GetPointerByOffset(void* raw_pointer,
                         size_t offset,
975 976 977 978 979 980 981
                         experimental::DataType type) {
  if (type == experimental::DataType::FLOAT32) {
    return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT64) {
    return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
                                   offset);
982 983 984
  } else if (type == experimental::DataType::FLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
                                   offset);
985 986 987 988 989 990
  } else if (type == experimental::DataType::INT32) {
    return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT64) {
    return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
                                   offset);
991 992 993 994 995 996 997 998
  } else if (type == experimental::DataType::INT8) {
    return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::UINT8) {
    return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::BOOL) {
    return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
999
                                   offset);
1000 1001 1002
  } else if (type == experimental::DataType::BFLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
                                   offset);
1003 1004
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
1005
        "Datatype %s in NCCL is not supported.", type));
1006
  }
1007
  return nullptr;
1008 1009 1010
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
1011 1012
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
1013
  PADDLE_ENFORCE_EQ(
1014 1015
      CheckTensorsInCudaPlace(in_tensors),
      true,
1016 1017
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1018 1019
      CheckTensorsInCudaPlace(out_tensors),
      true,
1020 1021
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1022 1023 1024 1025 1026
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1027 1028
          const gpuStream_t& stream) {
        size_t offset = 0;
1029
        GroupStart();
1030 1031
        for (auto i = 0; i < size_; i++) {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
1032
              GetPointerByOffset(input.data(), offset, input.dtype()),
1033 1034 1035 1036 1037
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
1038
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1039
              GetPointerByOffset(output.data(), offset, input.dtype()),
1040 1041 1042 1043 1044
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
1045
          offset += input.numel() / size_;
1046
        }
1047
        GroupEnd();
1048
      },
1049 1050 1051
      CommType::ALLTOALL);
}

1052
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
1053
    std::vector<phi::DenseTensor>& in_tensors,
1054 1055
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceOptions& opts) {
1056
  PADDLE_ENFORCE_EQ(
1057 1058
      CheckTensorsInCudaPlace(in_tensors),
      true,
1059 1060
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1061 1062 1063 1064 1065 1066
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
1067
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
1068 1069 1070
            input.data(),
            output.data(),
            input.numel(),
1071
            platform::ToNCCLDataType(input.dtype()),
1072 1073 1074 1075
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
1076 1077 1078 1079 1080
      },
      CommType::REDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
1081
    std::vector<phi::DenseTensor>& in_tensors,
1082 1083
    std::vector<phi::DenseTensor>& out_tensors,
    const ScatterOptions& opts) {
1084
  PADDLE_ENFORCE_EQ(
1085 1086
      CheckTensorsInCudaPlace(in_tensors),
      true,
1087 1088
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1089 1090
      CheckTensorsInCudaPlace(out_tensors),
      true,
1091 1092
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1093 1094 1095 1096 1097
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1098 1099 1100
          const gpuStream_t& stream) {
        size_t offset = 0;
        if (rank_ == opts.root_rank) {
1101
          GroupStart();
1102 1103
          for (auto i = 0; i < size_; i++) {
            PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
1104
                GetPointerByOffset(input.data(), offset, input.dtype()),
1105 1106 1107 1108 1109
                input.numel() / size_,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
1110
            offset += input.numel() / size_;
1111 1112
          }
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1113 1114 1115 1116 1117
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1118
              stream));
1119
          GroupEnd();
1120 1121
        } else {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1122 1123 1124 1125 1126
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1127 1128 1129 1130 1131 1132
              stream));
        }
      },
      CommType::SCATTER);
}

1133 1134
}  //  namespace distributed
}  //  namespace paddle