process_group_nccl.cc 39.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
Wen Sun 已提交
15
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
16

17
#include "paddle/fluid/distributed/collective/check.h"
W
Wen Sun 已提交
18
#include "paddle/fluid/distributed/collective/common.h"
W
Wen Sun 已提交
19
#include "paddle/fluid/distributed/collective/nccl_tools.h"
20
#include "paddle/fluid/distributed/collective/utils.h"
21
#include "paddle/fluid/platform/cuda_device_guard.h"
22
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
23
#include "paddle/fluid/platform/place.h"
L
LiYuRio 已提交
24
#include "paddle/phi/api/lib/utils/allocator.h"
25
#include "paddle/phi/core/enforce.h"
26 27 28 29

DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);

30 31
// set this flag to `true` and recompile to enable dynamic checks
constexpr bool FLAGS_enable_nccl_dynamic_check = false;
32 33 34 35 36
constexpr int64_t kWaitBlockTImeout = 10;

namespace paddle {
namespace distributed {

37 38 39 40 41 42 43
ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
                                     int rank,
                                     CommType comm_type,
                                     bool sync_op,
                                     bool use_calc_stream)
    : TaskStream(rank, comm_type, sync_op, use_calc_stream),
      comm_event_(place),
W
Wen Sun 已提交
44
      task_place_(place) {}
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

ProcessGroupNCCL::NCCLTask::~NCCLTask() {}

bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); }

void ProcessGroupNCCL::NCCLTask::UpdateWaitChain(
    const phi::DeviceContext& ctx) {
  comm_event_.Record(&ctx);
}

// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
  // Warning here when use calc stream but also invoke waiting explicitly.
  if (UseCalcStream()) {
    VLOG(3) << "Warning: The communication is on calc stream, wait here is "
               "useless.";
    return true;
  }

W
Wen Sun 已提交
64 65 66
  const auto* calc_ctx =
      platform::DeviceContextPool::Instance().Get(task_place_);
  comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
67 68 69 70 71 72

  if (FLAGS_nccl_blocking_wait) {
    // NOTE(shenliang03): It will block host for sync
    while (!IsCompleted()) {
      std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
    }
73
  }
74

W
Wen Sun 已提交
75
  if (IsBlockCPUInWait()) {
76 77 78 79 80 81 82 83
    // If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
    PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
  }
  return true;
84 85
}

86 87 88 89 90 91 92
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
                                   int rank,
                                   int size,
                                   int gid)
93
    : ProcessGroupWithStream(rank, size, gid), store_(store) {}
94 95

void ProcessGroupNCCL::GroupStart() {
96
  NCCL_CHECK(phi::dynload::ncclGroupStart());
97 98
}

99
void ProcessGroupNCCL::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); }
100

101
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
102 103 104 105
    const Place& place) const {
  return GetDeviceContext(place, /*use_calc_stream*/ false);
}

106
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
107 108 109 110
    const Place& place, bool use_calc_stream) const {
  const std::string& key = GetKeyFromPlace(place);
  if (use_calc_stream) {
    const auto& iter = place_to_calc_ctx_.find(key);
111
    return iter->second;
112 113 114 115 116
  } else {
    const auto& iter = place_to_comm_ctx_.find(key);
    PADDLE_ENFORCE_NE(
        iter,
        place_to_comm_ctx_.end(),
117
        phi::errors::NotFound(
118
            "Cannot find the device context in this process group."));
119
    return iter->second.get();
120 121 122 123 124 125 126 127 128
  }
}

ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
  const std::string& key = GetKeyFromPlace(place);
  const auto& iter = place_to_comm_ctx_.find(key);
  PADDLE_ENFORCE_NE(
      iter,
      place_to_comm_ctx_.end(),
129
      phi::errors::NotFound(
130 131 132 133 134 135 136
          "Cannot find the NCCL commmunicator in this process group."));
  return iter->second->nccl_comm();
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
137 138
    int64_t offset,
    int64_t numel,
139 140
    bool sync_op,
    bool use_calc_stream) {
141 142 143
  // numel > 0 indicates the tensor need to be sliced
  const phi::DenseTensor& in_tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
144 145 146 147 148
  CommStaticCheck::GatherLikeShape(*out_tensor,
                                   in_tensor_maybe_partial,
                                   /*dst_rank*/ rank_,
                                   /*cur_rank*/ rank_,
                                   size_);
149 150
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
151 152 153 154 155 156 157
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*out_tensor,
                                       /*root_rank*/ 0,
                                       rank_,
                                       comm);
        }
        NCCL_CHECK(phi::dynload::ncclAllGather(
158 159 160 161
            in_tensor_maybe_partial.data(),
            out_tensor->data(),
            in_tensor_maybe_partial.numel(),
            platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
162
            comm,
163
            stream));
