ProcessGroupNCCL.cc 45.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
16

L
lilong12 已提交
17
#include "paddle/fluid/distributed/collective/Common.h"
18
#include "paddle/fluid/distributed/collective/utils.h"
19
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
20
#include "paddle/fluid/platform/place.h"
L
LiYuRio 已提交
21
#include "paddle/phi/api/lib/utils/allocator.h"
22 23 24 25 26 27 28 29 30

DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);

constexpr int64_t kWaitBlockTImeout = 10;

namespace paddle {
namespace distributed {

31 32 33 34 35 36 37
ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
                                     int rank,
                                     CommType comm_type,
                                     bool sync_op,
                                     bool use_calc_stream)
    : TaskStream(rank, comm_type, sync_op, use_calc_stream),
      comm_event_(place),
W
Wen Sun 已提交
38
      task_place_(place) {}
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

ProcessGroupNCCL::NCCLTask::~NCCLTask() {}

bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }

void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
    const phi::DeviceContext& ctx) {
  comm_event_.Record(&ctx);
}

// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
  // Warning here when use calc stream but also invoke waiting explicitly.
  if (UseCalcStream()) {
    VLOG(3) << "Warning: The communication is on calc stream, wait here is "
               "useless.";
    return true;
  }

W
Wen Sun 已提交
58 59 60
  const auto* calc_ctx =
      platform::DeviceContextPool::Instance().Get(task_place_);
  comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
61 62 63 64 65 66

  if (FLAGS_nccl_blocking_wait) {
    // NOTE(shenliang03): It will block host for sync
    while (!IsCompleted()) {
      std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
    }
67
  }
68

W
Wen Sun 已提交
69
  if (IsBlockCPUInWait()) {
70 71 72 73 74 75 76 77
    // If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
  }
  return true;
78 79
}

80 81 82 83 84 85 86
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
                                   int rank,
                                   int size,
                                   int gid)
87
    : ProcessGroupStream(rank, size, gid), store_(store) {}
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

void ProcessGroupNCCL::GroupStart() {
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
}

void ProcessGroupNCCL::GroupEnd() {
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}

const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
    const Place& place) const {
  return GetDeviceContext(place, /*use_calc_stream*/ false);
}

const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext(
    const Place& place, bool use_calc_stream) const {
  const std::string& key = GetKeyFromPlace(place);
  if (use_calc_stream) {
    const auto& iter = place_to_calc_ctx_.find(key);
    return *iter->second;
  } else {
    const auto& iter = place_to_comm_ctx_.find(key);
    PADDLE_ENFORCE_NE(
        iter,
        place_to_comm_ctx_.end(),
        platform::errors::NotFound(
            "Cannot find the device context in this process group."));
    return *iter->second;
  }
}

ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
  const std::string& key = GetKeyFromPlace(place);
  const auto& iter = place_to_comm_ctx_.find(key);
  PADDLE_ENFORCE_NE(
      iter,
      place_to_comm_ctx_.end(),
      platform::errors::NotFound(
          "Cannot find the NCCL commmunicator in this process group."));
  return iter->second->nccl_comm();
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
133 134
    int64_t offset,
    int64_t numel,
