未验证 提交 88410225 编写于 作者: W Wen Sun 提交者: GitHub

Unify `ProcessGroupNCCL` APIs underlying implementation (#48163)

* refactor: replace Collective & PointToPoint with NCCLEnv

* refactor: rename to RunFnInNCCLEnv

* refactor: pass std::function by value
上级 2a47416c
......@@ -185,7 +185,8 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor*,
virtual std::shared_ptr<ProcessGroup::Task> Send(
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
......
......@@ -137,21 +137,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
return Collective(
out_tensor,
in_tensor_maybe_partial,
[](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllGather(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
in_tensor_maybe_partial.data(),
out_tensor->data(),
in_tensor_maybe_partial.numel(),
platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()),
comm,
stream));
},
in_tensor_maybe_partial,
CommType::ALLGATHER,
sync_op,
use_calc_stream);
......@@ -163,22 +159,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
platform::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
in_tensor,
CommType::ALLREDUCE,
sync_op,
use_calc_stream);
......@@ -215,37 +207,32 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t in_row_size = input.numel() / in_dim[0],
out_row_size = output->numel() / out_dim[0];
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
GroupStart();
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(input, in_offset, in_numel);
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
NCCL_CHECK(platform::dynload::ncclSend(
input_partial.data(),
in_numel,
platform::ToNCCLDataType(input.dtype()),
platform::ToNCCLDataType(input_partial.dtype()),
i,
comm,
stream));
in_offset += in_numel;
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*output, out_offset, out_numel);
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
NCCL_CHECK(platform::dynload::ncclRecv(
output_partial.data(),
out_numel,
platform::ToNCCLDataType(output->dtype()),
platform::ToNCCLDataType(output_partial.dtype()),
i,
comm,
stream));
......@@ -253,6 +240,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
}
GroupEnd();
},
in_tensor,
CommType::ALLTOALL,
sync_op,
use_calc_stream);
......@@ -286,23 +274,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
NCCL_CHECK(platform::dynload::ncclBroadcast(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.type()),
in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
platform::ToNCCLDataType(in_tensor.dtype()),
root,
comm,
stream));
},
in_tensor,
CommType::BROADCAST,
sync_op,
use_calc_stream);
......@@ -314,23 +298,19 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce(
input.data(),
output->data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
platform::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream));
},
in_tensor,
CommType::REDUCE,
sync_op,
use_calc_stream);
......@@ -342,22 +322,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter(
input.data(),
output->data(),
output->numel(),
platform::ToNCCLDataType(input.dtype()),
in_tensor.data(),
out_tensor->data(),
out_tensor->numel(),
platform::ToNCCLDataType(in_tensor.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
in_tensor,
CommType::REDUCE_SCATTER,
sync_op,
use_calc_stream);
......@@ -369,47 +345,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t numel = input.numel() / size_;
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
GroupStart();
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel);
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
NCCL_CHECK(platform::dynload::ncclSend(
partial_tensor.data(),
numel,
platform::ToNCCLDataType(input.dtype()),
platform::ToNCCLDataType(partial_tensor.dtype()),
i,
comm,
stream));
offset += numel;
}
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
out_tensor->data(),
numel,
platform::ToNCCLDataType(output->dtype()),
platform::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
} else {
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
out_tensor->data(),
numel,
platform::ToNCCLDataType(output->dtype()),
platform::ToNCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream));
}
},
in_tensor,
CommType::SCATTER,
sync_op,
use_calc_stream);
......@@ -428,54 +400,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor;
}
return PointToPoint(
tensor,
src_rank,
[](phi::DenseTensor* output,
int src,
ncclComm_t comm,
gpuStream_t stream) {
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclRecv(
output->data(),
output->numel(),
platform::ToNCCLDataType(output->dtype()),
src,
tensor->data(),
tensor->numel(),
platform::ToNCCLDataType(tensor->dtype()),
src_rank,
comm,
stream));
},
*tensor,
CommType::RECV,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
phi::DenseTensor* tensor,
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
// numel > 0 indicates the tensor need to be sliced
phi::DenseTensor partial_tensor;
if (numel > 0) {
partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor;
}
return PointToPoint(
tensor,
dst_rank,
[](phi::DenseTensor* input,
int dst,
ncclComm_t comm,
gpuStream_t stream) {
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclSend(
input->data(),
input->numel(),
platform::ToNCCLDataType(input->dtype()),
dst,
tensor_maybe_partial.data(),
tensor_maybe_partial.numel(),
platform::ToNCCLDataType(tensor_maybe_partial.dtype()),
dst_rank,
comm,
stream));
},
tensor_maybe_partial,
CommType::SEND,
sync_op,
use_calc_stream);
......@@ -548,54 +509,13 @@ void ProcessGroupNCCL::SyncCalcStream(const Place& place) {
calc_event.Wait(platform::Place2DeviceType(place), comm_ctx);
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = in_tensor.place();
const auto& key = GetKeyFromPlace(place);
platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key);
}
if (!use_calc_stream) {
SyncCalcStream(place);
}
auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream);
const auto* calc_ctx = place_to_calc_ctx_.at(key);
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(out_tensor, in_tensor, nccl_comm, nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(in_tensor.Holder(), nccl_stream);
}
task->UpdateWaitChain(*comm_ctx);
}
return task;
}
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
phi::DenseTensor* tensor,
int rank,
Fn fn,
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream) {
const auto& place = tensor->place();
const auto& place = tensor.place();
const auto& key = GetKeyFromPlace(place);
platform::CUDADeviceGuard cuda_guard(place);
......@@ -614,11 +534,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
const auto& comm_ctx = place_to_comm_ctx_.at(key);
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();
fn(tensor, rank, nccl_comm, nccl_stream);
fn(nccl_comm, nccl_stream);
if (!use_calc_stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
memory::RecordStream(tensor->Holder(), nccl_stream);
memory::RecordStream(tensor.Holder(), nccl_stream);
}
task->UpdateWaitChain(*comm_ctx);
}
......
......@@ -150,7 +150,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
......@@ -210,23 +210,13 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn>
std::shared_ptr<ProcessGroupStream::Task> Collective(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
Fn fn,
std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn,
const phi::DenseTensor& tensor,
CommType comm_type,
bool sync_op,
bool use_calc_stream);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(phi::DenseTensor* tensor,
int rank,
Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void SyncCalcStream(const Place& place);
// TODO(sunyilun): methods below will be removed later
......
......@@ -212,7 +212,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor* tensor,
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
......@@ -226,7 +226,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
phi::DenseTensor*,
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
......
......@@ -168,13 +168,14 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> Send(phi::DenseTensor* tensor,
virtual std::shared_ptr<ProcessGroup::Task> Send(
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
......
......@@ -226,7 +226,7 @@ struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send(&tmp,
pg->Send(tmp,
j,
send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
......
......@@ -224,7 +224,7 @@ struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send(&tmp,
pg->Send(tmp,
j,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
......
......@@ -70,7 +70,7 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
phi::DenseTensor tmp = *x;
auto task = pg->Send(&tmp, peer, offset, send_numel, /*sync_op*/ true);
auto task = pg->Send(tmp, peer, offset, send_numel, /*sync_op*/ true);
task->Wait();
} else {
gpuStream_t stream = nullptr;
......
......@@ -168,7 +168,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto out_dense = *p_dense;
// numel == -1 indicates sending the whole tensor
return self.Send(
out_dense, dst, /*offset*/ 0, /*numel*/ -1, sync_op);
......@@ -189,7 +189,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto out_dense = *p_dense;
int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks;
......@@ -1126,7 +1126,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto out_dense = *p_dense;
// numel == -1 indicates sending the whole tensor
return self.Send(out_dense,
dst,
......@@ -1149,7 +1149,7 @@ void BindDistributed(py::module *m) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto *out_dense = p_dense.get();
auto out_dense = *p_dense;
int64_t numel = p_dense->numel();
int64_t send_numel = numel / nranks;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册