ProcessGroupNCCL.cc 35.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
16

L
lilong12 已提交
17
#include "paddle/fluid/distributed/collective/Common.h"
18
#include "paddle/fluid/distributed/collective/NCCLTools.h"
19
#include "paddle/fluid/distributed/collective/utils.h"
20
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
21
#include "paddle/fluid/platform/place.h"
L
LiYuRio 已提交
22
#include "paddle/phi/api/lib/utils/allocator.h"
23 24 25 26 27 28 29 30 31

DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);

constexpr int64_t kWaitBlockTImeout = 10;

namespace paddle {
namespace distributed {

32 33 34 35 36 37 38
ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
                                     int rank,
                                     CommType comm_type,
                                     bool sync_op,
                                     bool use_calc_stream)
    : TaskStream(rank, comm_type, sync_op, use_calc_stream),
      comm_event_(place),
W
Wen Sun 已提交
39
      task_place_(place) {}
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

ProcessGroupNCCL::NCCLTask::~NCCLTask() {}

bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }

void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
    const phi::DeviceContext& ctx) {
  comm_event_.Record(&ctx);
}

// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
  // Warning here when use calc stream but also invoke waiting explicitly.
  if (UseCalcStream()) {
    VLOG(3) << "Warning: The communication is on calc stream, wait here is "
               "useless.";
    return true;
  }

W
Wen Sun 已提交
59 60 61
  const auto* calc_ctx =
      platform::DeviceContextPool::Instance().Get(task_place_);
  comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
62 63 64 65 66 67

  if (FLAGS_nccl_blocking_wait) {
    // NOTE(shenliang03): It will block host for sync
    while (!IsCompleted()) {
      std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
    }
68
  }
69

W
Wen Sun 已提交
70
  if (IsBlockCPUInWait()) {
71 72 73 74 75 76 77 78
    // If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
  }
  return true;
79 80
}

81 82 83 84 85 86 87
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
                                   int rank,
                                   int size,
                                   int gid)
88
    : ProcessGroupStream(rank, size, gid), store_(store) {}
89 90

void ProcessGroupNCCL::GroupStart() {
91
  NCCL_CHECK(platform::dynload::ncclGroupStart());
92 93 94
}

void ProcessGroupNCCL::GroupEnd() {
95
  NCCL_CHECK(platform::dynload::ncclGroupEnd());
96 97
}

98
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
99 100 101 102
    const Place& place) const {
  return GetDeviceContext(place, /*use_calc_stream*/ false);
}

103
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
104 105 106 107
    const Place& place, bool use_calc_stream) const {
  const std::string& key = GetKeyFromPlace(place);
  if (use_calc_stream) {
    const auto& iter = place_to_calc_ctx_.find(key);
108
    return iter->second;
109 110 111 112 113 114 115
  } else {
    const auto& iter = place_to_comm_ctx_.find(key);
    PADDLE_ENFORCE_NE(
        iter,
        place_to_comm_ctx_.end(),
        platform::errors::NotFound(
            "Cannot find the device context in this process group."));
116
    return iter->second.get();
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  }
}

ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
  const std::string& key = GetKeyFromPlace(place);
  const auto& iter = place_to_comm_ctx_.find(key);
  PADDLE_ENFORCE_NE(
      iter,
      place_to_comm_ctx_.end(),
      platform::errors::NotFound(
          "Cannot find the NCCL commmunicator in this process group."));
  return iter->second->nccl_comm();
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
134 135
    int64_t offset,
    int64_t numel,
136 137
    bool sync_op,
    bool use_calc_stream) {
138 139 140
  // numel > 0 indicates the tensor need to be sliced
  const phi::DenseTensor& in_tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
141 142
  StaticCheckTensorsGatherLikeShape(
      *out_tensor, in_tensor_maybe_partial, rank_, size_);
143 144
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
145
        NCCL_CHECK(platform::dynload::ncclAllGather(
146 147 148 149
            in_tensor_maybe_partial.data(),
            out_tensor->data(),
            in_tensor_maybe_partial.numel(),
            platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
150
            comm,
151
            stream));
152
      },