135 136
    bool sync_op,
    bool use_calc_stream) {
137 138 139
  // numel > 0 indicates the tensor need to be sliced
  const phi::DenseTensor& in_tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
140 141
  return Collective(
      out_tensor,
142 143 144 145 146
      in_tensor_maybe_partial,
      [](phi::DenseTensor* output,
         const phi::DenseTensor& input,
         ncclComm_t comm,
         gpuStream_t stream) {
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        return platform::dynload::ncclAllGather(
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
            stream);
      },
      CommType::ALLGATHER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
172
          gpuStream_t stream) {
173 174 175 176 177 178 179 180 181 182 183 184 185 186
        return platform::dynload::ncclAllReduce(
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream);
      },
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
                         const std::vector<int64_t>& size_on_each_rank,
                         int world_size) {
  int length_size_on_each_rank = size_on_each_rank.size();
  PADDLE_ENFORCE_EQ(
      length_size_on_each_rank,
      world_size,
      platform::errors::InvalidArgument(
          "The length of size_on_each_rank must be equal to world_size."));

  int64_t sum_size_on_each_rank =
      std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
  PADDLE_ENFORCE_EQ(
      sum_size_on_each_rank,
      tensor_dim[0],
      platform::errors::InvalidArgument(
          "The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op,
    bool use_calc_stream) {
  const phi::DDim& out_dim = out_tensor->dims();
  const phi::DDim& in_dim = in_tensor.dims();
  CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
  CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);

  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
          gpuStream_t stream) {
        int64_t in_row_size = input.numel() / in_dim[0],
                out_row_size = output->numel() / out_dim[0];
        int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
        phi::DenseTensor input_partial, output_partial;

        GroupStart();
        for (auto i = 0; i < size_; i++) {
          in_numel = in_size_each_rank[i] * in_row_size;
          input_partial = GetPartialTensor(input, in_offset, in_numel);
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
              input_partial.data(),
              in_numel,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
          in_offset += in_numel;

          out_numel = out_size_each_rank[i] * out_row_size;
          output_partial = GetPartialTensor(*output, out_offset, out_numel);
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
              output_partial.data(),
              out_numel,
              platform::ToNCCLDataType(output->dtype()),
              i,
              comm,
              stream));
          out_offset += out_numel;
        }
        GroupEnd();
      },
      CommType::ALLTOALL,
      sync_op,
      use_calc_stream);
}

261 262
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
    const BarrierOptions& opts) {
263 264 265 266 267
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
                    platform::errors::PreconditionNotMet(
                        "The barrier device id must greater or equal than 0."));
  platform::CUDAPlace place(opts.device_id);
268
  auto allocator = std::unique_ptr<phi::Allocator>(
269
      new paddle::experimental::DefaultAllocator(place));
270 271 272 273 274 275 276 277 278
  phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
  phi::DenseTensor barrier_tensor{allocator.get(), meta};

  auto task = AllReduce(&barrier_tensor,
                        barrier_tensor,
                        {},
                        /*sync_op*/ true,
                        /*use_calc_stream*/ false);
  auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
W
Wen Sun 已提交
279
  nccl_task->SetBlockCPUInWait();
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      out_tensor,
      in_tensor,
      [&](phi::DenseTensor* output,
          const phi::DenseTensor& input,
          ncclComm_t comm,
295
          gpuStream_t stream) {
296 297 298 299 300 301 302 303 304 305 306 307 308
        int root = opts.source_rank + opts.source_root;
        return platform::dynload::ncclBroadcast(
            input.data(),
            output->data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
            stream);
      },
      CommType::BROADCAST,
      sync_op,
      use_calc_stream);
309 310
}

311 312 313 314
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
    phi::DenseTensor* tensor,
    int src_rank,
    int64_t offset,
315
    int64_t numel,
316 317
    bool sync_op,
    bool use_calc_stream) {
318 319 320 321 322 323
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
324
  return PointToPoint(
325
      tensor,
326
      src_rank,
327 328 329 330
      [](phi::DenseTensor* output,
         int src,
         ncclComm_t comm,
         gpuStream_t stream) {
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
        return platform::dynload::ncclRecv(
            output->data(),
            output->numel(),
            platform::ToNCCLDataType(output->dtype()),
            src,
            comm,
            stream);
      },
      CommType::RECV,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
    phi::DenseTensor* tensor,
    int dst_rank,
    int64_t offset,
348
    int64_t numel,
349 350
    bool sync_op,
    bool use_calc_stream) {
351 352 353 354 355 356
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
357
  return PointToPoint(
358
      tensor,
359
      dst_rank,
360 361 362 363
      [](phi::DenseTensor* input,
         int dst,
         ncclComm_t comm,
         gpuStream_t stream) {
364 365 366 367 368 369 370 371 372 373 374 375 376
        return platform::dynload::ncclSend(
            input->data(),
            input->numel(),
            platform::ToNCCLDataType(input->dtype()),
            dst,
            comm,
            stream);
      },
      CommType::SEND,
      sync_op,
      use_calc_stream);
}

377
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
378
    const Place& place,
379 380 381 382 383
    int rank,
    CommType comm_type,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
384
      place, rank, comm_type, is_sync, use_calc_stream);
385 386
}