164
      },
165
      in_tensor_maybe_partial,
166 167 168 169 170 171 172 173 174 175 176
      CommType::ALLGATHER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
177 178 179 180 181
  CommStaticCheck::SameShape(*out_tensor,
                             in_tensor,
                             /*dst_rank*/ rank_,
                             /*cur_rank*/ rank_,
                             size_);
182 183
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
184 185 186 187 188 189 190
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*out_tensor,
                                       /*root_rank*/ 0,
                                       rank_,
                                       comm);
        }
        NCCL_CHECK(phi::dynload::ncclAllReduce(
191 192 193 194
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
195 196
            ToNCCLRedType(opts.reduce_op),
            comm,
197
            stream));
198
      },
199
      in_tensor,
200 201 202 203 204
      CommType::ALLREDUCE,
      sync_op,
      use_calc_stream);
}

205 206 207 208 209 210 211
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
                         const std::vector<int64_t>& size_on_each_rank,
                         int world_size) {
  int length_size_on_each_rank = size_on_each_rank.size();
  PADDLE_ENFORCE_EQ(
      length_size_on_each_rank,
      world_size,
212
      phi::errors::InvalidArgument(
213 214 215 216 217 218 219
          "The length of size_on_each_rank must be equal to world_size."));

  int64_t sum_size_on_each_rank =
      std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
  PADDLE_ENFORCE_EQ(
      sum_size_on_each_rank,
      tensor_dim[0],
220
      phi::errors::InvalidArgument(
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
          "The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op,
    bool use_calc_stream) {
  const phi::DDim& out_dim = out_tensor->dims();
  const phi::DDim& in_dim = in_tensor.dims();
  CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
  CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);

236
  // NOTE: Since `all_to_all` needs other processes' participation, it cannot
237
  // simply be covered by static checks. Factors are set to 0 here to skip the
238 239
  // shape check. Its shape check will be done by dynamic checks with
  // FLAGS_enable_nccl_dynamic_check.
240 241 242 243 244 245 246
  CommStaticCheck::CheckShape(*out_tensor,
                              in_tensor,
                              /*dst_rank*/ rank_,
                              /*cur_rank*/ rank_,
                              size_,
                              /*out_size_factor*/ 0,
                              /*in_size_factor*/ 0);
247 248
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
249 250 251 252
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(
              *out_tensor, in_tensor, in_size_each_rank, rank_, size_, comm);
        }
253 254
        int64_t in_row_size = in_tensor.numel() / in_dim[0],
                out_row_size = out_tensor->numel() / out_dim[0];
255 256 257 258 259 260
        int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
        phi::DenseTensor input_partial, output_partial;

        GroupStart();
        for (auto i = 0; i < size_; i++) {
          in_numel = in_size_each_rank[i] * in_row_size;
261
          input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
262
          NCCL_CHECK(phi::dynload::ncclSend(
263 264
              input_partial.data(),
              in_numel,
265
              platform::ToNCCLDataType(input_partial.dtype()),
266 267 268 269 270 271
              i,
              comm,
              stream));
          in_offset += in_numel;

          out_numel = out_size_each_rank[i] * out_row_size;
272
          output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
273
          NCCL_CHECK(phi::dynload::ncclRecv(
274 275
              output_partial.data(),
              out_numel,
276
              platform::ToNCCLDataType(output_partial.dtype()),
277 278 279 280 281 282 283
              i,
              comm,
              stream));
          out_offset += out_numel;
        }
        GroupEnd();
      },
284
      in_tensor,
285 286 287 288 289
      CommType::ALLTOALL,
      sync_op,
      use_calc_stream);
}

290 291
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
    const BarrierOptions& opts) {
292 293
  PADDLE_ENFORCE_GE(opts.device_id,
                    0,
294
                    phi::errors::PreconditionNotMet(
295 296
                        "The barrier device id must greater or equal than 0."));
  platform::CUDAPlace place(opts.device_id);
297
  auto allocator = std::unique_ptr<phi::Allocator>(
298
      new paddle::experimental::DefaultAllocator(place));
299 300 301 302 303 304 305 306 307
  phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1});
  phi::DenseTensor barrier_tensor{allocator.get(), meta};

  auto task = AllReduce(&barrier_tensor,
                        barrier_tensor,
                        {},
                        /*sync_op*/ true,
                        /*use_calc_stream*/ false);
  auto nccl_task = dynamic_cast<NCCLTask*>(task.get());