153
      in_tensor_maybe_partial,
154 155 156 157 158 159 160 161 162 163 164
      CommType::ALLGATHER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
165
  StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
166 167
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
168
        NCCL_CHECK(platform::dynload::ncclAllReduce(
169 170 171 172
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
173 174
            ToNCCLRedType(opts.reduce_op),
            comm,
175
            stream));
176
      },
177
      in_tensor,
178 179 180 181 182
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
                         const std::vector<int64_t>& size_on_each_rank,
                         int world_size) {
  int length_size_on_each_rank = size_on_each_rank.size();
  PADDLE_ENFORCE_EQ(
      length_size_on_each_rank,
      world_size,
      platform::errors::InvalidArgument(
          "The length of size_on_each_rank must be equal to world_size."));

  int64_t sum_size_on_each_rank =
      std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
  PADDLE_ENFORCE_EQ(
      sum_size_on_each_rank,
      tensor_dim[0],
      platform::errors::InvalidArgument(
          "The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op,
    bool use_calc_stream) {
  const phi::DDim& out_dim = out_tensor->dims();
  const phi::DDim& in_dim = in_tensor.dims();
  CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
  CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);

214 215 216 217 218 219 220 221 222
  // NOTE: Since `all_to_all` needs other processes's participation, it cannot
  // simply be covered by static checks. Factors are set to 0 here to skip the
  // shape check. Its shape check will be done by dynamic checks in debug mode.
  StaticCheckTensors(*out_tensor,
                     in_tensor,
                     rank_,
                     size_,
                     /*out_size_factor*/ 0,
                     /*in_size_factor*/ 0);
223 224 225 226
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
        int64_t in_row_size = in_tensor.numel() / in_dim[0],
                out_row_size = out_tensor->numel() / out_dim[0];
227 228 229 230 231 232
        int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
        phi::DenseTensor input_partial, output_partial;

        GroupStart();
        for (auto i = 0; i < size_; i++) {
          in_numel = in_size_each_rank[i] * in_row_size;
233
          input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
234
          NCCL_CHECK(platform::dynload::ncclSend(
235 236
              input_partial.data(),
              in_numel,
237
              platform::ToNCCLDataType(input_partial.dtype()),
238 239 240 241 242 243
              i,
              comm,
              stream));
          in_offset += in_numel;

          out_numel = out_size_each_rank[i] * out_row_size;
244
          output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
245
          NCCL_CHECK(platform::dynload::ncclRecv(
246 247
              output_partial.data(),
              out_numel,
248
              platform::ToNCCLDataType(output_partial.dtype()),
249 250 251 252 253 254 255
              i,
              comm,
              stream));
          out_offset += out_numel;
        }
        GroupEnd();
      },
256
      in_tensor,
257 258 259 260 261
      CommType::ALLTOALL,
      sync_op,
      use_calc_stream);
}

262 263
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
    const BarrierOptions& opts) {
264 265 266 267 268
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
                    platform::errors::PreconditionNotMet(
                        "The barrier device id must greater or equal than 0."));
  platform::CUDAPlace place(opts.device_id);
269
  auto allocator = std::unique_ptr<phi::Allocator>(
270
      new paddle::experimental::DefaultAllocator(place));
271 272 273 274 275 276 277 278 279
  phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
  phi::DenseTensor barrier_tensor{allocator.get(), meta};

  auto task = AllReduce(&barrier_tensor,
                        barrier_tensor,
                        {},
                        /*sync_op*/ true,
                        /*use_calc_stream*/ false);
  auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
W
Wen Sun 已提交
280
  nccl_task->SetBlockCPUInWait();
281 282 283 284 285 286 287 288 289
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
290
  StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
291 292
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
293
        int root = opts.source_rank + opts.source_root;