387 388 389 390 391 392 393 394 395 396 397 398
void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
  const std::string key =
      "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
  if (rank_ == 0) {
    std::vector<uint8_t> nccl_id_wrapper(
        reinterpret_cast<uint8_t*>(nccl_id),
        reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
    store_->set(key, nccl_id_wrapper);
  } else {
    const auto& nccl_id_wrapper = store_->get(key);
    std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
  }
399 400
}

401 402
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
                                          const std::string& place_key) {
W
Wen Sun 已提交
403 404 405 406
  if (place_to_comm_ctx_.size() > 0) {
    VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
  }

407 408 409 410 411
  ncclUniqueId nccl_id;
  if (rank_ == 0) {
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
  }
  BroadcastUniqueNCCLID(&nccl_id);
412

413 414 415
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << place_key
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
416

417 418 419 420 421 422 423 424
  auto* calc_ctx = static_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto comm_ctx = std::make_unique<phi::GPUContext>(place);
  ncclComm_t nccl_comm;
  NCCLCHECK(platform::dynload::ncclCommInitRank(
      &nccl_comm, GetSize(), nccl_id, GetRank()));
  comm_ctx->set_nccl_comm(nccl_comm);

W
Wen Sun 已提交
425 426 427
  place_to_calc_event_.emplace(place_key, place);
  place_to_calc_ctx_.emplace(place_key, calc_ctx);
  place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
428 429

  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
430 431 432
  std::vector<phi::GPUContext*> comm_ctx_wrapper{
      place_to_comm_ctx_[place_key].get()};
  places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
433 434
}

W
Wen Sun 已提交
435
void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
436
  const std::string& key = GetKeyFromPlace(place);
W
Wen Sun 已提交
437 438 439 440 441
  auto& calc_event = place_to_calc_event_.at(key);
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
  calc_event.Record(calc_ctx);
  calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
442 443
}

444 445 446 447 448 449 450 451 452 453 454
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    Fn fn,
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& place = in_tensor.place();
  const auto& key = GetKeyFromPlace(place);

W
Wen Sun 已提交
455 456 457
  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
458
    CreateNCCLEnvCache(place, key);
459 460
  }

461
  if (!use_calc_stream) {
W
Wen Sun 已提交
462
    SyncCalcStream(place);
463
  }
464

465 466
  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

W
Wen Sun 已提交
467 468 469
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
470
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
W
Wen Sun 已提交
471
  fn(out_tensor, in_tensor, nccl_comm, nccl_stream);
472 473 474 475 476

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
      memory::RecordStream(in_tensor.Holder(), nccl_stream);
    }
W
Wen Sun 已提交
477
    task->UpdateWaitChain(*comm_ctx);
478 479 480
  }

  return task;
481 482
}

483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
    phi::DenseTensor* tensor,
    int rank,
    Fn fn,
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& place = tensor->place();
  const auto& key = GetKeyFromPlace(place);

  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
    CreateNCCLEnvCache(place, key);
  }

  if (!use_calc_stream) {
    SyncCalcStream(place);
  }

  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
  fn(tensor, rank, nccl_comm, nccl_stream);

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
      memory::RecordStream(tensor->Holder(), nccl_stream);
    }
    task->UpdateWaitChain(*comm_ctx);
  }

  return task;
}

522
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
523
                                       std::vector<int64_t> tensor_shape) {
524
  int64_t len_size = (*split_sizes).size();
525 526 527 528 529 530
  if (len_size == 0) {
    PADDLE_ENFORCE_EQ(tensor_shape[0] % size_ == 0,
                      true,
                      platform::errors::InvalidArgument(
                          "Tensor's dim[0] must be divisible by group size "
                          "when split_sizes not given."));
531 532 533 534
    (*split_sizes)
        .insert((*split_sizes).end(),
                size_,
                static_cast<int64_t>(tensor_shape[0] / size_));
535 536 537 538 539 540 541
  } else {
    PADDLE_ENFORCE_EQ(
        len_size == size_,
        true,
        platform::errors::InvalidArgument(
            "The length of split_sizes must be equal to group size."));
    auto sum_size = std::accumulate(
542
        (*split_sizes).begin(), (*split_sizes).end(), static_cast<int64_t>(0));
543 544 545 546 547 548 549 550
    PADDLE_ENFORCE_EQ(
        sum_size == tensor_shape[0],
        true,
        platform::errors::InvalidArgument(
            "The sum of split_sizes must be equal to tensor's dim[0]."));
  }
}

