未验证 提交 5a6cd05f 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

update dygraph collective process group (#54863)

* update dygraph collective

fix ut

* remove debug log
上级 bbcaaffd
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
PHI_DECLARE_bool(nccl_blocking_wait); PHI_DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator); DECLARE_bool(use_stream_safe_cuda_allocator);
...@@ -144,26 +146,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -144,26 +146,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial = const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream);
/*root_rank*/ 0,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclAllGather(
in_tensor_maybe_partial.data(),
out_tensor->data(),
in_tensor_maybe_partial.numel(),
phi::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
comm,
stream));
}, },
in_tensor_maybe_partial, in_tensor_maybe_partial,
CommType::ALLGATHER, CommType::ALLGATHER,
...@@ -177,27 +163,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -177,27 +163,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, comm_context->AllReduce(
/*root_rank*/ 0, out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream);
rank_,
comm);
}
NCCL_CHECK(
phi::dynload::ncclAllReduce(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
}, },
in_tensor, in_tensor,
CommType::ALLREDUCE, CommType::ALLREDUCE,
...@@ -221,49 +191,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -221,49 +191,37 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
// simply be covered by static checks. Factors are set to 0 here to skip the // simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks with // shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check. // FLAGS_enable_nccl_dynamic_check.
phi::distributed::CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
if (FLAGS_enable_nccl_dynamic_check) { if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape( phi::distributed::NCCLDynamicCheck::CheckShape(
*out_tensor, in_tensor, in_size_each_rank, rank_, size_, comm); *out_tensor,
in_tensor,
in_size_each_rank,
rank_,
size_,
comm_context->GetNcclComm());
} }
int64_t in_row_size = in_tensor.numel() / in_dim[0], int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0]; out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial; phi::DenseTensor input_partial, output_partial;
GroupStart(); comm_context->GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size; in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
NCCL_CHECK( comm_context->Send(input_partial, in_numel, i, stream);
phi::dynload::ncclSend(input_partial.data(),
in_numel,
phi::ToNCCLDataType(input_partial.dtype()),
i,
comm,
stream));
in_offset += in_numel; in_offset += in_numel;
out_numel = out_size_each_rank[i] * out_row_size; out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel); output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
NCCL_CHECK(phi::dynload::ncclRecv( comm_context->Recv(&output_partial, out_numel, i, stream);
output_partial.data(),
out_numel,
phi::ToNCCLDataType(output_partial.dtype()),
i,
comm,
stream));
out_offset += out_numel; out_offset += out_numel;
} }
GroupEnd(); comm_context->GroupEnd();
}, },
in_tensor, in_tensor,
CommType::ALLTOALL, CommType::ALLTOALL,
...@@ -299,26 +257,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -299,26 +257,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
int root = opts.source_rank + opts.source_root; int root = opts.source_rank + opts.source_root;
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape( comm_context->Broadcast(out_tensor, in_tensor, root, stream);
*out_tensor, root, rank_, comm);
}
NCCL_CHECK(
phi::dynload::ncclBroadcast(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
root,
comm,
stream));
}, },
in_tensor, in_tensor,
CommType::BROADCAST, CommType::BROADCAST,
...@@ -332,29 +275,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -332,29 +275,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape( comm_context->Reduce(out_tensor,
*out_tensor, in_tensor,
/*root_rank*/ opts.root_rank, ToNCCLRedType(opts.reduce_op),
rank_, opts.root_rank,
comm); stream);
}
NCCL_CHECK(
phi::dynload::ncclReduce(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
}, },
in_tensor, in_tensor,
CommType::REDUCE, CommType::REDUCE,
...@@ -368,27 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter( ...@@ -368,27 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts, const ReduceScatterOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, comm_context->ReduceScatter(
/*root_rank*/ 0, out_tensor, in_tensor, ToNCCLRedType(opts.reduce_op), stream);
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclReduceScatter(
in_tensor.data(),
out_tensor->data(),
out_tensor->numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
}, },
in_tensor, in_tensor,
CommType::REDUCE_SCATTER, CommType::REDUCE_SCATTER,
...@@ -409,46 +321,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -409,46 +321,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
/*cur_rank*/ rank_, /*cur_rank*/ rank_,
size_); size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
if (FLAGS_enable_nccl_dynamic_check) { if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckShape( phi::distributed::NCCLDynamicCheck::CheckShape(
*out_tensor, *out_tensor,
/*root_rank*/ opts.root_rank, /*root_rank*/ opts.root_rank,
rank_, rank_,
comm); comm_context->GetNcclComm());
} }
int64_t numel = in_tensor.numel() / size_; int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
int64_t offset = 0; int64_t offset = 0;
phi::DenseTensor partial_tensor; phi::DenseTensor partial_tensor;
GroupStart(); comm_context->GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel); partial_tensor = GetPartialTensor(in_tensor, offset, numel);
NCCL_CHECK(phi::dynload::ncclSend( comm_context->Send(partial_tensor, numel, i, stream);
partial_tensor.data(),
numel,
phi::ToNCCLDataType(partial_tensor.dtype()),
i,
comm,
stream));
offset += numel; offset += numel;
} }
NCCL_CHECK( comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
phi::dynload::ncclRecv(out_tensor->data(), comm_context->GroupEnd();
numel,
phi::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else { } else {
NCCL_CHECK( comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
phi::dynload::ncclRecv(out_tensor->data(),
numel,
phi::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream));
} }
}, },
in_tensor, in_tensor,
...@@ -489,34 +385,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather( ...@@ -489,34 +385,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather(
"root world size [%d] is less than root rank [%d]", "root world size [%d] is less than root rank [%d]",
size_, size_,
opts.root_rank)); opts.root_rank));
auto gather_func = [&](ncclComm_t comm, gpuStream_t stream) { auto gather_func = [&](gpuStream_t stream) {
auto comm_context = this->GetCommContext();
// shape check // shape check
if (FLAGS_enable_nccl_dynamic_check) { if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckGatherShape( phi::distributed::NCCLDynamicCheck::CheckGatherShape(
in_tensor, gather_tensors, opts.root_rank, rank_, size_, comm); in_tensor,
gather_tensors,
opts.root_rank,
rank_,
size_,
comm_context->GetNcclComm());
} }
GroupStart();
comm_context->GroupStart();
// root receive from all devices // root receive from all devices
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i]; auto& gather_tensor = gather_tensors[i];
NCCL_CHECK( comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream);
phi::dynload::ncclRecv(gather_tensor.data(),
gather_tensor.numel(),
phi::ToNCCLDataType(gather_tensor.dtype()),
i,
comm,
stream));
} }
} }
// send to root // send to root
NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(), comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream);
in_tensor.numel(), comm_context->GroupEnd();
phi::ToNCCLDataType(in_tensor.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
}; };
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream);
...@@ -536,21 +428,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -536,21 +428,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
tensor = &partial_tensor; tensor = &partial_tensor;
} }
phi::distributed::CommStaticCheck::CheckShape(*tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape(*tensor, comm_context->Recv(tensor, tensor->numel(), src_rank, stream);
/*root_rank*/ src_rank,
rank_,
comm);
}
NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(),
tensor->numel(),
phi::ToNCCLDataType(tensor->dtype()),
src_rank,
comm,
stream));
}, },
*tensor, *tensor,
CommType::RECV, CommType::RECV,
...@@ -569,23 +450,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -569,23 +450,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
const phi::DenseTensor& tensor_maybe_partial = const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
phi::distributed::CommStaticCheck::CheckShape(
tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) { auto comm_context = this->GetCommContext();
phi::distributed::NCCLDynamicCheck::CheckShape(tensor_maybe_partial, comm_context->Send(tensor_maybe_partial,
/*root_rank*/ rank_, tensor_maybe_partial.numel(),
rank_, dst_rank,
comm); stream);
}
NCCL_CHECK(phi::dynload::ncclSend(
tensor_maybe_partial.data(),
tensor_maybe_partial.numel(),
phi::ToNCCLDataType(tensor_maybe_partial.dtype()),
dst_rank,
comm,
stream));
}, },
tensor_maybe_partial, tensor_maybe_partial,
CommType::SEND, CommType::SEND,
...@@ -623,23 +494,14 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ...@@ -623,23 +494,14 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
VLOG(3) << "Warning: Tensors from multiple devices are not supported yet."; VLOG(3) << "Warning: Tensors from multiple devices are not supported yet.";
} }
ncclUniqueId nccl_id;
if (rank_ == 0) {
NCCL_CHECK(phi::dynload::ncclGetUniqueId(&nccl_id));
}
BroadcastUniqueNCCLID(&nccl_id);
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key << ", place: " << place_key;
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);
auto* calc_ctx = static_cast<phi::GPUContext*>( auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place); auto comm_ctx = std::make_unique<phi::GPUContext>(place);
ncclComm_t nccl_comm; auto nccl_comm_ctx = this->GetCommContext();
NCCL_CHECK(phi::dynload::ncclCommInitRank( comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm());
&nccl_comm, GetSize(), nccl_id, GetRank()));
comm_ctx->set_nccl_comm(nccl_comm);
place_to_calc_event_.emplace(place_key, place); place_to_calc_event_.emplace(place_key, place);
place_to_calc_ctx_.emplace(place_key, calc_ctx); place_to_calc_ctx_.emplace(place_key, calc_ctx);
...@@ -661,7 +523,7 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) { ...@@ -661,7 +523,7 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn, std::function<void(gpuStream_t)> fn,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
CommType comm_type, CommType comm_type,
bool sync_op, bool sync_op,
...@@ -683,9 +545,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv( ...@@ -683,9 +545,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
const auto* calc_ctx = place_to_calc_ctx_.at(key); const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key); const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(nccl_comm, nccl_stream); fn(nccl_stream);
if (!use_calc_stream) { if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) { if (FLAGS_use_stream_safe_cuda_allocator) {
...@@ -900,13 +761,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -900,13 +761,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
phi::DenseTensor& output, phi::DenseTensor& output,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
return phi::dynload::ncclAllReduce(input.data(), auto comm_context = this->GetCommContext();
output.data(), comm_context->AllReduce(
input.numel(), &output, input, ToNCCLRedType(opts.reduce_op), stream);
phi::ToNCCLDataType(input.type()),
ToNCCLRedType(opts.reduce_op),
comm,
stream);
}, },
CommType::ALLREDUCE); CommType::ALLREDUCE);
} }
...@@ -929,13 +786,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -929,13 +786,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const gpuStream_t& stream) { const gpuStream_t& stream) {
const auto root = const auto root =
opts.source_rank * in_tensors.size() + opts.source_root; opts.source_rank * in_tensors.size() + opts.source_root;
return phi::dynload::ncclBroadcast(input.data(), auto comm_context = this->GetCommContext();
output.data(), comm_context->Broadcast(&output, input, root, stream);
input.numel(),
phi::ToNCCLDataType(input.type()),
root,
comm,
stream);
}, },
CommType::BROADCAST); CommType::BROADCAST);
} }
...@@ -978,12 +830,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -978,12 +830,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream, const gpuStream_t& stream,
int dst_rank) { int dst_rank) {
return phi::dynload::ncclSend(input.data(), auto comm_context = this->GetCommContext();
input.numel(), comm_context->Send(input, input.numel(), dst_rank, stream);
phi::ToNCCLDataType(input.dtype()),
dst_rank,
comm,
stream);
}, },
dst_rank, dst_rank,
CommType::SEND); CommType::SEND);
...@@ -1000,12 +848,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -1000,12 +848,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream, const gpuStream_t& stream,
int src_rank) { int src_rank) {
return phi::dynload::ncclRecv(output.data(), auto comm_context = this->GetCommContext();
output.numel(), comm_context->Recv(&output, output.numel(), src_rank, stream);
phi::ToNCCLDataType(output.dtype()),
src_rank,
comm,
stream);
}, },
src_rank, src_rank,
CommType::RECV); CommType::RECV);
...@@ -1030,12 +874,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -1030,12 +874,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi::DenseTensor& output, phi::DenseTensor& output,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
return phi::dynload::ncclAllGather(input.data(), auto comm_context = this->GetCommContext();
output.data(), comm_context->AllGather(&output, input, stream);
input.numel(),
phi::ToNCCLDataType(input.dtype()),
comm,
stream);
}, },
CommType::ALLGATHER); CommType::ALLGATHER);
} }
...@@ -1059,25 +899,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -1059,25 +899,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
size_t offset = 0; size_t offset = 0;
GroupStart(); size_t count = input.numel() / size_;
auto comm_context = this->GetCommContext();
comm_context->GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend( auto input_data = GetPartialTensor(input, offset, count);
GetPointerByOffset(input.data(), offset, input.dtype()), comm_context->Send(input_data, count, i, stream);
input.numel() / size_, auto output_data = GetPartialTensor(output, offset, count);
phi::ToNCCLDataType(input.dtype()), comm_context->Recv(&output_data, count, i, stream);
i, offset += count;
comm,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRecv(
GetPointerByOffset(output.data(), offset, input.dtype()),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
offset += input.numel() / size_;
} }
GroupEnd(); comm_context->GroupEnd();
}, },
CommType::ALLTOALL); CommType::ALLTOALL);
} }
...@@ -1097,15 +929,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -1097,15 +929,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
phi::DenseTensor& output, phi::DenseTensor& output,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
PADDLE_ENFORCE_GPU_SUCCESS( auto comm_context = this->GetCommContext();
phi::dynload::ncclReduce(input.data(), comm_context->Reduce(&output,
output.data(), input,
input.numel(), ToNCCLRedType(opts.reduce_op),
phi::ToNCCLDataType(input.dtype()), opts.root_rank,
ToNCCLRedType(opts.reduce_op), stream);
opts.root_rank,
comm,
stream));
}, },
CommType::REDUCE); CommType::REDUCE);
} }
...@@ -1129,35 +958,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1129,35 +958,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
phi::DenseTensor& output, phi::DenseTensor& output,
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
auto comm_context = this->GetCommContext();
size_t offset = 0; size_t offset = 0;
size_t count = input.numel() / size_;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
GroupStart(); comm_context->GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend( auto input_data = reinterpret_cast<phi::DenseTensor*>(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()));
input.numel() / size_, comm_context->Send(*input_data, count, i, stream);
phi::ToNCCLDataType(input.dtype()), offset += count;
i,
comm,
stream));
offset += input.numel() / size_;
} }
PADDLE_ENFORCE_GPU_SUCCESS( comm_context->Recv(&output, count, opts.root_rank, stream);
phi::dynload::ncclRecv(output.data(), comm_context->GroupEnd();
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS( comm_context->Recv(&output, count, opts.root_rank, stream);
phi::dynload::ncclRecv(output.data(),
input.numel() / size_,
phi::ToNCCLDataType(input.dtype()),
opts.root_rank,
comm,
stream));
} }
}, },
CommType::SCATTER); CommType::SCATTER);
...@@ -1165,14 +980,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1165,14 +980,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL( std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank, int rank,
int size, int size,
int gid) { int gid) {
phi::distributed::CommContextManager::CreateNCCLCommContext(
store, device_id, gid, rank, size);
auto process_group = auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid); std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group); ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group; return process_group;
} }
phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(this->gid_));
PADDLE_ENFORCE_NE(comm_context,
nullptr,
phi::errors::Unavailable("NCCLCommContext is nullptr"));
return comm_context;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h" #include "paddle/phi/core/distributed/store/store.h"
namespace paddle { namespace paddle {
...@@ -68,6 +69,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -68,6 +69,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
public: public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL( static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank, int rank,
int size, int size,
int gid); int gid);
...@@ -219,7 +221,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -219,7 +221,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
void SyncCalcStream(const Place& place); void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv( std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn, std::function<void(gpuStream_t)> fn,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
CommType comm_type, CommType comm_type,
bool sync_op, bool sync_op,
...@@ -249,6 +251,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -249,6 +251,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
void CreateNCCLManagerCache(const std::string& places_key, void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places); const std::vector<Place>& places);
phi::distributed::NCCLCommContext* GetCommContext();
private: private:
std::shared_ptr<phi::distributed::Store> store_; std::shared_ptr<phi::distributed::Store> store_;
......
...@@ -1238,6 +1238,7 @@ void BindDistributed(py::module *m) { ...@@ -1238,6 +1238,7 @@ void BindDistributed(py::module *m) {
.def_static("create", .def_static("create",
distributed::ProcessGroupNCCL::CreateProcessGroupNCCL, distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
py::arg("store"), py::arg("store"),
py::arg("device_id"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
......
...@@ -151,7 +151,10 @@ def _new_process_group_impl( ...@@ -151,7 +151,10 @@ def _new_process_group_impl(
if backend == "gloo": if backend == "gloo":
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl": elif backend == "nccl":
pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id) pg = core.ProcessGroupNCCL.create(
store, genv.device_id, rank, world_size, group_id
)
elif backend == "xccl": elif backend == "xccl":
pg = core.ProcessGroupCustom.create( pg = core.ProcessGroupCustom.create(
store, genv.device_type, rank, world_size, group_id store, genv.device_type, rank, world_size, group_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册