未验证 提交 3f480af2 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication all_to_all, all_to_all_single C++ API (#48059)

上级 dbc63555
...@@ -46,7 +46,6 @@ enum class CommType : std::uint8_t { ...@@ -46,7 +46,6 @@ enum class CommType : std::uint8_t {
SEND = 9, SEND = 9,
RECV = 10, RECV = 10,
BARRIER = 11, BARRIER = 11,
ALLTOALL_SINGLE = 12,
UNKNOWN = 100, UNKNOWN = 100,
}; };
...@@ -124,6 +123,17 @@ class ProcessGroup { ...@@ -124,6 +123,17 @@ class ProcessGroup {
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_to_all with sync_op flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Barrier( virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) { const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -255,25 +265,6 @@ class ProcessGroup { ...@@ -255,25 +265,6 @@ class ProcessGroup {
"ProcessGroup%s does not support alltoall", GetBackendName())); "ProcessGroup%s does not support alltoall", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<int64_t>&,
std::vector<int64_t>&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll_Single", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<int64_t>&,
std::vector<int64_t>&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support alltoall_single", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce( virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
......
...@@ -184,6 +184,80 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -184,6 +184,80 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
use_calc_stream); use_calc_stream);
} }
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
const std::vector<int64_t>& size_on_each_rank,
int world_size) {
int length_size_on_each_rank = size_on_each_rank.size();
PADDLE_ENFORCE_EQ(
length_size_on_each_rank,
world_size,
platform::errors::InvalidArgument(
"The length of size_on_each_rank must be equal to world_size."));
int64_t sum_size_on_each_rank =
std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
PADDLE_ENFORCE_EQ(
sum_size_on_each_rank,
tensor_dim[0],
platform::errors::InvalidArgument(
"The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
const phi::DDim& out_dim = out_tensor->dims();
const phi::DDim& in_dim = in_tensor.dims();
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
return Collective(
out_tensor,
in_tensor,
[&](phi::DenseTensor* output,
const phi::DenseTensor& input,
ncclComm_t comm,
gpuStream_t stream) {
int64_t in_row_size = input.numel() / in_dim[0],
out_row_size = output->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
GroupStart();
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(input, in_offset, in_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
input_partial.data(),
in_numel,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_numel;
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*output, out_offset, out_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output_partial.data(),
out_numel,
platform::ToNCCLDataType(output->dtype()),
i,
comm,
stream));
out_offset += out_numel;
}
GroupEnd();
},
CommType::ALLTOALL,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id, PADDLE_ENFORCE_GE(opts.device_id,
...@@ -551,7 +625,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -551,7 +625,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
std::vector<phi::GPUContext*> dev_ctx_raw; std::vector<phi::GPUContext*> dev_ctx_raw;
dev_ctx_raw.resize(places.size()); dev_ctx_raw.resize(places.size());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); GroupStart();
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
platform::CUDADeviceGuard guard(places[i]); platform::CUDADeviceGuard guard(places[i]);
...@@ -564,7 +638,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -564,7 +638,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
dev_ctx_raw[i] = dev_ctx[i].get(); dev_ctx_raw[i] = dev_ctx[i].get();
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); GroupEnd();
// TODO(sunyilun): for compatibility, will be removed later // TODO(sunyilun): for compatibility, will be removed later
place_to_calc_event_.emplace(places_key, places[0]); place_to_calc_event_.emplace(places_key, places[0]);
...@@ -1086,7 +1160,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -1086,7 +1160,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
size_t offset = 0; size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()),
...@@ -1104,7 +1178,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -1104,7 +1178,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
stream)); stream));
offset += input.numel() / size_; offset += input.numel() / size_;
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); GroupEnd();
}, },
CommType::ALLTOALL); CommType::ALLTOALL);
} }
...@@ -1130,7 +1204,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -1130,7 +1204,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
ncclComm_t comm, ncclComm_t comm,
const gpuStream_t& stream) { const gpuStream_t& stream) {
size_t offset = 0; size_t offset = 0;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()),
...@@ -1148,141 +1222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -1148,141 +1222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
stream)); stream));
offset += input.numel() / size_; offset += input.numel() / size_;
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); GroupEnd();
}, },
CommType::ALLTOALL, CommType::ALLTOALL,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll_Single(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(),
true,
platform::errors::InvalidArgument(
"The dtypes of input and output must be equal."));
std::vector<int64_t> in_dims = phi::vectorize(input.dims());
std::vector<int64_t> out_dims = phi::vectorize(output.dims());
CheckSplitSizes(&in_sizes, in_dims);
CheckSplitSizes(&out_sizes, out_dims);
size_t in_offset = 0, out_offset = 0;
size_t in_length = 0, out_length = 0;
size_t in_row_size = input.numel() / in_dims[0];
size_t out_row_size = output.numel() / out_dims[0];
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
in_length = in_sizes[i] * in_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), in_offset, input.dtype()),
in_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_length;
out_length = out_sizes[i] * out_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), out_offset, input.dtype()),
out_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
out_offset += out_length;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLTOALL_SINGLE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(),
true,
platform::errors::InvalidArgument(
"The dtypes of input and output must be equal."));
std::vector<int64_t> in_dims = phi::vectorize(input.dims());
std::vector<int64_t> out_dims = phi::vectorize(output.dims());
CheckSplitSizes(&in_sizes, in_dims);
CheckSplitSizes(&out_sizes, out_dims);
size_t in_offset = 0, out_offset = 0;
size_t in_length = 0, out_length = 0;
size_t in_row_size = input.numel() / in_dims[0];
size_t out_row_size = output.numel() / out_dims[0];
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
in_length = in_sizes[i] * in_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), in_offset, input.dtype()),
in_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_length;
out_length = out_sizes[i] * out_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), out_offset, input.dtype()),
out_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
out_offset += out_length;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLTOALL_SINGLE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
...@@ -1396,7 +1342,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1396,7 +1342,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const gpuStream_t& stream) { const gpuStream_t& stream) {
size_t offset = 0; size_t offset = 0;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()),
...@@ -1414,7 +1360,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1414,7 +1360,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); GroupEnd();
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(), output.data(),
...@@ -1456,7 +1402,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1456,7 +1402,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
"Input and output tensors should have the same shape.")); "Input and output tensors should have the same shape."));
size_t offset = 0; size_t offset = 0;
if (rank_ == opts.root_rank) { if (rank_ == opts.root_rank) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); GroupStart();
for (auto i = 0; i < size_; i++) { for (auto i = 0; i < size_; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), offset, input.dtype()), GetPointerByOffset(input.data(), offset, input.dtype()),
...@@ -1474,7 +1420,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1474,7 +1420,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
opts.root_rank, opts.root_rank,
comm, comm,
stream)); stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); GroupEnd();
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
output.data(), output.data(),
......
...@@ -109,6 +109,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -109,6 +109,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Barrier( std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override; const BarrierOptions& = BarrierOptions()) override;
...@@ -171,20 +179,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -171,20 +179,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream) override; bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes) override;
std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors, std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -73,6 +73,31 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce( ...@@ -73,6 +73,31 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
"ProcessGroup%s does not support all_reduce.", GetBackendName())); "ProcessGroup%s does not support all_reduce.", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) {
return AllToAll(out_tensor,
in_tensor,
out_size_each_rank,
in_size_each_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_to_all.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
...@@ -165,31 +190,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll( ...@@ -165,31 +190,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
"ProcessGroup%s does not support do alltoall", GetBackendName())); "ProcessGroup%s does not support do alltoall", GetBackendName()));
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op) {
return AllToAllSingle(in_tensors,
out_tensors,
in_sizes,
out_sizes,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do alltoall_single", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors, std::vector<phi::DenseTensor>& out_tensors,
......
...@@ -89,6 +89,21 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -89,6 +89,21 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
...@@ -140,21 +155,6 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -140,21 +155,6 @@ class ProcessGroupStream : public ProcessGroup {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
std::vector<int64_t>& in_sizes, // NOLINT
std::vector<int64_t>& out_sizes, // NOLINT
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
std::vector<int64_t>& in_sizes, // NOLINT
std::vector<int64_t>& out_sizes, // NOLINT
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce( std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......
...@@ -277,7 +277,7 @@ void BindDistributed(py::module *m) { ...@@ -277,7 +277,7 @@ void BindDistributed(py::module *m) {
/*offset*/ 0, /*offset*/ 0,
/*numel*/ -1, /*numel*/ -1,
sync_op); sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx); task->UpdateWaitChain(dev_ctx);
return task; return task;
}, },
...@@ -316,84 +316,96 @@ void BindDistributed(py::module *m) { ...@@ -316,84 +316,96 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_to_all", "all_to_all",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor_list,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor_list,
bool sync_op) { bool sync_op) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list = auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty // in_tensor_list should not be empty
const auto &dev_ctx = const auto &dev_ctx =
self.GetDeviceContext(in_tensor_list.back().place()); self.GetDeviceContext(in_tensor_list.back().place());
auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op); int world_size = self.GetSize();
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); auto task =
self.AllToAll(out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
sync_op);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
task->UpdateWaitChain(dev_ctx); task->UpdateWaitChain(dev_ctx);
return task; return task;
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"all_to_all_tensor", "all_to_all_tensor",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllToAll(in_wrapper, out_wrapper, sync_op); auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
int world_size = self.GetSize();
return self.AllToAll(
out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"all_to_all_single", "all_to_all_single",
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
std::vector<int64_t> &in_sizes, py::handle py_in_tensor,
std::vector<int64_t> &out_sizes, const std::vector<int64_t> &out_sizes,
const std::vector<int64_t> &in_sizes,
bool sync_op) { bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllToAllSingle( return self.AllToAll(
in_wrapper, out_wrapper, in_sizes, out_sizes, sync_op); out_dense, in_dense, out_sizes, in_sizes, sync_op);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in_sizes"), py::arg("in"),
py::arg("out_sizes"), py::arg("out_sizes"),
py::arg("in_sizes"),
py::arg("sync_op"), py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
...@@ -674,18 +686,20 @@ void BindDistributed(py::module *m) { ...@@ -674,18 +686,20 @@ void BindDistributed(py::module *m) {
[](distributed::ProcessGroup &self, [](distributed::ProcessGroup &self,
py::handle py_in_tensor, py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
std::vector<int64_t> in_sizes, const std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) { const std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense}; auto *out_dense = p_out_tensor.get();
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single( auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
in_tensors, out_tensors, in_sizes, out_sizes); auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
return self.AllToAll(
out_dense, in_dense, out_sizes, in_sizes, /*sync_op*/ true);
}, },
py::arg("in"), py::arg("in"),
py::arg("out"), py::arg("out"),
...@@ -765,7 +779,7 @@ void BindDistributed(py::module *m) { ...@@ -765,7 +779,7 @@ void BindDistributed(py::module *m) {
/*numel*/ -1, /*numel*/ -1,
/*sync_op*/ true, /*sync_op*/ true,
/*use_calc_stream*/ true); /*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task; return task;
}, },
py::arg("out"), py::arg("out"),
...@@ -856,88 +870,96 @@ void BindDistributed(py::module *m) { ...@@ -856,88 +870,96 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_to_all_on_calc_stream", "all_to_all_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor_list, py::handle py_out_tensor_list,
py::handle py_out_tensor_list) { py::handle py_in_tensor_list) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list = auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl()); concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
// in_tensor_list must not be empty auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
const auto &dev_ctx = self.GetDeviceContext( const auto &dev_ctx = self.GetDeviceContext(
in_tensor_list.back().place(), /*use_calc_stream*/ true); in_tensor_list.back().place(), /*use_calc_stream*/ true);
auto task = self.AllToAll(in_wrapper, int world_size = self.GetSize();
out_wrapper, auto task =
/*sync_op*/ true, self.AllToAll(out_dense,
/*use_calc_stream*/ true); in_dense,
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true,
/*use_calc_stream*/ true);
SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task; return task;
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"all_to_all_tensor_on_calc_stream", "all_to_all_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor, py::handle py_out_tensor,
py::handle py_out_tensor) { py::handle py_in_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllToAll(in_wrapper, auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
out_wrapper, auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
/*sync_op*/ true, in_tensor.impl());
/*use_calc_stream*/ true); auto in_dense = *p_in_tensor;
int world_size = self.GetSize();
return self.AllToAll(
out_dense,
in_dense,
GetDefaultSplitSizes(*out_dense, world_size),
GetDefaultSplitSizes(in_dense, world_size),
/*sync_op*/ true,
/*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"all_to_all_single_on_calc_stream", "all_to_all_single_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor, py::handle py_out_tensor,
std::vector<int64_t> &in_sizes, py::handle py_in_tensor,
std::vector<int64_t> &out_sizes) { const std::vector<int64_t> &out_sizes,
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); const std::vector<int64_t> &in_sizes) {
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>( auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl()); out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense}; auto *out_dense = p_out_tensor.get();
return self.AllToAllSingle(in_wrapper, auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
out_wrapper, auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_sizes, in_tensor.impl());
out_sizes, auto in_dense = *p_in_tensor;
/*sync_op*/ true,
/*use_calc_stream*/ true); return self.AllToAll(out_dense,
in_dense,
out_sizes,
in_sizes,
/*sync_op*/ true,
/*use_calc_stream*/ true);
}, },
py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("in_sizes"), py::arg("in"),
py::arg("out_sizes"), py::arg("out_sizes"),
py::arg("in_sizes"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace paddle { namespace paddle {
namespace distributed { namespace pybind {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ConcatDenseTensor { struct ConcatDenseTensor {
...@@ -113,6 +113,10 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx, ...@@ -113,6 +113,10 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
ConcatDenseTensor<DeviceContext, phi::dtype::float16>()( ConcatDenseTensor<DeviceContext, phi::dtype::float16>()(
dev_ctx, t_list, p_out); dev_ctx, t_list, p_out);
break; break;
case phi::DataType::BFLOAT16:
ConcatDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
dev_ctx, t_list, p_out);
break;
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
ConcatDenseTensor<DeviceContext, float>()(dev_ctx, t_list, p_out); ConcatDenseTensor<DeviceContext, float>()(dev_ctx, t_list, p_out);
break; break;
...@@ -150,6 +154,10 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx, ...@@ -150,6 +154,10 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
SplitDenseTensor<DeviceContext, phi::dtype::float16>()( SplitDenseTensor<DeviceContext, phi::dtype::float16>()(
dev_ctx, t_in, p_list); dev_ctx, t_in, p_list);
break; break;
case phi::DataType::BFLOAT16:
SplitDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
dev_ctx, t_in, p_list);
break;
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
SplitDenseTensor<DeviceContext, float>()(dev_ctx, t_in, p_list); SplitDenseTensor<DeviceContext, float>()(dev_ctx, t_in, p_list);
break; break;
...@@ -249,5 +257,10 @@ void SplitTensor(const phi::DeviceContext &dev_ctx, ...@@ -249,5 +257,10 @@ void SplitTensor(const phi::DeviceContext &dev_ctx,
} }
} }
} // namespace distributed inline std::vector<int64_t> GetDefaultSplitSizes(const phi::DenseTensor &tensor,
int world_size) {
return std::vector<int64_t>(world_size, tensor.dims()[0] / world_size);
}
} // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -75,11 +75,11 @@ def _all_to_all_in_dygraph( ...@@ -75,11 +75,11 @@ def _all_to_all_in_dygraph(
if use_calc_stream: if use_calc_stream:
return group.process_group.all_to_all_on_calc_stream( return group.process_group.all_to_all_on_calc_stream(
in_tensor_list, out_tensor_list out_tensor_list, in_tensor_list
) )
task = group.process_group.all_to_all( task = group.process_group.all_to_all(
in_tensor_list, out_tensor_list, sync_op out_tensor_list, in_tensor_list, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
...@@ -243,18 +243,23 @@ def _alltoall_single_in_dygraph( ...@@ -243,18 +243,23 @@ def _alltoall_single_in_dygraph(
sync_op, sync_op,
use_calc_stream, use_calc_stream,
): ):
world_size = dist.get_world_size()
if out_split_sizes is None: if out_split_sizes is None:
out_split_sizes = [] out_split_sizes = [
out_tensor.shape[0] // world_size for _ in range(world_size)
]
if in_split_sizes is None: if in_split_sizes is None:
in_split_sizes = [] in_split_sizes = [
in_tensor.shape[0] // world_size for _ in range(world_size)
]
if use_calc_stream: if use_calc_stream:
return group.process_group.all_to_all_single_on_calc_stream( return group.process_group.all_to_all_single_on_calc_stream(
in_tensor, out_tensor, in_split_sizes, out_split_sizes out_tensor, in_tensor, out_split_sizes, in_split_sizes
) )
task = group.process_group.all_to_all_single( task = group.process_group.all_to_all_single(
in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op out_tensor, in_tensor, out_split_sizes, in_split_sizes, sync_op
) )
if sync_op: if sync_op:
task.wait() task.wait()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册