551 552
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
W
Wen Sun 已提交
553
                       platform::DeviceEvent& nccl_event,         // NOLINT
554 555 556 557
                       std::vector<phi::GPUContext*>& dev_ctx) {  // NOLINT
  for (size_t i = 0; i < places.size(); ++i) {
    auto* default_ctx = static_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(places[i]));
W
Wen Sun 已提交
558 559
    nccl_event.Record(default_ctx);
    nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
B
Baibaifan 已提交
560
  }
561 562
}

563 564 565 566 567 568 569
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
    std::vector<Place> places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
      places, rank, comm_type, inputs);
570
}
571

572 573 574 575 576 577 578 579 580
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
    const std::vector<Place>& places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
      places, rank, comm_type, inputs, is_sync, use_calc_stream);
581 582
}

583 584 585 586 587 588 589
ProcessGroupNCCL::NCCLTask::NCCLTask(
    const std::vector<Place>& places,
    int rank,
    CommType CommType,
    const std::vector<phi::DenseTensor>& inputs)
    : TaskStream(rank, inputs, CommType),
      comm_event_(places[0]),
W
Wen Sun 已提交
590
      task_place_(places[0]) {}
591 592 593 594 595 596 597 598 599 600

ProcessGroupNCCL::NCCLTask::NCCLTask(
    const std::vector<Place>& places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs,
    bool sync_op,
    bool use_calc_stream)
    : TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream),
      comm_event_(places[0]),
W
Wen Sun 已提交
601
      task_place_(places[0]) {}
602

603 604 605
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
    const std::string& places_key, const std::vector<Place>& places) {
606 607
  PADDLE_ENFORCE_EQ(places_key.empty(),
                    false,
608 609 610 611
                    platform::errors::PreconditionNotMet(
                        "Not able to create/get the NCCL Communicator since "
                        "the GPU place are not known"));

612
  ncclUniqueId nccl_id;
613 614 615
  if (rank_ == 0) {
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
  }
616
  BroadcastUniqueNCCLID(&nccl_id);
617

618 619
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << places_key
620 621
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);

L
Leo Chen 已提交
622
  std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
623 624
  dev_ctx.resize(places.size());

625 626 627
  std::vector<phi::GPUContext*> dev_ctx_raw;
  dev_ctx_raw.resize(places.size());

628
  GroupStart();
629 630 631

  for (size_t i = 0; i < places.size(); ++i) {
    platform::CUDADeviceGuard guard(places[i]);
632

L
Leo Chen 已提交
633
    dev_ctx[i].reset(new phi::GPUContext(places[i]));
634 635 636 637 638
    ncclComm_t nccl_comm;
    NCCLCHECK(platform::dynload::ncclCommInitRank(
        &nccl_comm, GetSize(), nccl_id, GetRank()));
    dev_ctx[i]->set_nccl_comm(nccl_comm);
    dev_ctx_raw[i] = dev_ctx[i].get();
639 640
  }

641
  GroupEnd();
642

643
  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
644 645 646 647 648 649
  place_to_calc_event_.emplace(places_key, places[0]);
  place_to_calc_ctx_.emplace(
      places_key,
      static_cast<phi::GPUContext*>(
          platform::DeviceContextPool::Instance().Get(places[0])));
  place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
650 651

  // These caches will be useful to process sync/wait/communicate
652
  places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
653 654
}

655 656 657 658 659 660 661 662 663 664 665 666 667
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
    std::vector<phi::DenseTensor>& inputs,
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& places = GetPlaceList(inputs);
  const auto& key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
668
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
669 670 671 672
      CreateNCCLManagerCache(key, places);
    }
  }