294
        NCCL_CHECK(platform::dynload::ncclBroadcast(
295 296 297 298
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
299 300
            root,
            comm,
301
            stream));
302
      },
303
      in_tensor,
304 305 306
      CommType::BROADCAST,
      sync_op,
      use_calc_stream);
307 308
}

309 310 311 312 313 314
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
315
  StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
316 317
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
318
        NCCL_CHECK(platform::dynload::ncclReduce(
319 320 321 322
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
323 324 325 326 327
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
      },
328
      in_tensor,
329 330 331 332 333 334 335 336 337 338 339
      CommType::REDUCE,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
340
  StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
341 342
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
343
        NCCL_CHECK(platform::dynload::ncclReduceScatter(
344 345 346 347
            in_tensor.data(),
            out_tensor->data(),
            out_tensor->numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
348 349 350 351
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream));
      },
352
      in_tensor,
353 354 355 356 357 358 359 360 361 362 363
      CommType::REDUCE_SCATTER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
364
  StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
365 366 367
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
        int64_t numel = in_tensor.numel() / size_;
368 369 370 371 372
        if (rank_ == opts.root_rank) {
          int64_t offset = 0;
          phi::DenseTensor partial_tensor;
          GroupStart();
          for (auto i = 0; i < size_; i++) {
373
            partial_tensor = GetPartialTensor(in_tensor, offset, numel);
374 375 376
            NCCL_CHECK(platform::dynload::ncclSend(
                partial_tensor.data(),
                numel,
377
                platform::ToNCCLDataType(partial_tensor.dtype()),
378 379 380 381 382 383
                i,
                comm,
                stream));
            offset += numel;
          }
          NCCL_CHECK(platform::dynload::ncclRecv(
384
              out_tensor->data(),
385
              numel,
386
              platform::ToNCCLDataType(out_tensor->dtype()),
387 388 389 390 391 392
              opts.root_rank,
              comm,
              stream));
          GroupEnd();
        } else {
          NCCL_CHECK(platform::dynload::ncclRecv(
393
              out_tensor->data(),
394
              numel,
395
              platform::ToNCCLDataType(out_tensor->dtype()),
396 397 398 399 400
              opts.root_rank,
              comm,
              stream));
        }
      },
401
      in_tensor,
402 403 404 405 406
      CommType::SCATTER,
      sync_op,
      use_calc_stream);
}

407 408 409 410
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
    phi::DenseTensor* tensor,
    int src_rank,
    int64_t offset,
411
    int64_t numel,
412 413
    bool sync_op,
    bool use_calc_stream) {
414 415 416 417 418 419
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
420 421

  StaticCheckTensor(*tensor, rank_, size_);
422 423
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
424
        NCCL_CHECK(platform::dynload::ncclRecv(
425 426 427 428
            tensor->data(),
            tensor->numel(),
            platform::ToNCCLDataType(tensor->dtype()),
            src_rank,
429
            comm,
430
            stream));
431
      },
432
      *tensor,
433 434 435 436 437 438
      CommType::RECV,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
439
    const phi::DenseTensor& tensor,
440 441
    int dst_rank,
    int64_t offset,
442
    int64_t numel,
443 444
    bool sync_op,
    bool use_calc_stream) {
445
  // numel > 0 indicates the tensor need to be sliced
446 447
  const phi::DenseTensor& tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
448 449

  StaticCheckTensor(tensor_maybe_partial, rank_, size_);
450 451
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
452
        NCCL_CHECK(platform::dynload::ncclSend(
453 454 455 456
            tensor_maybe_partial.data(),
            tensor_maybe_partial.numel(),
            platform::ToNCCLDataType(tensor_maybe_partial.dtype()),
            dst_rank,
457
            comm,
458
            stream));
459
      },
460
      tensor_maybe_partial,
461 462 463 464 465
      CommType::SEND,
      sync_op,
      use_calc_stream);
}

466
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
467
    const Place& place,
468 469 470 471 472
    int rank,
    CommType comm_type,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
