From 5a6cd05f60000cfb411968fb5784a30dca6d83c6 Mon Sep 17 00:00:00 2001 From: wentao yu Date: Thu, 29 Jun 2023 15:18:52 +0800 Subject: [PATCH] update dygraph collective process group (#54863) * update dygraph collective fix ut * remove debug log --- .../collective/process_group_nccl.cc | 415 +++++------------- .../collective/process_group_nccl.h | 6 +- paddle/fluid/pybind/distributed_py.cc | 1 + python/paddle/distributed/collective.py | 5 +- 4 files changed, 132 insertions(+), 295 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index e7c31e1a181..58c1ae247d7 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -26,6 +26,8 @@ #include "paddle/phi/core/flags.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" + PHI_DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -144,26 +146,10 @@ std::shared_ptr ProcessGroupNCCL::AllGather( // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& in_tensor_maybe_partial = numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; - phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, - in_tensor_maybe_partial, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ 0, - rank_, - comm); - } - NCCL_CHECK(phi::dynload::ncclAllGather( - in_tensor_maybe_partial.data(), - out_tensor->data(), - in_tensor_maybe_partial.numel(), - phi::ToNCCLDataType(in_tensor_maybe_partial.dtype()), - comm, - stream)); + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); + comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); }, in_tensor_maybe_partial, CommType::ALLGATHER, @@ -177,27 +163,11 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) { - phi::distributed::CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ 0, - rank_, - comm); - } - NCCL_CHECK( - phi::dynload::ncclAllReduce(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - phi::ToNCCLDataType(in_tensor.dtype()), - ToNCCLRedType(opts.reduce_op), - comm, - stream)); + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); + comm_context->AllReduce( + out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream); }, in_tensor, CommType::ALLREDUCE, @@ -221,49 +191,37 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( // simply be covered by static checks. Factors are set to 0 here to skip the // shape check. Its shape check will be done by dynamic checks with // FLAGS_enable_nccl_dynamic_check. - phi::distributed::CommStaticCheck::CheckShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_, - /*out_size_factor*/ 0, - /*in_size_factor*/ 0); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckShape( - *out_tensor, in_tensor, in_size_each_rank, rank_, size_, comm); + *out_tensor, + in_tensor, + in_size_each_rank, + rank_, + size_, + comm_context->GetNcclComm()); } + int64_t in_row_size = in_tensor.numel() / in_dim[0], out_row_size = out_tensor->numel() / out_dim[0]; int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; phi::DenseTensor input_partial, output_partial; - GroupStart(); + comm_context->GroupStart(); for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); - NCCL_CHECK( - phi::dynload::ncclSend(input_partial.data(), - in_numel, - phi::ToNCCLDataType(input_partial.dtype()), - i, - comm, - stream)); + comm_context->Send(input_partial, in_numel, i, stream); in_offset += in_numel; out_numel = out_size_each_rank[i] * out_row_size; output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel); - NCCL_CHECK(phi::dynload::ncclRecv( - output_partial.data(), - out_numel, - phi::ToNCCLDataType(output_partial.dtype()), - i, - comm, - stream)); + comm_context->Recv(&output_partial, out_numel, i, stream); out_offset += out_numel; } - GroupEnd(); + comm_context->GroupEnd(); }, in_tensor, CommType::ALLTOALL, @@ -299,26 +257,11 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) { - phi::distributed::CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { + [&](gpuStream_t stream) { int root = opts.source_rank + opts.source_root; - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape( - *out_tensor, root, rank_, comm); - } - NCCL_CHECK( - phi::dynload::ncclBroadcast(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - phi::ToNCCLDataType(in_tensor.dtype()), - root, - comm, - stream)); + auto comm_context = this->GetCommContext(); + comm_context->Broadcast(out_tensor, in_tensor, root, stream); }, in_tensor, CommType::BROADCAST, @@ -332,29 +275,14 @@ std::shared_ptr ProcessGroupNCCL::Reduce( const ReduceOptions& opts, bool sync_op, bool use_calc_stream) { - phi::distributed::CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ opts.root_rank, - /*cur_rank*/ rank_, - size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape( - *out_tensor, - /*root_rank*/ opts.root_rank, - rank_, - comm); - } - NCCL_CHECK( - phi::dynload::ncclReduce(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - phi::ToNCCLDataType(in_tensor.dtype()), - ToNCCLRedType(opts.reduce_op), - opts.root_rank, - comm, - stream)); + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); + comm_context->Reduce(out_tensor, + in_tensor, + ToNCCLRedType(opts.reduce_op), + opts.root_rank, + stream); }, in_tensor, CommType::REDUCE, @@ -368,27 +296,11 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter( const ReduceScatterOptions& opts, bool sync_op, bool use_calc_stream) { - phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ 0, - rank_, - comm); - } - NCCL_CHECK(phi::dynload::ncclReduceScatter( - in_tensor.data(), - out_tensor->data(), - out_tensor->numel(), - phi::ToNCCLDataType(in_tensor.dtype()), - ToNCCLRedType(opts.reduce_op), - comm, - stream)); + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); + comm_context->ReduceScatter( + out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream); }, in_tensor, CommType::REDUCE_SCATTER, @@ -409,46 +321,30 @@ std::shared_ptr ProcessGroupNCCL::Scatter( /*cur_rank*/ rank_, size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, /*root_rank*/ opts.root_rank, rank_, - comm); + comm_context->GetNcclComm()); } + int64_t numel = in_tensor.numel() / size_; if (rank_ == opts.root_rank) { int64_t offset = 0; phi::DenseTensor partial_tensor; - GroupStart(); + comm_context->GroupStart(); for (auto i = 0; i < size_; i++) { partial_tensor = GetPartialTensor(in_tensor, offset, numel); - NCCL_CHECK(phi::dynload::ncclSend( - partial_tensor.data(), - numel, - phi::ToNCCLDataType(partial_tensor.dtype()), - i, - comm, - stream)); + comm_context->Send(partial_tensor, numel, i, stream); offset += numel; } - NCCL_CHECK( - phi::dynload::ncclRecv(out_tensor->data(), - numel, - phi::ToNCCLDataType(out_tensor->dtype()), - opts.root_rank, - comm, - stream)); - GroupEnd(); + comm_context->Recv(out_tensor, numel, opts.root_rank, stream); + comm_context->GroupEnd(); } else { - NCCL_CHECK( - phi::dynload::ncclRecv(out_tensor->data(), - numel, - phi::ToNCCLDataType(out_tensor->dtype()), - opts.root_rank, - comm, - stream)); + comm_context->Recv(out_tensor, numel, opts.root_rank, stream); } }, in_tensor, @@ -489,34 +385,30 @@ std::shared_ptr ProcessGroupNCCL::Gather( "root world size [%d] is less than root rank [%d]", size_, opts.root_rank)); - auto gather_func = [&](ncclComm_t comm, gpuStream_t stream) { + auto gather_func = [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); // shape check if (FLAGS_enable_nccl_dynamic_check) { phi::distributed::NCCLDynamicCheck::CheckGatherShape( - in_tensor, gather_tensors, opts.root_rank, rank_, size_, comm); + in_tensor, + gather_tensors, + opts.root_rank, + rank_, + size_, + comm_context->GetNcclComm()); } - GroupStart(); + + comm_context->GroupStart(); // root receive from all devices if (rank_ == opts.root_rank) { for (auto i = 0; i < size_; i++) { auto& gather_tensor = gather_tensors[i]; - NCCL_CHECK( - phi::dynload::ncclRecv(gather_tensor.data(), - gather_tensor.numel(), - phi::ToNCCLDataType(gather_tensor.dtype()), - i, - comm, - stream)); + comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream); } } // send to root - NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(), - in_tensor.numel(), - phi::ToNCCLDataType(in_tensor.dtype()), - opts.root_rank, - comm, - stream)); - GroupEnd(); + comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream); + comm_context->GroupEnd(); }; return RunFnInNCCLEnv( gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); @@ -536,21 +428,10 @@ std::shared_ptr ProcessGroupNCCL::Recv( tensor = &partial_tensor; } - phi::distributed::CommStaticCheck::CheckShape(*tensor, rank_, size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape(*tensor, - /*root_rank*/ src_rank, - rank_, - comm); - } - NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(), - tensor->numel(), - phi::ToNCCLDataType(tensor->dtype()), - src_rank, - comm, - stream)); + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); + comm_context->Recv(tensor, tensor->numel(), src_rank, stream); }, *tensor, CommType::RECV, @@ -569,23 +450,13 @@ std::shared_ptr ProcessGroupNCCL::Send( const phi::DenseTensor& tensor_maybe_partial = numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; - phi::distributed::CommStaticCheck::CheckShape( - tensor_maybe_partial, rank_, size_); return RunFnInNCCLEnv( - [&](ncclComm_t comm, gpuStream_t stream) { - if (FLAGS_enable_nccl_dynamic_check) { - phi::distributed::NCCLDynamicCheck::CheckShape(tensor_maybe_partial, - /*root_rank*/ rank_, - rank_, - comm); - } - NCCL_CHECK(phi::dynload::ncclSend( - tensor_maybe_partial.data(), - tensor_maybe_partial.numel(), - phi::ToNCCLDataType(tensor_maybe_partial.dtype()), - dst_rank, - comm, - stream)); + [&](gpuStream_t stream) { + auto comm_context = this->GetCommContext(); + comm_context->Send(tensor_maybe_partial, + tensor_maybe_partial.numel(), + dst_rank, + stream); }, tensor_maybe_partial, CommType::SEND, @@ -623,23 +494,14 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, VLOG(3) << "Warning: Tensors from multiple devices are not supported yet."; } - ncclUniqueId nccl_id; - if (rank_ == 0) { - NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id)); - } - BroadcastUniqueNCCLID(&nccl_id); - VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << place_key - << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); + << ", place: " << place_key; auto* calc_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); auto comm_ctx = std::make_unique(place); - ncclComm_t nccl_comm; - NCCL_CHECK(phi::dynload::ncclCommInitRank( - &nccl_comm, GetSize(), nccl_id, GetRank())); - comm_ctx->set_nccl_comm(nccl_comm); + auto nccl_comm_ctx = this->GetCommContext(); + comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm()); place_to_calc_event_.emplace(place_key, place); place_to_calc_ctx_.emplace(place_key, calc_ctx); @@ -661,7 +523,7 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) { } std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( - std::function fn, + std::function fn, const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, @@ -683,9 +545,8 @@ std::shared_ptr ProcessGroupNCCL::RunFnInNCCLEnv( const auto* calc_ctx = place_to_calc_ctx_.at(key); const auto& comm_ctx = place_to_comm_ctx_.at(key); - auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - fn(nccl_comm, nccl_stream); + fn(nccl_stream); if (!use_calc_stream) { if (FLAGS_use_stream_safe_cuda_allocator) { @@ -900,13 +761,9 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream) { - return phi::dynload::ncclAllReduce(input.data(), - output.data(), - input.numel(), - phi::ToNCCLDataType(input.type()), - ToNCCLRedType(opts.reduce_op), - comm, - stream); + auto comm_context = this->GetCommContext(); + comm_context->AllReduce( + &output, input, ToNCCLRedType(opts.reduce_op), stream); }, CommType::ALLREDUCE); } @@ -929,13 +786,8 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const gpuStream_t& stream) { const auto root = opts.source_rank * in_tensors.size() + opts.source_root; - return phi::dynload::ncclBroadcast(input.data(), - output.data(), - input.numel(), - phi::ToNCCLDataType(input.type()), - root, - comm, - stream); + auto comm_context = this->GetCommContext(); + comm_context->Broadcast(&output, input, root, stream); }, CommType::BROADCAST); } @@ -978,12 +830,8 @@ std::shared_ptr ProcessGroupNCCL::Send( ncclComm_t comm, const gpuStream_t& stream, int dst_rank) { - return phi::dynload::ncclSend(input.data(), - input.numel(), - phi::ToNCCLDataType(input.dtype()), - dst_rank, - comm, - stream); + auto comm_context = this->GetCommContext(); + comm_context->Send(input, input.numel(), dst_rank, stream); }, dst_rank, CommType::SEND); @@ -1000,12 +848,8 @@ std::shared_ptr ProcessGroupNCCL::Recv( ncclComm_t comm, const gpuStream_t& stream, int src_rank) { - return phi::dynload::ncclRecv(output.data(), - output.numel(), - phi::ToNCCLDataType(output.dtype()), - src_rank, - comm, - stream); + auto comm_context = this->GetCommContext(); + comm_context->Recv(&output, output.numel(), src_rank, stream); }, src_rank, CommType::RECV); @@ -1030,12 +874,8 @@ std::shared_ptr ProcessGroupNCCL::AllGather( phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream) { - return phi::dynload::ncclAllGather(input.data(), - output.data(), - input.numel(), - phi::ToNCCLDataType(input.dtype()), - comm, - stream); + auto comm_context = this->GetCommContext(); + comm_context->AllGather(&output, input, stream); }, CommType::ALLGATHER); } @@ -1059,25 +899,17 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( ncclComm_t comm, const gpuStream_t& stream) { size_t offset = 0; - GroupStart(); + size_t count = input.numel() / size_; + auto comm_context = this->GetCommContext(); + comm_context->GroupStart(); for (auto i = 0; i < size_; i++) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend( - GetPointerByOffset(input.data(), offset, input.dtype()), - input.numel() / size_, - phi::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRecv( - GetPointerByOffset(output.data(), offset, input.dtype()), - input.numel() / size_, - phi::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - offset += input.numel() / size_; + auto input_data = GetPartialTensor(input, offset, count); + comm_context->Send(input_data, count, i, stream); + auto output_data = GetPartialTensor(output, offset, count); + comm_context->Recv(&output_data, count, i, stream); + offset += count; } - GroupEnd(); + comm_context->GroupEnd(); }, CommType::ALLTOALL); } @@ -1097,15 +929,12 @@ std::shared_ptr ProcessGroupNCCL::Reduce( phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream) { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclReduce(input.data(), - output.data(), - input.numel(), - phi::ToNCCLDataType(input.dtype()), - ToNCCLRedType(opts.reduce_op), - opts.root_rank, - comm, - stream)); + auto comm_context = this->GetCommContext(); + comm_context->Reduce(&output, + input, + ToNCCLRedType(opts.reduce_op), + opts.root_rank, + stream); }, CommType::REDUCE); } @@ -1129,35 +958,21 @@ std::shared_ptr ProcessGroupNCCL::Scatter( phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream) { + auto comm_context = this->GetCommContext(); size_t offset = 0; + size_t count = input.numel() / size_; if (rank_ == opts.root_rank) { - GroupStart(); + comm_context->GroupStart(); for (auto i = 0; i < size_; i++) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend( - GetPointerByOffset(input.data(), offset, input.dtype()), - input.numel() / size_, - phi::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - offset += input.numel() / size_; + auto input_data = reinterpret_cast( + GetPointerByOffset(input.data(), offset, input.dtype())); + comm_context->Send(*input_data, count, i, stream); + offset += count; } - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclRecv(output.data(), - input.numel() / size_, - phi::ToNCCLDataType(input.dtype()), - opts.root_rank, - comm, - stream)); - GroupEnd(); + comm_context->Recv(&output, count, opts.root_rank, stream); + comm_context->GroupEnd(); } else { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclRecv(output.data(), - input.numel() / size_, - phi::ToNCCLDataType(input.dtype()), - opts.root_rank, - comm, - stream)); + comm_context->Recv(&output, count, opts.root_rank, stream); } }, CommType::SCATTER); @@ -1165,14 +980,28 @@ std::shared_ptr ProcessGroupNCCL::Scatter( std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( const std::shared_ptr& store, + int device_id, int rank, int size, int gid) { + phi::distributed::CommContextManager::CreateNCCLCommContext( + store, device_id, gid, rank, size); auto process_group = std::make_shared(store, rank, size, gid); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } +phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + auto comm_context = static_cast( + comm_context_manager.Get(this->gid_)); + PADDLE_ENFORCE_NE(comm_context, + nullptr, + phi::errors::Unavailable("NCCLCommContext is nullptr")); + return comm_context; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index d4a159a7f45..f2f5584e658 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -26,6 +26,7 @@ #include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/distributed/store/store.h" namespace paddle { @@ -68,6 +69,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { public: static std::shared_ptr CreateProcessGroupNCCL( const std::shared_ptr& store, + int device_id, int rank, int size, int gid); @@ -219,7 +221,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { void SyncCalcStream(const Place& place); std::shared_ptr RunFnInNCCLEnv( - std::function fn, + std::function fn, const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, @@ -249,6 +251,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { void CreateNCCLManagerCache(const std::string& places_key, const std::vector& places); + phi::distributed::NCCLCommContext* GetCommContext(); + private: std::shared_ptr store_; diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 01df736fb10..362a0d62b1d 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1238,6 +1238,7 @@ void BindDistributed(py::module *m) { .def_static("create", distributed::ProcessGroupNCCL::CreateProcessGroupNCCL, py::arg("store"), + py::arg("device_id"), py::arg("rank"), py::arg("world_size"), py::arg("group_id") = 0, diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 2011fe19811..14fa116c874 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -151,7 +151,10 @@ def _new_process_group_impl( if backend == "gloo": pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) elif backend == "nccl": - pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id) + pg = core.ProcessGroupNCCL.create( + store, genv.device_id, rank, world_size, group_id + ) + elif backend == "xccl": pg = core.ProcessGroupCustom.create( store, genv.device_type, rank, world_size, group_id -- GitLab