673
  if (!use_calc_stream) {
W
Wen Sun 已提交
674 675
    SyncDefaultStream(
        places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
676
  }
677

678 679
  auto task =
      CreateTask(places, rank_, comm_type, inputs, sync_op, use_calc_stream);
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694

  platform::CUDADeviceGuard cuda_guard;

  {
    platform::NCCLGroupGuard nccl_guard;
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);

      gpuStream_t nccl_stream;
      if (use_calc_stream) {
        nccl_stream =
            static_cast<phi::GPUContext*>(
                platform::DeviceContextPool::Instance().Get(places[i]))
                ->stream();
      } else {
W
Wen Sun 已提交
695
        nccl_stream = places_to_ctx_.at(key)[i]->stream();
696 697
      }

698 699
      fn(inputs[i],
         outputs[i],
W
Wen Sun 已提交
700
         places_to_ctx_.at(key)[i]->nccl_comm(),
701
         nccl_stream);
702 703 704 705 706 707 708 709 710 711 712 713 714 715
    }
  }

  if (FLAGS_use_stream_safe_cuda_allocator) {
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);

      gpuStream_t nccl_stream;
      if (use_calc_stream) {
        nccl_stream =
            static_cast<phi::GPUContext*>(
                platform::DeviceContextPool::Instance().Get(places[i]))
                ->stream();
      } else {
W
Wen Sun 已提交
716
        nccl_stream = places_to_ctx_.at(key)[i]->stream();
717 718 719 720 721 722 723 724 725 726
      }

      memory::RecordStream(inputs[i].Holder(), nccl_stream);
    }
  }

  // Adding stream event dependency only when use comm stream
  if (!use_calc_stream) {
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
727
      task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
728 729 730 731 732 733
    }
  }

  return task;
}

734 735
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
736
    std::vector<phi::DenseTensor>& inputs,
737 738 739
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType op_type) {
740 741 742 743 744
  const auto places = GetPlaceList(inputs);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
745
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
746 747 748 749
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
750 751
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
752 753 754 755 756 757

  auto task = CreateTask(places, rank_, op_type, inputs);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

S
ShenLiang 已提交
758 759
  {
    platform::NCCLGroupGuard nccl_guard;
760 761
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
762
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
763 764
      fn(inputs[i],
         outputs[i],
W
Wen Sun 已提交
765
         places_to_ctx_.at(key)[i]->nccl_comm(),
766
         nccl_stream);
767 768 769
    }
  }

S
ShenLiang 已提交
770
  if (FLAGS_use_stream_safe_cuda_allocator) {
771 772
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
S
ShenLiang 已提交
773
      memory::RecordStream(inputs[i].Holder(),
W
Wen Sun 已提交
774
                           places_to_ctx_.at(key)[i]->stream());
775 776 777 778 779
    }
  }

  for (size_t i = 0; i < inputs.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
780
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
781 782 783 784
  }
  return task;
}

785 786
template <typename Fn>
void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
787 788
                                  phi::DenseTensor* out,
                                  Fn fn,
789 790 791
                                  CommType op_type) {
  std::vector<Place> places;
  places.push_back(in->place());
792
  const std::string& key = GetKeyFromPlaces(places);
793 794 795

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
796
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
797 798 799 800
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
801 802
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
803 804 805 806 807 808

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

  if (FLAGS_use_stream_safe_cuda_allocator) {
    cuda_guard.SetDevice(places[0]);
W
Wen Sun 已提交
809
    memory::RecordStream(in->Holder(), places_to_ctx_.at(key)[0]->stream());
810 811 812 813 814
  }

  {
    platform::NCCLGroupGuard nccl_guard;
    cuda_guard.SetDevice(places[0]);
W
Wen Sun 已提交
815 816
    const auto& nccl_stream = places_to_ctx_.at(key)[0]->stream();
    fn(in, out, places_to_ctx_.at(key)[0]->nccl_comm(), nccl_stream);
817 818 819 820 821
  }

  cuda_guard.SetDevice(places[0]);
}