473
      place, rank, comm_type, is_sync, use_calc_stream);
474 475
}

476 477 478 479 480 481 482 483 484 485 486 487
void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
  const std::string key =
      "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
  if (rank_ == 0) {
    std::vector<uint8_t> nccl_id_wrapper(
        reinterpret_cast<uint8_t*>(nccl_id),
        reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
    store_->set(key, nccl_id_wrapper);
  } else {
    const auto& nccl_id_wrapper = store_->get(key);
    std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
  }
488 489
}

490 491
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
                                          const std::string& place_key) {
W
Wen Sun 已提交
492 493 494 495
  if (place_to_comm_ctx_.size() > 0) {
    VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
  }

496 497
  ncclUniqueId nccl_id;
  if (rank_ == 0) {
498
    NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
499 500
  }
  BroadcastUniqueNCCLID(&nccl_id);
501

502 503 504
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << place_key
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
505

506 507 508 509
  auto* calc_ctx = static_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto comm_ctx = std::make_unique<phi::GPUContext>(place);
  ncclComm_t nccl_comm;
510
  NCCL_CHECK(platform::dynload::ncclCommInitRank(
511 512 513
      &nccl_comm, GetSize(), nccl_id, GetRank()));
  comm_ctx->set_nccl_comm(nccl_comm);

W
Wen Sun 已提交
514 515 516
  place_to_calc_event_.emplace(place_key, place);
  place_to_calc_ctx_.emplace(place_key, calc_ctx);
  place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
517 518

  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
519 520 521
  std::vector<phi::GPUContext*> comm_ctx_wrapper{
      place_to_comm_ctx_[place_key].get()};
  places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
522 523
}

W
Wen Sun 已提交
524
void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
525
  const std::string& key = GetKeyFromPlace(place);
W
Wen Sun 已提交
526 527 528 529 530
  auto& calc_event = place_to_calc_event_.at(key);
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
  calc_event.Record(calc_ctx);
  calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
531 532
}

533 534 535
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
    std::function<void(ncclComm_t, gpuStream_t)> fn,
    const phi::DenseTensor& tensor,
536 537 538
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
539
  const auto& place = tensor.place();
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
  const auto& key = GetKeyFromPlace(place);

  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
    CreateNCCLEnvCache(place, key);
  }

  if (!use_calc_stream) {
    SyncCalcStream(place);
  }

  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
558
  fn(nccl_comm, nccl_stream);
559 560 561

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
562
      memory::RecordStream(tensor.Holder(), nccl_stream);
563 564 565 566 567 568 569
    }
    task->UpdateWaitChain(*comm_ctx);
  }

  return task;
}

570 571
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
W
Wen Sun 已提交
572
                       platform::DeviceEvent& nccl_event,         // NOLINT
573 574 575 576
                       std::vector<phi::GPUContext*>& dev_ctx) {  // NOLINT
  for (size_t i = 0; i < places.size(); ++i) {
    auto* default_ctx = static_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(places[i]));
W
Wen Sun 已提交
577 578
    nccl_event.Record(default_ctx);
    nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
B
Baibaifan 已提交
579
  }
580 581
}

582 583 584 585 586 587 588
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
    std::vector<Place> places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
      places, rank, comm_type, inputs);
589
}
590

591 592 593 594 595 596 597
ProcessGroupNCCL::NCCLTask::NCCLTask(
    const std::vector<Place>& places,
    int rank,
    CommType CommType,
    const std::vector<phi::DenseTensor>& inputs)
    : TaskStream(rank, inputs, CommType),
      comm_event_(places[0]),
W
Wen Sun 已提交
598
      task_place_(places[0]) {}
599

600 601 602
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
    const std::string& places_key, const std::vector<Place>& places) {
603 604
  PADDLE_ENFORCE_EQ(places_key.empty(),
                    false,
605 606 607 608
                    platform::errors::PreconditionNotMet(
                        "Not able to create/get the NCCL Communicator since "
                        "the GPU place are not known"));

609
  ncclUniqueId nccl_id;
610
  if (rank_ == 0) {
611
    NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id));
