From 3f480af2b073e30df7ca048d98792d581e06fe94 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Thu, 17 Nov 2022 17:35:58 +0800 Subject: [PATCH] Refactor collective communication all_to_all, all_to_all_single C++ API (#48059) --- .../distributed/collective/ProcessGroup.h | 31 +-- .../collective/ProcessGroupNCCL.cc | 222 +++++++----------- .../distributed/collective/ProcessGroupNCCL.h | 22 +- .../collective/ProcessGroupStream.cc | 50 ++-- .../collective/ProcessGroupStream.h | 30 +-- paddle/fluid/pybind/distributed_py.cc | 220 +++++++++-------- paddle/fluid/pybind/process_group_utils.h | 17 +- .../communication/stream/all_to_all.py | 17 +- 8 files changed, 290 insertions(+), 319 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 029a64a25c..152bb1aa6f 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -46,7 +46,6 @@ enum class CommType : std::uint8_t { SEND = 9, RECV = 10, BARRIER = 11, - ALLTOALL_SINGLE = 12, UNKNOWN = 100, }; @@ -124,6 +123,17 @@ class ProcessGroup { GetBackendName())); } + virtual std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support all_to_all with sync_op flag.", + GetBackendName())); + } + virtual std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) { PADDLE_THROW(platform::errors::Unimplemented( @@ -255,25 +265,6 @@ class ProcessGroup { "ProcessGroup%s does not support alltoall", GetBackendName())); } - virtual std::shared_ptr AllToAll_Single( - std::vector&, // NOLINT - std::vector&, // NOLINT - std::vector&, - std::vector&) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support AllToAll_Single", GetBackendName())); - } - - virtual std::shared_ptr AllToAllSingle( - std::vector&, // NOLINT - std::vector&, // NOLINT - std::vector&, - std::vector&, - bool) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support alltoall_single", GetBackendName())); - } - virtual std::shared_ptr Reduce( std::vector&, // NOLINT std::vector&, // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index d7d5beea89..4a70b81e31 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -184,6 +184,80 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( use_calc_stream); } +void CheckSizeOnEachRank(const phi::DDim& tensor_dim, + const std::vector& size_on_each_rank, + int world_size) { + int length_size_on_each_rank = size_on_each_rank.size(); + PADDLE_ENFORCE_EQ( + length_size_on_each_rank, + world_size, + platform::errors::InvalidArgument( + "The length of size_on_each_rank must be equal to world_size.")); + + int64_t sum_size_on_each_rank = + std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0); + PADDLE_ENFORCE_EQ( + sum_size_on_each_rank, + tensor_dim[0], + platform::errors::InvalidArgument( + "The sum of size_on_each_rank must be equal to tensor's dim[0].")); +} + +std::shared_ptr ProcessGroupNCCL::AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) { + const phi::DDim& out_dim = out_tensor->dims(); + const phi::DDim& in_dim = in_tensor.dims(); + CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); + CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); + + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + gpuStream_t stream) { + int64_t in_row_size = input.numel() / in_dim[0], + out_row_size = output->numel() / out_dim[0]; + int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; + phi::DenseTensor input_partial, output_partial; + + GroupStart(); + for (auto i = 0; i < size_; i++) { + in_numel = in_size_each_rank[i] * in_row_size; + input_partial = GetPartialTensor(input, in_offset, in_numel); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + input_partial.data(), + in_numel, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + in_offset += in_numel; + + out_numel = out_size_each_rank[i] * out_row_size; + output_partial = GetPartialTensor(*output, out_offset, out_numel); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + output_partial.data(), + out_numel, + platform::ToNCCLDataType(output->dtype()), + i, + comm, + stream)); + out_offset += out_numel; + } + GroupEnd(); + }, + CommType::ALLTOALL, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::Barrier( const BarrierOptions& opts) { PADDLE_ENFORCE_GE(opts.device_id, @@ -551,7 +625,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( std::vector dev_ctx_raw; dev_ctx_raw.resize(places.size()); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + GroupStart(); for (size_t i = 0; i < places.size(); ++i) { platform::CUDADeviceGuard guard(places[i]); @@ -564,7 +638,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( dev_ctx_raw[i] = dev_ctx[i].get(); } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + GroupEnd(); // TODO(sunyilun): for compatibility, will be removed later place_to_calc_event_.emplace(places_key, places[0]); @@ -1086,7 +1160,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( ncclComm_t comm, const gpuStream_t& stream) { size_t offset = 0; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + GroupStart(); for (auto i = 0; i < size_; i++) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( GetPointerByOffset(input.data(), offset, input.dtype()), @@ -1104,7 +1178,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( stream)); offset += input.numel() / size_; } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + GroupEnd(); }, CommType::ALLTOALL); } @@ -1130,7 +1204,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( ncclComm_t comm, const gpuStream_t& stream) { size_t offset = 0; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + GroupStart(); for (auto i = 0; i < size_; i++) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( GetPointerByOffset(input.data(), offset, input.dtype()), @@ -1148,141 +1222,13 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( stream)); offset += input.numel() / size_; } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + GroupEnd(); }, CommType::ALLTOALL, sync_op, use_calc_stream); } -std::shared_ptr ProcessGroupNCCL::AllToAll_Single( - std::vector& in_tensors, - std::vector& out_tensors, - std::vector& in_sizes, - std::vector& out_sizes) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(), - true, - platform::errors::InvalidArgument( - "The dtypes of input and output must be equal.")); - - std::vector in_dims = phi::vectorize(input.dims()); - std::vector out_dims = phi::vectorize(output.dims()); - CheckSplitSizes(&in_sizes, in_dims); - CheckSplitSizes(&out_sizes, out_dims); - - size_t in_offset = 0, out_offset = 0; - size_t in_length = 0, out_length = 0; - size_t in_row_size = input.numel() / in_dims[0]; - size_t out_row_size = output.numel() / out_dims[0]; - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto i = 0; i < size_; i++) { - in_length = in_sizes[i] * in_row_size; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( - GetPointerByOffset(input.data(), in_offset, input.dtype()), - in_length, - platform::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - in_offset += in_length; - - out_length = out_sizes[i] * out_row_size; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - GetPointerByOffset(output.data(), out_offset, input.dtype()), - out_length, - platform::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - out_offset += out_length; - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - }, - CommType::ALLTOALL_SINGLE); -} - -std::shared_ptr ProcessGroupNCCL::AllToAllSingle( - std::vector& in_tensors, - std::vector& out_tensors, - std::vector& in_sizes, - std::vector& out_sizes, - bool sync_op, - bool use_calc_stream) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(), - true, - platform::errors::InvalidArgument( - "The dtypes of input and output must be equal.")); - - std::vector in_dims = phi::vectorize(input.dims()); - std::vector out_dims = phi::vectorize(output.dims()); - CheckSplitSizes(&in_sizes, in_dims); - CheckSplitSizes(&out_sizes, out_dims); - - size_t in_offset = 0, out_offset = 0; - size_t in_length = 0, out_length = 0; - size_t in_row_size = input.numel() / in_dims[0]; - size_t out_row_size = output.numel() / out_dims[0]; - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto i = 0; i < size_; i++) { - in_length = in_sizes[i] * in_row_size; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( - GetPointerByOffset(input.data(), in_offset, input.dtype()), - in_length, - platform::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - in_offset += in_length; - - out_length = out_sizes[i] * out_row_size; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - GetPointerByOffset(output.data(), out_offset, input.dtype()), - out_length, - platform::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - out_offset += out_length; - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - }, - CommType::ALLTOALL_SINGLE, - sync_op, - use_calc_stream); -} - std::shared_ptr ProcessGroupNCCL::Reduce( std::vector& in_tensors, std::vector& out_tensors, @@ -1396,7 +1342,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( const gpuStream_t& stream) { size_t offset = 0; if (rank_ == opts.root_rank) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + GroupStart(); for (auto i = 0; i < size_; i++) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( GetPointerByOffset(input.data(), offset, input.dtype()), @@ -1414,7 +1360,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( opts.root_rank, comm, stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + GroupEnd(); } else { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( output.data(), @@ -1456,7 +1402,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( "Input and output tensors should have the same shape.")); size_t offset = 0; if (rank_ == opts.root_rank) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + GroupStart(); for (auto i = 0; i < size_; i++) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( GetPointerByOffset(input.data(), offset, input.dtype()), @@ -1474,7 +1420,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( opts.root_rank, comm, stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + GroupEnd(); } else { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( output.data(), diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index dab6d94288..a6528be80b 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -109,6 +109,14 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream) override; + std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; @@ -171,20 +179,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr AllToAll_Single( - std::vector& in, - std::vector& out, - std::vector& in_sizes, - std::vector& out_sizes) override; - - std::shared_ptr AllToAllSingle( - std::vector& in_tensors, - std::vector& out_tensors, - std::vector& in_sizes, - std::vector& out_sizes, - bool sync_op, - bool use_calc_stream) override; - std::shared_ptr Reduce( std::vector& tensors, std::vector& out_tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 2561a4f5b2..3839f70ac1 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -73,6 +73,31 @@ std::shared_ptr ProcessGroupStream::AllReduce( "ProcessGroup%s does not support all_reduce.", GetBackendName())); } +std::shared_ptr ProcessGroupStream::AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op) { + return AllToAll(out_tensor, + in_tensor, + out_size_each_rank, + in_size_each_rank, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support all_to_all.", GetBackendName())); +} + std::shared_ptr ProcessGroupStream::Broadcast( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -165,31 +190,6 @@ std::shared_ptr ProcessGroupStream::AllToAll( "ProcessGroup%s does not support do alltoall", GetBackendName())); } -std::shared_ptr ProcessGroupStream::AllToAllSingle( - std::vector& in_tensors, - std::vector& out_tensors, - std::vector& in_sizes, - std::vector& out_sizes, - bool sync_op) { - return AllToAllSingle(in_tensors, - out_tensors, - in_sizes, - out_sizes, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::AllToAllSingle( - std::vector& in_tensors, - std::vector& out_tensors, - std::vector& in_sizes, - std::vector& out_sizes, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support do alltoall_single", GetBackendName())); -} - std::shared_ptr ProcessGroupStream::Reduce( std::vector& in_tensors, std::vector& out_tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index 15b0635c5a..ad37c33068 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -89,6 +89,21 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); + std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op) override; + + virtual std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream); + std::shared_ptr Broadcast( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -140,21 +155,6 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); - std::shared_ptr AllToAllSingle( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - std::vector& in_sizes, // NOLINT - std::vector& out_sizes, // NOLINT - bool sync_op) override; - - virtual std::shared_ptr AllToAllSingle( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - std::vector& in_sizes, // NOLINT - std::vector& out_sizes, // NOLINT - bool sync_op, - bool use_calc_stream); - std::shared_ptr Reduce( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index a596275015..dbc4c57c65 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -277,7 +277,7 @@ void BindDistributed(py::module *m) { /*offset*/ 0, /*numel*/ -1, sync_op); - distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + SplitTensor(dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(dev_ctx); return task; }, @@ -316,84 +316,96 @@ void BindDistributed(py::module *m) { .def( "all_to_all", [](distributed::ProcessGroup &self, - py::handle py_in_tensor_list, py::handle py_out_tensor_list, + py::handle py_in_tensor_list, bool sync_op) { - auto in_tensor_list = - CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); - Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); - auto in_dense = std::dynamic_pointer_cast( - concat_in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor_list = CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( concat_out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto p_in_tensor = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + auto in_dense = *p_in_tensor; // in_tensor_list should not be empty const auto &dev_ctx = self.GetDeviceContext(in_tensor_list.back().place()); - auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op); - distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + int world_size = self.GetSize(); + auto task = + self.AllToAll(out_dense, + in_dense, + GetDefaultSplitSizes(*out_dense, world_size), + GetDefaultSplitSizes(in_dense, world_size), + sync_op); + SplitTensor(dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(dev_ctx); return task; }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("sync_op"), py::call_guard()) .def( "all_to_all_tensor", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, py::handle py_out_tensor, + py::handle py_in_tensor, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); - return self.AllToAll(in_wrapper, out_wrapper, sync_op); + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; + + int world_size = self.GetSize(); + return self.AllToAll( + out_dense, + in_dense, + GetDefaultSplitSizes(*out_dense, world_size), + GetDefaultSplitSizes(in_dense, world_size), + sync_op); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("sync_op"), py::call_guard()) .def( "all_to_all_single", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, py::handle py_out_tensor, - std::vector &in_sizes, - std::vector &out_sizes, + py::handle py_in_tensor, + const std::vector &out_sizes, + const std::vector &in_sizes, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; - return self.AllToAllSingle( - in_wrapper, out_wrapper, in_sizes, out_sizes, sync_op); + return self.AllToAll( + out_dense, in_dense, out_sizes, in_sizes, sync_op); }, - py::arg("in"), py::arg("out"), - py::arg("in_sizes"), + py::arg("in"), py::arg("out_sizes"), + py::arg("in_sizes"), py::arg("sync_op"), py::call_guard()) @@ -674,18 +686,20 @@ void BindDistributed(py::module *m) { [](distributed::ProcessGroup &self, py::handle py_in_tensor, py::handle py_out_tensor, - std::vector in_sizes, - std::vector out_sizes) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + const std::vector in_sizes, + const std::vector out_sizes) { auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - return self.AllToAll_Single( - in_tensors, out_tensors, in_sizes, out_sizes); + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; + + return self.AllToAll( + out_dense, in_dense, out_sizes, in_sizes, /*sync_op*/ true); }, py::arg("in"), py::arg("out"), @@ -765,7 +779,7 @@ void BindDistributed(py::module *m) { /*numel*/ -1, /*sync_op*/ true, /*use_calc_stream*/ true); - distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + SplitTensor(dev_ctx, *out_dense, &out_tensor_list); return task; }, py::arg("out"), @@ -856,88 +870,96 @@ void BindDistributed(py::module *m) { .def( "all_to_all_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor_list, - py::handle py_out_tensor_list) { - auto in_tensor_list = - CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); - Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); - auto in_dense = std::dynamic_pointer_cast( - concat_in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - + py::handle py_out_tensor_list, + py::handle py_in_tensor_list) { auto out_tensor_list = CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( concat_out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); - // in_tensor_list must not be empty + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto p_in_tensor = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + auto in_dense = *p_in_tensor; + + // in_tensor_list should not be empty const auto &dev_ctx = self.GetDeviceContext( in_tensor_list.back().place(), /*use_calc_stream*/ true); - auto task = self.AllToAll(in_wrapper, - out_wrapper, - /*sync_op*/ true, - /*use_calc_stream*/ true); - distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + int world_size = self.GetSize(); + auto task = + self.AllToAll(out_dense, + in_dense, + GetDefaultSplitSizes(*out_dense, world_size), + GetDefaultSplitSizes(in_dense, world_size), + /*sync_op*/ true, + /*use_calc_stream*/ true); + SplitTensor(dev_ctx, *out_dense, &out_tensor_list); return task; }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::call_guard()) .def( "all_to_all_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, - py::handle py_out_tensor) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - + py::handle py_out_tensor, + py::handle py_in_tensor) { auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); - return self.AllToAll(in_wrapper, - out_wrapper, - /*sync_op*/ true, - /*use_calc_stream*/ true); + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; + + int world_size = self.GetSize(); + return self.AllToAll( + out_dense, + in_dense, + GetDefaultSplitSizes(*out_dense, world_size), + GetDefaultSplitSizes(in_dense, world_size), + /*sync_op*/ true, + /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::call_guard()) .def( "all_to_all_single_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, py::handle py_out_tensor, - std::vector &in_sizes, - std::vector &out_sizes) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - + py::handle py_in_tensor, + const std::vector &out_sizes, + const std::vector &in_sizes) { auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); - return self.AllToAllSingle(in_wrapper, - out_wrapper, - in_sizes, - out_sizes, - /*sync_op*/ true, - /*use_calc_stream*/ true); + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; + + return self.AllToAll(out_dense, + in_dense, + out_sizes, + in_sizes, + /*sync_op*/ true, + /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), - py::arg("in_sizes"), + py::arg("in"), py::arg("out_sizes"), + py::arg("in_sizes"), py::call_guard()) .def( diff --git a/paddle/fluid/pybind/process_group_utils.h b/paddle/fluid/pybind/process_group_utils.h index 35a5ef0b1b..0543495754 100644 --- a/paddle/fluid/pybind/process_group_utils.h +++ b/paddle/fluid/pybind/process_group_utils.h @@ -21,7 +21,7 @@ #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace paddle { -namespace distributed { +namespace pybind { template struct ConcatDenseTensor { @@ -113,6 +113,10 @@ void ConcatDenseTensorWithType(const DeviceContext &dev_ctx, ConcatDenseTensor()( dev_ctx, t_list, p_out); break; + case phi::DataType::BFLOAT16: + ConcatDenseTensor()( + dev_ctx, t_list, p_out); + break; case phi::DataType::FLOAT32: ConcatDenseTensor()(dev_ctx, t_list, p_out); break; @@ -150,6 +154,10 @@ void SplitDenseTensorWithType(const DeviceContext &dev_ctx, SplitDenseTensor()( dev_ctx, t_in, p_list); break; + case phi::DataType::BFLOAT16: + SplitDenseTensor()( + dev_ctx, t_in, p_list); + break; case phi::DataType::FLOAT32: SplitDenseTensor()(dev_ctx, t_in, p_list); break; @@ -249,5 +257,10 @@ void SplitTensor(const phi::DeviceContext &dev_ctx, } } -} // namespace distributed +inline std::vector GetDefaultSplitSizes(const phi::DenseTensor &tensor, + int world_size) { + return std::vector(world_size, tensor.dims()[0] / world_size); +} + +} // namespace pybind } // namespace paddle diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index d05b53564a..2787c6a3d4 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -75,11 +75,11 @@ def _all_to_all_in_dygraph( if use_calc_stream: return group.process_group.all_to_all_on_calc_stream( - in_tensor_list, out_tensor_list + out_tensor_list, in_tensor_list ) task = group.process_group.all_to_all( - in_tensor_list, out_tensor_list, sync_op + out_tensor_list, in_tensor_list, sync_op ) if sync_op: task.wait() @@ -243,18 +243,23 @@ def _alltoall_single_in_dygraph( sync_op, use_calc_stream, ): + world_size = dist.get_world_size() if out_split_sizes is None: - out_split_sizes = [] + out_split_sizes = [ + out_tensor.shape[0] // world_size for _ in range(world_size) + ] if in_split_sizes is None: - in_split_sizes = [] + in_split_sizes = [ + in_tensor.shape[0] // world_size for _ in range(world_size) + ] if use_calc_stream: return group.process_group.all_to_all_single_on_calc_stream( - in_tensor, out_tensor, in_split_sizes, out_split_sizes + out_tensor, in_tensor, out_split_sizes, in_split_sizes ) task = group.process_group.all_to_all_single( - in_tensor, out_tensor, in_split_sizes, out_split_sizes, sync_op + out_tensor, in_tensor, out_split_sizes, in_split_sizes, sync_op ) if sync_op: task.wait() -- GitLab