822 823 824 825 826 827 828 829 830 831 832 833 834
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
    std::vector<phi::DenseTensor>& tensors,
    Fn fn,
    int dst_rank,
    CommType op_type,
    bool sync_op,
    bool use_calc_stream) {
  const auto& places = GetPlaceList(tensors);
  const auto& key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
835
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
836 837 838 839 840
      CreateNCCLManagerCache(key, places);
    }
  }

  if (!use_calc_stream) {
W
Wen Sun 已提交
841 842
    SyncDefaultStream(
        places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
843 844 845 846 847 848 849
  }

  auto task =
      CreateTask(places, rank_, op_type, tensors, sync_op, use_calc_stream);

  platform::CUDADeviceGuard cuda_guard;

850 851
  {
    platform::NCCLGroupGuard nccl_guard;
852 853 854 855 856 857 858 859 860
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
      gpuStream_t nccl_stream;
      if (use_calc_stream) {
        nccl_stream =
            static_cast<phi::GPUContext*>(
                platform::DeviceContextPool::Instance().Get(places[i]))
                ->stream();
      } else {
W
Wen Sun 已提交
861
        nccl_stream = places_to_ctx_.at(key)[i]->stream();
862
      }
863
      fn(tensors[i],
W
Wen Sun 已提交
864
         places_to_ctx_.at(key)[i]->nccl_comm(),
865 866
         nccl_stream,
         dst_rank);
867 868 869
    }
  }

870
  if (FLAGS_use_stream_safe_cuda_allocator) {
871 872 873 874 875 876 877 878 879
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
      gpuStream_t nccl_stream;
      if (use_calc_stream) {
        nccl_stream =
            static_cast<phi::GPUContext*>(
                platform::DeviceContextPool::Instance().Get(places[i]))
                ->stream();
      } else {
W
Wen Sun 已提交
880
        nccl_stream = places_to_ctx_.at(key)[i]->stream();
881
      }
882
      memory::RecordStream(tensors[i].Holder(), nccl_stream);
883 884 885 886 887 888
    }
  }

  if (!use_calc_stream) {
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
889
      task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
890 891 892 893 894 895
    }
  }

  return task;
}

B
Baibaifan 已提交
896 897
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
898 899 900
    std::vector<phi::DenseTensor>& tensors,
    Fn fn,
    int dst_rank,
901
    CommType op_type) {
B
Baibaifan 已提交
902 903 904 905 906
  const auto places = GetPlaceList(tensors);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
907
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
B
Baibaifan 已提交
908 909 910 911
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
912 913
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
B
Baibaifan 已提交
914 915 916 917 918 919

  auto task = CreateTask(places, rank_, op_type, tensors);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

920 921
  {
    platform::NCCLGroupGuard nccl_guard;
B
Baibaifan 已提交
922 923
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
924
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
925
      fn(tensors[i],
W
Wen Sun 已提交
926
         places_to_ctx_.at(key)[i]->nccl_comm(),
927 928
         nccl_stream,
         dst_rank);
B
Baibaifan 已提交
929 930 931
    }
  }

932
  if (FLAGS_use_stream_safe_cuda_allocator) {
B
Baibaifan 已提交
933 934
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
935
      memory::RecordStream(tensors[i].Holder(),
W
Wen Sun 已提交
936
                           places_to_ctx_.at(key)[i]->stream());
B
Baibaifan 已提交
937 938 939 940 941
    }
  }

  for (size_t i = 0; i < tensors.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
942
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
B
Baibaifan 已提交
943 944 945 946
  }
  return task;
}

947
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
948
    std::vector<phi::DenseTensor>& in_tensors,
949 950
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts) {
951
  PADDLE_ENFORCE_EQ(
952 953
      CheckTensorsInCudaPlace(in_tensors),
      true,
954
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
955
  return Collective(
956 957 958 959 960 961
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
962
        return platform::dynload::ncclAllReduce(
963 964 965
            input.data(),
            output.data(),
            input.numel(),
966
            platform::ToNCCLDataType(input.type()),
967 968 969
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream);
970 971
      },
      CommType::ALLREDUCE);
972 973 974
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
975
    std::vector<phi::DenseTensor>& in_tensors,
976 977
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts) {
978
  PADDLE_ENFORCE_EQ(
979 980
      CheckTensorsInCudaPlace(in_tensors),
      true,
981 982
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));

983
  return Collective(
984 985 986 987 988
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
989 990 991 992
          const gpuStream_t& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
        return platform::dynload::ncclBroadcast(
993 994 995 996 997 998 999
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
            stream);
1000 1001
      },
      CommType::BROADCAST);
1002 1003
}