W
Wen Sun 已提交
308
  nccl_task->SetBlockCPUInWait();
309 310 311 312 313 314 315 316 317
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
318 319 320 321 322
  CommStaticCheck::SameShape(*out_tensor,
                             in_tensor,
                             /*dst_rank*/ rank_,
                             /*cur_rank*/ rank_,
                             size_);
323 324
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
325
        int root = opts.source_rank + opts.source_root;
326 327 328 329
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*out_tensor, root, rank_, comm);
        }
        NCCL_CHECK(phi::dynload::ncclBroadcast(
330 331 332 333
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
334 335
            root,
            comm,
336
            stream));
337
      },
338
      in_tensor,
339 340 341
      CommType::BROADCAST,
      sync_op,
      use_calc_stream);
342 343
}

344 345 346 347 348 349
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
350 351 352 353 354
  CommStaticCheck::SameShape(*out_tensor,
                             in_tensor,
                             /*dst_rank*/ opts.root_rank,
                             /*cur_rank*/ rank_,
                             size_);
355 356
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
357 358 359 360 361 362 363
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*out_tensor,
                                       /*root_rank*/ opts.root_rank,
                                       rank_,
                                       comm);
        }
        NCCL_CHECK(phi::dynload::ncclReduce(
364 365 366 367
            in_tensor.data(),
            out_tensor->data(),
            in_tensor.numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
368 369 370 371 372
            ToNCCLRedType(opts.reduce_op),
            opts.root_rank,
            comm,
            stream));
      },
373
      in_tensor,
374 375 376 377 378 379 380 381 382 383 384
      CommType::REDUCE,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
385 386 387 388 389
  CommStaticCheck::ScatterLikeShape(*out_tensor,
                                    in_tensor,
                                    /*dst_rank*/ rank_,
                                    /*cur_rank*/ rank_,
                                    size_);
390 391
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
392 393 394 395 396 397 398
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*out_tensor,
                                       /*root_rank*/ 0,
                                       rank_,
                                       comm);
        }
        NCCL_CHECK(phi::dynload::ncclReduceScatter(
399 400 401 402
            in_tensor.data(),
            out_tensor->data(),
            out_tensor->numel(),
            platform::ToNCCLDataType(in_tensor.dtype()),
403 404 405 406
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream));
      },
407
      in_tensor,
408 409 410 411 412 413 414 415 416 417 418
      CommType::REDUCE_SCATTER,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
419 420 421 422 423
  CommStaticCheck::ScatterLikeShape(*out_tensor,
                                    in_tensor,
                                    /*dst_rank*/ opts.root_rank,
                                    /*cur_rank*/ rank_,
                                    size_);
424 425
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
426 427 428 429 430 431
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*out_tensor,
                                       /*root_rank*/ opts.root_rank,
                                       rank_,
                                       comm);
        }
432
        int64_t numel = in_tensor.numel() / size_;
433 434 435 436 437
        if (rank_ == opts.root_rank) {
          int64_t offset = 0;
          phi::DenseTensor partial_tensor;
          GroupStart();
          for (auto i = 0; i < size_; i++) {
438
            partial_tensor = GetPartialTensor(in_tensor, offset, numel);
439
            NCCL_CHECK(phi::dynload::ncclSend(
440 441
                partial_tensor.data(),
                numel,
442
                platform::ToNCCLDataType(partial_tensor.dtype()),
443 444 445 446 447
                i,
                comm,
                stream));
            offset += numel;
          }
448
          NCCL_CHECK(phi::dynload::ncclRecv(
449
              out_tensor->data(),
450
              numel,
451
              platform::ToNCCLDataType(out_tensor->dtype()),
452 453 454 455 456
              opts.root_rank,
              comm,
              stream));
          GroupEnd();
        } else {
457
          NCCL_CHECK(phi::dynload::ncclRecv(
458
              out_tensor->data(),
459
              numel,
460
              platform::ToNCCLDataType(out_tensor->dtype()),
461 462 463 464 465
              opts.root_rank,
              comm,
              stream));
        }
      },
466
      in_tensor,
467 468 469 470 471
      CommType::SCATTER,
      sync_op,
      use_calc_stream);
}

472 473 474 475
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
    phi::DenseTensor* tensor,
    int src_rank,
    int64_t offset,
476
    int64_t numel,
