未验证 提交 5a6cd05f 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

update dygraph collective process group (#54863)

* update dygraph collective

fix ut

* remove debug log
上级 bbcaaffd
......@@ -26,6 +26,8 @@
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
PHI_DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
......@@ -144,26 +146,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclAllGather(
in_tensor_maybe_partial.data(),
out_tensor->data(),
in_tensor_maybe_partial.numel(),
phi::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
comm,
stream));
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream);
},
in_tensor_maybe_partial,
CommType::ALLGATHER,
......@@ -177,27 +163,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(
phi::dynload::ncclAllReduce(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
comm_context->AllReduce(
out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream);
},
in_tensor,
CommType::ALLREDUCE,
......@@ -221,49 +191,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check.
phi::distributed::CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(
*out_tensor, in_tensor, in_size_each_rank, rank_, size_, comm);
*out_tensor,
in_tensor,
in_size_each_rank,
rank_,
size_,
comm_context->GetNcclComm());
}
int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
GroupStart();
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
NCCL_CHECK(
phi::dynload::ncclSend(input_partial.data(),
in_numel,
phi::ToNCCLDataType(input_partial.dtype()),
i,
comm,
stream));
comm_context->Send(input_partial, in_numel, i, stream);
in_offset += in_numel;
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
NCCL_CHECK(phi::dynload::ncclRecv(
output_partial.data(),
out_numel,
phi::ToNCCLDataType(output_partial.dtype()),
i,
comm,
stream));
comm_context->Recv(&output_partial, out_numel, i, stream);
out_offset += out_numel;
}
GroupEnd();
comm_context->GroupEnd();
},
in_tensor,
CommType::ALLTOALL,
......@@ -299,26 +257,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
[&](gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(
*out_tensor, root, rank_, comm);
}
NCCL_CHECK(
phi::dynload::ncclBroadcast(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
root,
comm,
stream));
auto comm_context = this->GetCommContext();
comm_context->Broadcast(out_tensor, in_tensor, root, stream);
},
in_tensor,
CommType::BROADCAST,
......@@ -332,29 +275,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(
*out_tensor,
/*root_rank*/ opts.root_rank,
rank_,
comm);
}
NCCL_CHECK(
phi::dynload::ncclReduce(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
comm_context->Reduce(out_tensor,
in_tensor,
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
stream);
},
in_tensor,
CommType::REDUCE,
......@@ -368,27 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor,
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclReduceScatter(
in_tensor.data(),
out_tensor->data(),
out_tensor->numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
comm_context->ReduceScatter(
out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream);
},
in_tensor,
CommType::REDUCE_SCATTER,
......@@ -409,46 +321,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(
*out_tensor,
/*root_rank*/ opts.root_rank,
rank_,
comm);
comm_context->GetNcclComm());
}
int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
GroupStart();
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
NCCL_CHECK(phi::dynload::ncclSend(
partial_tensor.data(),
numel,
phi::ToNCCLDataType(partial_tensor.dtype()),
i,
comm,
stream));
comm_context->Send(partial_tensor, numel, i, stream);
offset += numel;
}
NCCL_CHECK(
phi::dynload::ncclRecv(out_tensor->data(),
numel,
phi::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
comm_context->GroupEnd();
} else {
NCCL_CHECK(
phi::dynload::ncclRecv(out_tensor->data(),
numel,
phi::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream));
comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
}
},
in_tensor,
......@@ -489,34 +385,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather(
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
auto gather_func = [&](ncclComm_t comm, gpuStream_t stream) {
auto gather_func = [&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
// shape check
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckGatherShape(
in_tensor, gather_tensors, opts.root_rank, rank_, size_, comm);
in_tensor,
gather_tensors,
opts.root_rank,
rank_,
size_,
comm_context->GetNcclComm());
}
GroupStart();
comm_context->GroupStart();
// root receive from all devices
if (rank_ == opts.root_rank) {
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
NCCL_CHECK(
phi::dynload::ncclRecv(gather_tensor.data(),
gather_tensor.numel(),
phi::ToNCCLDataType(gather_tensor.dtype()),
i,
comm,
stream));
comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream);
}
}
// send to root
NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream);
comm_context->GroupEnd();
};
return RunFnInNCCLEnv(
gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream);
......@@ -536,21 +428,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
tensor = &partial_tensor;
}
phi::distributed::CommStaticCheck::CheckShape(*tensor, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(*tensor,
/*root_rank*/ src_rank,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(),
tensor->numel(),
phi::ToNCCLDataType(tensor->dtype()),
src_rank,
comm,
stream));
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
comm_context->Recv(tensor, tensor->numel(), src_rank, stream);
},
*tensor,
CommType::RECV,
......@@ -569,23 +450,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
phi::distributed::CommStaticCheck::CheckShape(
tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape(tensor_maybe_partial,
/*root_rank*/ rank_,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclSend(
tensor_maybe_partial.data(),
tensor_maybe_partial.numel(),
phi::ToNCCLDataType(tensor_maybe_partial.dtype()),
dst_rank,
comm,
stream));
[&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
comm_context->Send(tensor_maybe_partial,
tensor_maybe_partial.numel(),
dst_rank,
stream);
},
tensor_maybe_partial,
CommType::SEND,
......@@ -623,23 +494,14 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
}
ncclUniqueId nccl_id;
if (rank_ == 0) {
NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
<< ", place: " << place_key;
auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm;
NCCL_CHECK(phi::dynload::ncclCommInitRank(
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
auto nccl_comm_ctx = this->GetCommContext();
comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm());
place_to_calc_event_.emplace(place_key, place);
place_to_calc_ctx_.emplace(place_key, calc_ctx);
......@@ -661,7 +523,7 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn,
std::function<void(gpuStream_t)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
......@@ -683,9 +545,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(nccl_comm, nccl_stream);
fn(nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
......@@ -900,13 +761,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return phi::dynload::ncclAllReduce(input.data(),
output.data(),
input.numel(),
phi::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
auto comm_context = this->GetCommContext();
comm_context->AllReduce(
&output, input, ToNCCLRedType(opts.reduce_op), stream);
},
CommType::ALLREDUCE);
}
......@@ -929,13 +786,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const gpuStream_t& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return phi::dynload::ncclBroadcast(input.data(),
output.data(),
input.numel(),
phi::ToNCCLDataType(input.type()),
root,
comm,
stream);
auto comm_context = this->GetCommContext();
comm_context->Broadcast(&output, input, root, stream);
},
CommType::BROADCAST);
}
......@@ -978,12 +830,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
ncclComm_t comm,
const gpuStream_t& stream,
int dst_rank) {
return phi::dynload::ncclSend(input.data(),
input.numel(),
phi::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
auto comm_context = this->GetCommContext();
comm_context->Send(input, input.numel(), dst_rank, stream);
},
dst_rank,
CommType::SEND);
......@@ -1000,12 +848,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
ncclComm_t comm,
const gpuStream_t& stream,
int src_rank) {
return phi::dynload::ncclRecv(output.data(),
output.numel(),
phi::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
auto comm_context = this->GetCommContext();
comm_context->Recv(&output, output.numel(), src_rank, stream);
},
src_rank,
CommType::RECV);
......@@ -1030,12 +874,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return phi::dynload::ncclAllGather(input.data(),
output.data(),
input.numel(),
phi::ToNCCLDataType(input.dtype()),
comm,
stream);
auto comm_context = this->GetCommContext();
comm_context->AllGather(&output, input, stream);
},
CommType::ALLGATHER);
}
......@@ -1059,25 +899,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
ncclComm_t comm,
const gpuStream_t& stream) {
size_t offset = 0;
GroupStart();
size_t count = input.numel() / size_;
auto comm_context = this->GetCommContext();
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRecv(
GetPointerByOffset(output.data(), offset, input.dtype()),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
auto input_data = GetPartialTensor(input, offset, count);
comm_context->Send(input_data, count, i, stream);
auto output_data = GetPartialTensor(output, offset, count);
comm_context->Recv(&output_data, count, i, stream);
offset += count;
}
GroupEnd();
comm_context->GroupEnd();
},
CommType::ALLTOALL);
}
......@@ -1097,15 +929,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclReduce(input.data(),
output.data(),
input.numel(),
phi::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
auto comm_context = this->GetCommContext();
comm_context->Reduce(&output,
input,
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
stream);
},
CommType::REDUCE);
}
......@@ -1129,35 +958,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
auto comm_context = this->GetCommContext();
size_t offset = 0;
size_t count = input.numel() / size_;
if (rank_ == opts.root_rank) {
GroupStart();
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
auto input_data = reinterpret_cast<phi::DenseTensor*>(
GetPointerByOffset(input.data(), offset, input.dtype()));
comm_context->Send(*input_data, count, i, stream);
offset += count;
}
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclRecv(output.data(),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
comm_context->Recv(&output, count, opts.root_rank, stream);
comm_context->GroupEnd();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclRecv(output.data(),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
comm_context->Recv(&output, count, opts.root_rank, stream);
}
},
CommType::SCATTER);
......@@ -1165,14 +980,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank,
int size,
int gid) {
phi::distributed::CommContextManager::CreateNCCLCommContext(
store, device_id, gid, rank, size);
auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(this->gid_));
PADDLE_ENFORCE_NE(comm_context,
nullptr,
phi::errors::Unavailable("NCCLCommContext is nullptr"));
return comm_context;
}
} // namespace distributed
} // namespace paddle
......@@ -26,6 +26,7 @@
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h"
namespace paddle {
......@@ -68,6 +69,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank,
int size,
int gid);
......@@ -219,7 +221,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn,
std::function<void(gpuStream_t)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
......@@ -249,6 +251,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
phi::distributed::NCCLCommContext* GetCommContext();
private:
std::shared_ptr<phi::distributed::Store> store_;
......
......@@ -1238,6 +1238,7 @@ void BindDistributed(py::module *m) {
.def_static("create",
distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
py::arg("store"),
py::arg("device_id"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
......
......@@ -151,7 +151,10 @@ def _new_process_group_impl(
if backend == "gloo":
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl":
pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
pg = core.ProcessGroupNCCL.create(
store, genv.device_id, rank, world_size, group_id
)
elif backend == "xccl":
pg = core.ProcessGroupCustom.create(
store, genv.device_type, rank, world_size, group_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册