612
  }
613
  BroadcastUniqueNCCLID(&nccl_id);
614

615 616
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << places_key
617 618
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);

L
Leo Chen 已提交
619
  std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
620 621
  dev_ctx.resize(places.size());

622 623 624
  std::vector<phi::GPUContext*> dev_ctx_raw;
  dev_ctx_raw.resize(places.size());

625
  GroupStart();
626 627 628

  for (size_t i = 0; i < places.size(); ++i) {
    platform::CUDADeviceGuard guard(places[i]);
629

L
Leo Chen 已提交
630
    dev_ctx[i].reset(new phi::GPUContext(places[i]));
631
    ncclComm_t nccl_comm;
632
    NCCL_CHECK(platform::dynload::ncclCommInitRank(
633 634 635
        &nccl_comm, GetSize(), nccl_id, GetRank()));
    dev_ctx[i]->set_nccl_comm(nccl_comm);
    dev_ctx_raw[i] = dev_ctx[i].get();
636 637
  }

638
  GroupEnd();
639

640
  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
641 642 643 644 645 646
  place_to_calc_event_.emplace(places_key, places[0]);
  place_to_calc_ctx_.emplace(
      places_key,
      static_cast<phi::GPUContext*>(
          platform::DeviceContextPool::Instance().Get(places[0])));
  place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
647 648

  // These caches will be useful to process sync/wait/communicate
649
  places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
650 651 652 653
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
654
    std::vector<phi::DenseTensor>& inputs,
655 656 657
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType op_type) {
658 659 660 661 662
  const auto places = GetPlaceList(inputs);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
663
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
664 665 666 667
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
668 669
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
670 671 672 673 674 675

  auto task = CreateTask(places, rank_, op_type, inputs);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

S
ShenLiang 已提交
676 677
  {
    platform::NCCLGroupGuard nccl_guard;
678 679
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
680
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
681 682
      fn(inputs[i],
         outputs[i],
W
Wen Sun 已提交
683
         places_to_ctx_.at(key)[i]->nccl_comm(),
684
         nccl_stream);
685 686 687
    }
  }

S
ShenLiang 已提交
688
  if (FLAGS_use_stream_safe_cuda_allocator) {
689 690
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
S
ShenLiang 已提交
691
      memory::RecordStream(inputs[i].Holder(),
W
Wen Sun 已提交
692
                           places_to_ctx_.at(key)[i]->stream());
693 694 695 696 697
    }
  }

  for (size_t i = 0; i < inputs.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
698
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
699 700 701 702
  }
  return task;
}

B
Baibaifan 已提交
703 704
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
705 706 707
    std::vector<phi::DenseTensor>& tensors,
    Fn fn,
    int dst_rank,
708
    CommType op_type) {
B
Baibaifan 已提交
709 710 711 712 713
  const auto places = GetPlaceList(tensors);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
714
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
B
Baibaifan 已提交
715 716 717 718
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
719 720
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
B
Baibaifan 已提交
721 722 723 724 725 726

  auto task = CreateTask(places, rank_, op_type, tensors);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

727 728
  {
    platform::NCCLGroupGuard nccl_guard;
B
Baibaifan 已提交
729 730
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
731
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
732
      fn(tensors[i],
W
Wen Sun 已提交
733
         places_to_ctx_.at(key)[i]->nccl_comm(),
734 735
         nccl_stream,
         dst_rank);
B
Baibaifan 已提交
736 737 738
    }
  }

739
  if (FLAGS_use_stream_safe_cuda_allocator) {
B
Baibaifan 已提交
740 741
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
742
      memory::RecordStream(tensors[i].Holder(),
W
Wen Sun 已提交
743
                           places_to_ctx_.at(key)[i]->stream());
B
Baibaifan 已提交
744 745 746 747 748
    }
  }

  for (size_t i = 0; i < tensors.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
749
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
B
Baibaifan 已提交
750 751 752 753
  }
  return task;
}