477 478
    bool sync_op,
    bool use_calc_stream) {
479 480 481 482 483 484
  // numel > 0 indicates the tensor need to be sliced
  phi::DenseTensor partial_tensor;
  if (numel > 0) {
    partial_tensor = GetPartialTensor(*tensor, offset, numel);
    tensor = &partial_tensor;
  }
485

486
  CommStaticCheck::CheckShape(*tensor, rank_, size_);
487 488
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
489 490 491 492 493 494 495 496 497 498 499 500 501
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(*tensor,
                                       /*root_rank*/ src_rank,
                                       rank_,
                                       comm);
        }
        NCCL_CHECK(
            phi::dynload::ncclRecv(tensor->data(),
                                   tensor->numel(),
                                   platform::ToNCCLDataType(tensor->dtype()),
                                   src_rank,
                                   comm,
                                   stream));
502
      },
503
      *tensor,
504 505 506 507 508 509
      CommType::RECV,
      sync_op,
      use_calc_stream);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
510
    const phi::DenseTensor& tensor,
511 512
    int dst_rank,
    int64_t offset,
513
    int64_t numel,
514 515
    bool sync_op,
    bool use_calc_stream) {
516
  // numel > 0 indicates the tensor need to be sliced
517 518
  const phi::DenseTensor& tensor_maybe_partial =
      numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
519

520
  CommStaticCheck::CheckShape(tensor_maybe_partial, rank_, size_);
521 522
  return RunFnInNCCLEnv(
      [&](ncclComm_t comm, gpuStream_t stream) {
523 524 525 526 527 528 529
        if (FLAGS_enable_nccl_dynamic_check) {
          CommDynamicCheck::CheckShape(tensor_maybe_partial,
                                       /*root_rank*/ rank_,
                                       rank_,
                                       comm);
        }
        NCCL_CHECK(phi::dynload::ncclSend(
530 531 532 533
            tensor_maybe_partial.data(),
            tensor_maybe_partial.numel(),
            platform::ToNCCLDataType(tensor_maybe_partial.dtype()),
            dst_rank,
534
            comm,
535
            stream));
536
      },
537
      tensor_maybe_partial,
538 539 540 541 542
      CommType::SEND,
      sync_op,
      use_calc_stream);
}

543
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
544
    const Place& place,
545 546 547 548 549
    int rank,
    CommType comm_type,
    bool is_sync,
    bool use_calc_stream) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
550
      place, rank, comm_type, is_sync, use_calc_stream);
551 552
}

553 554 555 556 557 558 559 560 561 562 563 564
void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) {
  const std::string key =
      "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0";
  if (rank_ == 0) {
    std::vector<uint8_t> nccl_id_wrapper(
        reinterpret_cast<uint8_t*>(nccl_id),
        reinterpret_cast<uint8_t*>(nccl_id) + NCCL_UNIQUE_ID_BYTES);
    store_->set(key, nccl_id_wrapper);
  } else {
    const auto& nccl_id_wrapper = store_->get(key);
    std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
  }
565 566
}

567 568
void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
                                          const std::string& place_key) {
W
Wen Sun 已提交
569 570 571 572
  if (place_to_comm_ctx_.size() > 0) {
    VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
  }

573 574
  ncclUniqueId nccl_id;
  if (rank_ == 0) {
575
    NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id));
576 577
  }
  BroadcastUniqueNCCLID(&nccl_id);
578

579 580 581
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << place_key
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
582

583 584 585 586
  auto* calc_ctx = static_cast<phi::GPUContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto comm_ctx = std::make_unique<phi::GPUContext>(place);
  ncclComm_t nccl_comm;
587
  NCCL_CHECK(phi::dynload::ncclCommInitRank(
588 589 590
      &nccl_comm, GetSize(), nccl_id, GetRank()));
  comm_ctx->set_nccl_comm(nccl_comm);

W
Wen Sun 已提交
591 592 593
  place_to_calc_event_.emplace(place_key, place);
  place_to_calc_ctx_.emplace(place_key, calc_ctx);
  place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
594 595

  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
596 597 598
  std::vector<phi::GPUContext*> comm_ctx_wrapper{
      place_to_comm_ctx_[place_key].get()};
  places_to_ctx_.emplace(place_key, comm_ctx_wrapper);
599 600
}

W
Wen Sun 已提交
601
void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
602
  const std::string& key = GetKeyFromPlace(place);
W
Wen Sun 已提交
603 604 605 606 607
  auto& calc_event = place_to_calc_event_.at(key);
  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto* comm_ctx = place_to_comm_ctx_.at(key).get();
  calc_event.Record(calc_ctx);
  calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
608 609
}

610 611 612
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
    std::function<void(ncclComm_t, gpuStream_t)> fn,
    const phi::DenseTensor& tensor,