1004 1005
void CheckTensorsInDifferentDevices(
    const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
B
Baibaifan 已提交
1006
  PADDLE_ENFORCE_EQ(
1007 1008
      tensors.size() == 0,
      false,
B
Baibaifan 已提交
1009 1010
      platform::errors::InvalidArgument("Tensor list must be nonempty."));
  PADDLE_ENFORCE_LE(
1011 1012
      tensors.size(),
      num_devices,
B
Baibaifan 已提交
1013 1014 1015 1016 1017 1018
      platform::errors::InvalidArgument(
          "Tensor list mustn't be larger than the number of available GPUs."));

  std::set<Place> used_devices;

  for (const auto& t : tensors) {
1019 1020
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()),
                      true,
B
Baibaifan 已提交
1021 1022 1023
                      platform::errors::InvalidArgument(
                          "Tensors must be CUDA and dense tensor."));

1024
    const auto inserted = used_devices.insert(t.place()).second;
1025 1026
    PADDLE_ENFORCE_EQ(inserted,
                      true,
B
Baibaifan 已提交
1027 1028 1029 1030 1031 1032
                      platform::errors::InvalidArgument(
                          "Tensors must be on distinct GPU devices."));
  }
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
1033
    std::vector<phi::DenseTensor>& tensors, int dst_rank) {
B
Baibaifan 已提交
1034 1035
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

1036 1037
  auto task = PointToPoint(
      tensors,
1038 1039 1040
      [&](phi::DenseTensor& input,
          ncclComm_t comm,
          const gpuStream_t& stream,
1041 1042
          int dst_rank) {
        return platform::dynload::ncclSend(
1043 1044 1045 1046 1047 1048
            input.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            dst_rank,
            comm,
            stream);
1049
      },
1050 1051
      dst_rank,
      CommType::SEND);
B
Baibaifan 已提交
1052 1053 1054 1055
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
1056
    std::vector<phi::DenseTensor>& tensors, int src_rank) {
B
Baibaifan 已提交
1057 1058
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

1059 1060
  auto task = PointToPoint(
      tensors,
1061 1062 1063
      [&](phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream,
1064 1065
          int src_rank) {
        return platform::dynload::ncclRecv(
1066 1067 1068 1069 1070 1071
            output.data(),
            output.numel(),
            platform::ToNCCLDataType(output.dtype()),
            src_rank,
            comm,
            stream);
1072
      },
1073 1074
      src_rank,
      CommType::RECV);
1075 1076 1077
  return task;
}

1078
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
1079 1080
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
1081
  PADDLE_ENFORCE_EQ(
1082 1083
      CheckTensorsInCudaPlace(in_tensors),
      true,
1084 1085
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1086 1087
      CheckTensorsInCudaPlace(out_tensors),
      true,
1088
      platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
1089
  return Collective(
1090 1091 1092 1093 1094 1095
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
1096
        return platform::dynload::ncclAllGather(
1097 1098 1099 1100 1101 1102
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
            stream);
1103 1104
      },
      CommType::ALLGATHER);
1105 1106
}

1107 1108
void* GetPointerByOffset(void* raw_pointer,
                         size_t offset,
1109 1110 1111 1112 1113 1114 1115
                         experimental::DataType type) {
  if (type == experimental::DataType::FLOAT32) {
    return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT64) {
    return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
                                   offset);
1116 1117 1118
  } else if (type == experimental::DataType::FLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
                                   offset);
1119 1120 1121 1122 1123 1124
  } else if (type == experimental::DataType::INT32) {
    return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT64) {
    return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
                                   offset);
1125 1126 1127 1128 1129 1130 1131 1132
  } else if (type == experimental::DataType::INT8) {
    return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::UINT8) {
    return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::BOOL) {
    return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
1133
                                   offset);
1134 1135 1136
  } else if (type == experimental::DataType::BFLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
                                   offset);
1137 1138
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
1139
        "Datatype %s in NCCL is not supported.", type));
1140
  }
1141
  return nullptr;
1142 1143 1144
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
1145 1146
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
1147
  PADDLE_ENFORCE_EQ(
1148 1149
      CheckTensorsInCudaPlace(in_tensors),
      true,
1150 1151
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1152 1153
      CheckTensorsInCudaPlace(out_tensors),
      true,
1154 1155
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1156 1157 1158 1159 1160
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1161 1162
          const gpuStream_t& stream) {
        size_t offset = 0;
1163
        GroupStart();
1164 1165
        for (auto i = 0; i < size_; i++) {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
1166
              GetPointerByOffset(input.data(), offset, input.dtype()),
1167 1168 1169 1170 1171
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
1172
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1173
              GetPointerByOffset(output.data(), offset, input.dtype()),
1174 1175 1176 1177 1178
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
1179
          offset += input.numel() / size_;
1180
        }
1181
        GroupEnd();
1182
      },