754
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
755
    std::vector<phi::DenseTensor>& in_tensors,
756 757
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts) {
758
  PADDLE_ENFORCE_EQ(
759 760
      CheckTensorsInCudaPlace(in_tensors),
      true,
761
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
762
  return Collective(
763 764 765 766 767 768
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
769
        return platform::dynload::ncclAllReduce(
770 771 772
            input.data(),
            output.data(),
            input.numel(),
773
            platform::ToNCCLDataType(input.type()),
774 775 776
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream);
777 778
      },
      CommType::ALLREDUCE);
779 780 781
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
782
    std::vector<phi::DenseTensor>& in_tensors,
783 784
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts) {
785
  PADDLE_ENFORCE_EQ(
786 787
      CheckTensorsInCudaPlace(in_tensors),
      true,
788 789
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));

790
  return Collective(
791 792 793 794 795
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
796 797 798 799
          const gpuStream_t& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
        return platform::dynload::ncclBroadcast(
800 801 802 803 804 805 806
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
            stream);
807 808
      },
      CommType::BROADCAST);
809 810
}

811 812
void CheckTensorsInDifferentDevices(
    const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
B
Baibaifan 已提交
813
  PADDLE_ENFORCE_EQ(
814 815
      tensors.size() == 0,
      false,
B
Baibaifan 已提交
816 817
      platform::errors::InvalidArgument("Tensor list must be nonempty."));
  PADDLE_ENFORCE_LE(
818 819
      tensors.size(),
      num_devices,
B
Baibaifan 已提交
820 821 822 823 824 825
      platform::errors::InvalidArgument(
          "Tensor list mustn't be larger than the number of available GPUs."));

  std::set<Place> used_devices;

  for (const auto& t : tensors) {
826 827
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()),
                      true,
B
Baibaifan 已提交
828 829 830
                      platform::errors::InvalidArgument(
                          "Tensors must be CUDA and dense tensor."));

831
    const auto inserted = used_devices.insert(t.place()).second;
832 833
    PADDLE_ENFORCE_EQ(inserted,
                      true,
B
Baibaifan 已提交
834 835 836 837 838 839
                      platform::errors::InvalidArgument(
                          "Tensors must be on distinct GPU devices."));
  }
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
840
    std::vector<phi::DenseTensor>& tensors, int dst_rank) {
B
Baibaifan 已提交
841 842
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

843 844
  auto task = PointToPoint(
      tensors,
845 846 847
      [&](phi::DenseTensor& input,
          ncclComm_t comm,
          const gpuStream_t& stream,
848 849
          int dst_rank) {
        return platform::dynload::ncclSend(
850 851 852 853 854 855
            input.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            dst_rank,
            comm,
            stream);
856
      },
857 858
      dst_rank,
      CommType::SEND);
B
Baibaifan 已提交
859 860 861 862
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
863
    std::vector<phi::DenseTensor>& tensors, int src_rank) {
B
Baibaifan 已提交
864 865
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

866 867
  auto task = PointToPoint(
      tensors,
868 869 870
      [&](phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream,
871 872
          int src_rank) {
        return platform::dynload::ncclRecv(
873 874 875 876 877 878
            output.data(),
            output.numel(),
            platform::ToNCCLDataType(output.dtype()),
            src_rank,
            comm,
            stream);
879
      },
880 881
      src_rank,
      CommType::RECV);
882 883 884
  return task;
}

885
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
886 887
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
888
  PADDLE_ENFORCE_EQ(
889 890
      CheckTensorsInCudaPlace(in_tensors),
      true,
891 892
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
893 894
      CheckTensorsInCudaPlace(out_tensors),
      true,
895
      platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
896
  return Collective(
897 898 899 900 901 902
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
903
        return platform::dynload::ncclAllGather(
904 905 906 907 908 909
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
            stream);
910 911
      },
      CommType::ALLGATHER);