613 614 615
    CommType comm_type,
    bool sync_op,
    bool use_calc_stream) {
616
  const auto& place = tensor.place();
617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
  const auto& key = GetKeyFromPlace(place);

  platform::CUDADeviceGuard cuda_guard(place);

  if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
    CreateNCCLEnvCache(place, key);
  }

  if (!use_calc_stream) {
    SyncCalcStream(place);
  }

  auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);

  const auto* calc_ctx = place_to_calc_ctx_.at(key);
  const auto& comm_ctx = place_to_comm_ctx_.at(key);
  auto nccl_comm = comm_ctx->nccl_comm();
  auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
635
  fn(nccl_comm, nccl_stream);
636 637 638

  if (!use_calc_stream) {
    if (FLAGS_use_stream_safe_cuda_allocator) {
639
      memory::RecordStream(tensor.Holder(), nccl_stream);
640 641 642 643
    }
    task->UpdateWaitChain(*comm_ctx);
  }

644 645 646 647
  if (FLAGS_enable_nccl_dynamic_check) {
    task->SetBlockCPUInWait();
    task->Wait();
  }
648 649 650
  return task;
}

651 652
// TODO(sunyilun): methods below will be removed later
void SyncDefaultStream(const std::vector<Place>& places,
W
Wen Sun 已提交
653
                       platform::DeviceEvent& nccl_event,         // NOLINT
654 655 656 657
                       std::vector<phi::GPUContext*>& dev_ctx) {  // NOLINT
  for (size_t i = 0; i < places.size(); ++i) {
    auto* default_ctx = static_cast<phi::GPUContext*>(
        platform::DeviceContextPool::Instance().Get(places[i]));
W
Wen Sun 已提交
658 659
    nccl_event.Record(default_ctx);
    nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]);
B
Baibaifan 已提交
660
  }
661 662
}

663 664 665 666 667 668 669
std::shared_ptr<ProcessGroupNCCL::NCCLTask> ProcessGroupNCCL::CreateTask(
    std::vector<Place> places,
    int rank,
    CommType comm_type,
    const std::vector<phi::DenseTensor>& inputs) {
  return std::make_shared<ProcessGroupNCCL::NCCLTask>(
      places, rank, comm_type, inputs);
670
}
671

672 673 674 675 676 677 678
ProcessGroupNCCL::NCCLTask::NCCLTask(
    const std::vector<Place>& places,
    int rank,
    CommType CommType,
    const std::vector<phi::DenseTensor>& inputs)
    : TaskStream(rank, inputs, CommType),
      comm_event_(places[0]),
W
Wen Sun 已提交
679
      task_place_(places[0]) {}
680

681 682 683
// create NCCLManager cache for places_key
void ProcessGroupNCCL::CreateNCCLManagerCache(
    const std::string& places_key, const std::vector<Place>& places) {
684 685
  PADDLE_ENFORCE_EQ(places_key.empty(),
                    false,
686
                    phi::errors::PreconditionNotMet(
687 688 689
                        "Not able to create/get the NCCL Communicator since "
                        "the GPU place are not known"));

690
  ncclUniqueId nccl_id;
691
  if (rank_ == 0) {
692
    NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id));
693
  }
694
  BroadcastUniqueNCCLID(&nccl_id);
695

696 697
  VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
          << ", place: " << places_key
698 699
          << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);

L
Leo Chen 已提交
700
  std::vector<std::unique_ptr<phi::GPUContext>> dev_ctx;
701 702
  dev_ctx.resize(places.size());

703 704 705
  std::vector<phi::GPUContext*> dev_ctx_raw;
  dev_ctx_raw.resize(places.size());

706
  GroupStart();
707 708 709

  for (size_t i = 0; i < places.size(); ++i) {
    platform::CUDADeviceGuard guard(places[i]);
710

L
Leo Chen 已提交
711
    dev_ctx[i].reset(new phi::GPUContext(places[i]));
712
    ncclComm_t nccl_comm;
713
    NCCL_CHECK(phi::dynload::ncclCommInitRank(
714 715 716
        &nccl_comm, GetSize(), nccl_id, GetRank()));
    dev_ctx[i]->set_nccl_comm(nccl_comm);
    dev_ctx_raw[i] = dev_ctx[i].get();
717 718
  }

719
  GroupEnd();
720

721
  // TODO(sunyilun): for compatibility, will be removed later
W
Wen Sun 已提交
722 723 724 725 726 727
  place_to_calc_event_.emplace(places_key, places[0]);
  place_to_calc_ctx_.emplace(
      places_key,
      static_cast<phi::GPUContext*>(
          platform::DeviceContextPool::Instance().Get(places[0])));
  place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0]));