1183 1184 1185
      CommType::ALLTOALL);
}

1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCudaPlace(in_tensors),
      true,
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCudaPlace(out_tensors),
      true,
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
        size_t offset = 0;
1207
        GroupStart();
1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224
        for (auto i = 0; i < size_; i++) {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
              GetPointerByOffset(input.data(), offset, input.dtype()),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
              GetPointerByOffset(output.data(), offset, input.dtype()),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
          offset += input.numel() / size_;
        }
1225
        GroupEnd();
1226 1227 1228 1229 1230 1231
      },
      CommType::ALLTOALL,
      sync_op,
      use_calc_stream);
}

1232
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
1233
    std::vector<phi::DenseTensor>& in_tensors,
1234 1235
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceOptions& opts) {
1236
  PADDLE_ENFORCE_EQ(
1237 1238
      CheckTensorsInCudaPlace(in_tensors),
      true,
1239 1240
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1241 1242 1243 1244 1245 1246
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
1247
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
1248 1249 1250
            input.data(),
            output.data(),
            input.numel(),
1251
            platform::ToNCCLDataType(input.dtype()),
1252 1253 1254 1255
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
1256 1257 1258 1259
      },
      CommType::REDUCE);
}

1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCudaPlace(in_tensors),
      true,
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
      },
      CommType::REDUCE,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
        if (FLAGS_use_stream_safe_cuda_allocator) {
          platform::CUDADeviceGuard cuda_guard;
          cuda_guard.SetDevice(output.place());
          memory::RecordStream(output.Holder(), stream);
        }
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
            input.data(),
            output.data(),
            output.numel(),
            platform::ToNCCLDataType(input.dtype()),
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream));
      },
      CommType::REDUCE_SCATTER,
      sync_op,
      use_calc_stream);
}

1324
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
1325
    std::vector<phi::DenseTensor>& in_tensors,
1326 1327
    std::vector<phi::DenseTensor>& out_tensors,
    const ScatterOptions& opts) {
1328
  PADDLE_ENFORCE_EQ(
1329 1330
      CheckTensorsInCudaPlace(in_tensors),
      true,
1331 1332
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1333 1334
      CheckTensorsInCudaPlace(out_tensors),
      true,
1335 1336
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1337 1338 1339 1340 1341
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1342 1343 1344
          const gpuStream_t& stream) {
        size_t offset = 0;
        if (rank_ == opts.root_rank) {
1345
          GroupStart();
1346 1347
          for (auto i = 0; i < size_; i++) {
            PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
1348
                GetPointerByOffset(input.data(), offset, input.dtype()),
1349 1350 1351 1352 1353
                input.numel() / size_,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
1354
            offset += input.numel() / size_;
1355 1356
          }
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1357 1358 1359 1360 1361
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1362
              stream));
1363
          GroupEnd();
1364 1365
        } else {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1366 1367 1368 1369 1370
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1371 1372 1373 1374 1375 1376
              stream));
        }
      },
      CommType::SCATTER);
}

1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    const ScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCudaPlace(in_tensors),
      true,
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
      CheckTensorsInCudaPlace(out_tensors),
      true,
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
        PADDLE_ENFORCE_EQ(
            output.numel(),
            input.numel() / size_,
            platform::errors::InvalidArgument(
                "Input and output tensors should have the same shape."));
        size_t offset = 0;
        if (rank_ == opts.root_rank) {
1405
          GroupStart();
1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422
          for (auto i = 0; i < size_; i++) {
            PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
                GetPointerByOffset(input.data(), offset, input.dtype()),
                input.numel() / size_,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
            offset += input.numel() / size_;
          }
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
              stream));
1423
          GroupEnd();
1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438
        } else {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
              stream));
        }
      },
      CommType::SCATTER,
      sync_op,
      use_calc_stream);
}

1439 1440
}  //  namespace distributed
}  //  namespace paddle