912 913
}

914 915
void* GetPointerByOffset(void* raw_pointer,
                         size_t offset,
916 917 918 919 920 921 922
                         experimental::DataType type) {
  if (type == experimental::DataType::FLOAT32) {
    return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT64) {
    return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
                                   offset);
923 924 925
  } else if (type == experimental::DataType::FLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
                                   offset);
926 927 928 929 930 931
  } else if (type == experimental::DataType::INT32) {
    return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT64) {
    return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
                                   offset);
932 933 934 935 936 937 938 939
  } else if (type == experimental::DataType::INT8) {
    return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::UINT8) {
    return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::BOOL) {
    return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
940
                                   offset);
941 942 943
  } else if (type == experimental::DataType::BFLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
                                   offset);
944 945
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
946
        "Datatype %s in NCCL is not supported.", type));
947
  }
948
  return nullptr;
949 950 951
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
952 953
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
954
  PADDLE_ENFORCE_EQ(
955 956
      CheckTensorsInCudaPlace(in_tensors),
      true,
957 958
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
959 960
      CheckTensorsInCudaPlace(out_tensors),
      true,
961 962
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
963 964 965 966 967
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
968 969
          const gpuStream_t& stream) {
        size_t offset = 0;
970
        GroupStart();
971 972
        for (auto i = 0; i < size_; i++) {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
973
              GetPointerByOffset(input.data(), offset, input.dtype()),
974 975 976 977 978
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
979
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
980
              GetPointerByOffset(output.data(), offset, input.dtype()),
981 982 983 984 985
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
986
          offset += input.numel() / size_;
987
        }
988
        GroupEnd();
989
      },
990 991 992
      CommType::ALLTOALL);
}

993
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
994
    std::vector<phi::DenseTensor>& in_tensors,
995 996
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceOptions& opts) {
997
  PADDLE_ENFORCE_EQ(
998 999
      CheckTensorsInCudaPlace(in_tensors),
      true,
1000 1001
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1002 1003 1004 1005 1006 1007
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
1008
        PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
1009 1010 1011
            input.data(),
            output.data(),
            input.numel(),
1012
            platform::ToNCCLDataType(input.dtype()),
1013 1014 1015 1016
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
1017 1018 1019 1020 1021
      },
      CommType::REDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
1022
    std::vector<phi::DenseTensor>& in_tensors,
1023 1024
    std::vector<phi::DenseTensor>& out_tensors,
    const ScatterOptions& opts) {
1025
  PADDLE_ENFORCE_EQ(
1026 1027
      CheckTensorsInCudaPlace(in_tensors),
      true,
1028 1029
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  PADDLE_ENFORCE_EQ(
1030 1031
      CheckTensorsInCudaPlace(out_tensors),
      true,
1032 1033
      platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
  return Collective(
1034 1035 1036 1037 1038
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1039 1040 1041
          const gpuStream_t& stream) {
        size_t offset = 0;
        if (rank_ == opts.root_rank) {
1042
          GroupStart();
1043 1044
          for (auto i = 0; i < size_; i++) {
            PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
1045
                GetPointerByOffset(input.data(), offset, input.dtype()),
1046 1047 1048 1049 1050
                input.numel() / size_,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
1051
            offset += input.numel() / size_;
1052 1053
          }
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1054 1055 1056 1057 1058
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1059
              stream));
1060
          GroupEnd();
1061 1062
        } else {
          PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
1063 1064 1065 1066 1067
              output.data(),
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              opts.root_rank,
              comm,
1068 1069 1070 1071 1072 1073
              stream));
        }
      },
      CommType::SCATTER);
}

L
LiYuRio 已提交
1074 1075 1076 1077 1078 1079 1080 1081
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
    const std::shared_ptr<Store>& store, int rank, int size, int gid) {
  auto process_group =
      std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
  ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
  return process_group;
}

1082 1083
}  //  namespace distributed
}  //  namespace paddle