728 729

  // These caches will be useful to process sync/wait/communicate
730
  places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw));
731 732 733 734
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
735
    std::vector<phi::DenseTensor>& inputs,
736 737 738
    std::vector<phi::DenseTensor>& outputs,
    Fn fn,
    CommType op_type) {
739 740 741 742 743
  const auto places = GetPlaceList(inputs);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
744
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
745 746 747 748
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
749 750
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
751 752 753 754 755 756

  auto task = CreateTask(places, rank_, op_type, inputs);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

S
ShenLiang 已提交
757 758
  {
    platform::NCCLGroupGuard nccl_guard;
759 760
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
761
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
762 763
      fn(inputs[i],
         outputs[i],
W
Wen Sun 已提交
764
         places_to_ctx_.at(key)[i]->nccl_comm(),
765
         nccl_stream);
766 767 768
    }
  }

S
ShenLiang 已提交
769
  if (FLAGS_use_stream_safe_cuda_allocator) {
770 771
    for (size_t i = 0; i < inputs.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
S
ShenLiang 已提交
772
      memory::RecordStream(inputs[i].Holder(),
W
Wen Sun 已提交
773
                           places_to_ctx_.at(key)[i]->stream());
774 775 776 777 778
    }
  }

  for (size_t i = 0; i < inputs.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
779
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
780 781 782 783
  }
  return task;
}

B
Baibaifan 已提交
784 785
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
786 787 788
    std::vector<phi::DenseTensor>& tensors,
    Fn fn,
    int dst_rank,
789
    CommType op_type) {
B
Baibaifan 已提交
790 791 792 793 794
  const auto places = GetPlaceList(tensors);
  const auto key = GetKeyFromPlaces(places);

  {
    std::lock_guard<std::mutex> lock(mutex_);
W
Wen Sun 已提交
795
    if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
B
Baibaifan 已提交
796 797 798 799
      CreateNCCLManagerCache(key, places);
    }
  }

W
Wen Sun 已提交
800 801
  SyncDefaultStream(
      places, place_to_calc_event_.at(key), places_to_ctx_.at(key));
B
Baibaifan 已提交
802 803 804 805 806 807

  auto task = CreateTask(places, rank_, op_type, tensors);

  // construct uninitialize guard for device
  platform::CUDADeviceGuard cuda_guard;

808 809
  {
    platform::NCCLGroupGuard nccl_guard;
B
Baibaifan 已提交
810 811
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
812
      const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream();
813
      fn(tensors[i],
W
Wen Sun 已提交
814
         places_to_ctx_.at(key)[i]->nccl_comm(),
815 816
         nccl_stream,
         dst_rank);
B
Baibaifan 已提交
817 818 819
    }
  }

820
  if (FLAGS_use_stream_safe_cuda_allocator) {
B
Baibaifan 已提交
821 822
    for (size_t i = 0; i < tensors.size(); ++i) {
      cuda_guard.SetDevice(places[i]);
823
      memory::RecordStream(tensors[i].Holder(),
W
Wen Sun 已提交
824
                           places_to_ctx_.at(key)[i]->stream());
B
Baibaifan 已提交
825 826 827 828 829
    }
  }

  for (size_t i = 0; i < tensors.size(); ++i) {
    cuda_guard.SetDevice(places[i]);
W
Wen Sun 已提交
830
    task->UpdateWaitChain(*places_to_ctx_.at(key)[i]);
B
Baibaifan 已提交
831 832 833 834
  }
  return task;
}

835
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
836
    std::vector<phi::DenseTensor>& in_tensors,
837 838
    std::vector<phi::DenseTensor>& out_tensors,
    const AllreduceOptions& opts) {
839
  PADDLE_ENFORCE_EQ(
840 841
      CheckTensorsInCudaPlace(in_tensors),
      true,
842
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
843
  return Collective(
844 845 846 847 848 849
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
850
        return phi::dynload::ncclAllReduce(
851 852 853
            input.data(),
            output.data(),
            input.numel(),
854
            platform::ToNCCLDataType(input.type()),
855 856 857
            ToNCCLRedType(opts.reduce_op),
            comm,
            stream);
858 859
      },
      CommType::ALLREDUCE);
860 861 862
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
863
    std::vector<phi::DenseTensor>& in_tensors,
864 865
    std::vector<phi::DenseTensor>& out_tensors,
    const BroadcastOptions& opts) {
866
  PADDLE_ENFORCE_EQ(
867 868
      CheckTensorsInCudaPlace(in_tensors),
      true,
869
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
870

871
  return Collective(
872 873 874 875 876
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
877 878 879
          const gpuStream_t& stream) {
        const auto root =
            opts.source_rank * in_tensors.size() + opts.source_root;
880
        return phi::dynload::ncclBroadcast(
881 882 883 884 885 886 887
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.type()),
            root,
            comm,
            stream);
888 889
      },
      CommType::BROADCAST);
890 891
}

892 893
void CheckTensorsInDifferentDevices(
    const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
B
Baibaifan 已提交
894
  PADDLE_ENFORCE_EQ(
895 896
      tensors.size() == 0,
      false,
897
      phi::errors::InvalidArgument("Tensor list must be nonempty."));
B
Baibaifan 已提交
898
  PADDLE_ENFORCE_LE(
899 900
      tensors.size(),
      num_devices,
901
      phi::errors::InvalidArgument(
B
Baibaifan 已提交
902 903 904 905 906
          "Tensor list mustn't be larger than the number of available GPUs."));

  std::set<Place> used_devices;

  for (const auto& t : tensors) {
907 908 909 910
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(t.place()),
        true,
        phi::errors::InvalidArgument("Tensors must be CUDA and dense tensor."));
B
Baibaifan 已提交
911

912
    const auto inserted = used_devices.insert(t.place()).second;
913 914
    PADDLE_ENFORCE_EQ(inserted,
                      true,
915
                      phi::errors::InvalidArgument(
B
Baibaifan 已提交
916 917 918 919 920
                          "Tensors must be on distinct GPU devices."));
  }
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
921
    std::vector<phi::DenseTensor>& tensors, int dst_rank) {
B
Baibaifan 已提交
922 923
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

924 925
  auto task = PointToPoint(
      tensors,
926 927 928
      [&](phi::DenseTensor& input,
          ncclComm_t comm,
          const gpuStream_t& stream,
929
          int dst_rank) {
930 931 932 933 934 935
        return phi::dynload::ncclSend(input.data(),
                                      input.numel(),
                                      platform::ToNCCLDataType(input.dtype()),
                                      dst_rank,
                                      comm,
                                      stream);
936
      },
937 938
      dst_rank,
      CommType::SEND);
B
Baibaifan 已提交
939 940 941 942
  return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
943
    std::vector<phi::DenseTensor>& tensors, int src_rank) {
B
Baibaifan 已提交
944 945
  CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

946 947
  auto task = PointToPoint(
      tensors,
948 949 950
      [&](phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream,
951
          int src_rank) {
952 953 954 955 956 957
        return phi::dynload::ncclRecv(output.data(),
                                      output.numel(),
                                      platform::ToNCCLDataType(output.dtype()),
                                      src_rank,
                                      comm,
                                      stream);
958
      },
959 960
      src_rank,
      CommType::RECV);
961 962 963
  return task;
}

964
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
965 966
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
967
  PADDLE_ENFORCE_EQ(
968 969
      CheckTensorsInCudaPlace(in_tensors),
      true,
970
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
971
  PADDLE_ENFORCE_EQ(
972 973
      CheckTensorsInCudaPlace(out_tensors),
      true,
974
      phi::errors::InvalidArgument("All outputs should be in CudaPlace."));
975
  return Collective(
976 977 978 979 980 981
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
982
        return phi::dynload::ncclAllGather(
983 984 985 986 987 988
            input.data(),
            output.data(),
            input.numel(),
            platform::ToNCCLDataType(input.dtype()),
            comm,
            stream);
989 990
      },
      CommType::ALLGATHER);
991 992
}

993 994
void* GetPointerByOffset(void* raw_pointer,
                         size_t offset,
995 996 997 998 999 1000 1001
                         experimental::DataType type) {
  if (type == experimental::DataType::FLOAT32) {
    return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::FLOAT64) {
    return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
                                   offset);
1002 1003 1004
  } else if (type == experimental::DataType::FLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
                                   offset);
1005 1006 1007 1008 1009 1010
  } else if (type == experimental::DataType::INT32) {
    return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::INT64) {
    return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
                                   offset);
1011 1012 1013 1014 1015 1016 1017 1018
  } else if (type == experimental::DataType::INT8) {
    return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::UINT8) {
    return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
                                   offset);
  } else if (type == experimental::DataType::BOOL) {
    return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
1019
                                   offset);
1020 1021 1022
  } else if (type == experimental::DataType::BFLOAT16) {
    return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
                                   offset);
1023
  } else {
1024
    PADDLE_THROW(phi::errors::Unimplemented(
1025
        "Datatype %s in NCCL is not supported.", type));
1026
  }
1027
  return nullptr;
1028 1029 1030
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
1031 1032
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors) {
1033
  PADDLE_ENFORCE_EQ(
1034 1035
      CheckTensorsInCudaPlace(in_tensors),
      true,
1036
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
1037
  PADDLE_ENFORCE_EQ(
1038 1039
      CheckTensorsInCudaPlace(out_tensors),
      true,
1040
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
1041
  return Collective(
1042 1043 1044 1045 1046
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1047 1048
          const gpuStream_t& stream) {
        size_t offset = 0;
1049
        GroupStart();
1050
        for (auto i = 0; i < size_; i++) {
1051
          PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
1052
              GetPointerByOffset(input.data(), offset, input.dtype()),
1053 1054 1055 1056 1057
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
1058
          PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRecv(
1059
              GetPointerByOffset(output.data(), offset, input.dtype()),
1060 1061 1062 1063 1064
              input.numel() / size_,
              platform::ToNCCLDataType(input.dtype()),
              i,
              comm,
              stream));
1065
          offset += input.numel() / size_;
1066
        }
1067
        GroupEnd();
1068
      },
1069 1070 1071
      CommType::ALLTOALL);
}

1072
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
1073
    std::vector<phi::DenseTensor>& in_tensors,
1074 1075
    std::vector<phi::DenseTensor>& out_tensors,
    const ReduceOptions& opts) {
1076
  PADDLE_ENFORCE_EQ(
1077 1078
      CheckTensorsInCudaPlace(in_tensors),
      true,
1079
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
1080
  return Collective(
1081 1082 1083 1084 1085 1086
      in_tensors,
      out_tensors,
      [&](const phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
          const gpuStream_t& stream) {
1087 1088 1089 1090 1091 1092 1093 1094 1095
        PADDLE_ENFORCE_GPU_SUCCESS(
            phi::dynload::ncclReduce(input.data(),
                                     output.data(),
                                     input.numel(),
                                     platform::ToNCCLDataType(input.dtype()),
                                     ToNCCLRedType(opts.reduce_op),
                                     opts.root_rank,
                                     comm,
                                     stream));
1096 1097 1098 1099 1100
      },
      CommType::REDUCE);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
1101
    std::vector<phi::DenseTensor>& in_tensors,
1102 1103
    std::vector<phi::DenseTensor>& out_tensors,
    const ScatterOptions& opts) {
1104
  PADDLE_ENFORCE_EQ(
1105 1106
      CheckTensorsInCudaPlace(in_tensors),
      true,
1107
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
1108
  PADDLE_ENFORCE_EQ(
1109 1110
      CheckTensorsInCudaPlace(out_tensors),
      true,
1111
      phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
1112
  return Collective(
1113 1114 1115 1116 1117
      in_tensors,
      out_tensors,
      [&](phi::DenseTensor& input,
          phi::DenseTensor& output,
          ncclComm_t comm,
1118 1119 1120
          const gpuStream_t& stream) {
        size_t offset = 0;
        if (rank_ == opts.root_rank) {
1121
          GroupStart();
1122
          for (auto i = 0; i < size_; i++) {
1123
            PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
1124
                GetPointerByOffset(input.data(), offset, input.dtype()),
1125 1126 1127 1128 1129
                input.numel() / size_,
                platform::ToNCCLDataType(input.dtype()),
                i,
                comm,
                stream));
1130
            offset += input.numel() / size_;
1131
          }
1132 1133 1134 1135 1136 1137 1138
          PADDLE_ENFORCE_GPU_SUCCESS(
              phi::dynload::ncclRecv(output.data(),
                                     input.numel() / size_,
                                     platform::ToNCCLDataType(input.dtype()),
                                     opts.root_rank,
                                     comm,
                                     stream));
1139
          GroupEnd();
1140
        } else {
1141 1142 1143 1144 1145 1146 1147
          PADDLE_ENFORCE_GPU_SUCCESS(
              phi::dynload::ncclRecv(output.data(),
                                     input.numel() / size_,
                                     platform::ToNCCLDataType(input.dtype()),
                                     opts.root_rank,
                                     comm,
                                     stream));
1148 1149 1150 1151 1152
        }
      },
      CommType::SCATTER);
}

L
LiYuRio 已提交
1153 1154 1155 1156 1157 1158 1159 1160
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
    const std::shared_ptr<Store>& store, int rank, int size, int gid) {
  auto process_group =
      std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
  ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
  return process_group;
}

1161 1162
}  //  namespace distributed
}  //  namespace paddle