From edda13cd88b269c932e1d8fafa5a6fabbbda72a2 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Fri, 18 Nov 2022 15:34:10 +0800 Subject: [PATCH] Refactor collective communication reduce, scatter, reduce_scatter C++ API (#48115) --- .../fluid/distributed/collective/NCCLTools.h | 3 +- .../distributed/collective/ProcessGroup.h | 60 ++-- .../collective/ProcessGroupGloo.cc | 62 ++-- .../distributed/collective/ProcessGroupGloo.h | 22 +- .../collective/ProcessGroupNCCL.cc | 269 ++++++++---------- .../distributed/collective/ProcessGroupNCCL.h | 40 ++- .../collective/ProcessGroupStream.cc | 132 ++++----- .../collective/ProcessGroupStream.h | 76 +++-- paddle/fluid/pybind/distributed_py.cc | 198 ++++++------- .../communication/stream/reduce_scatter.py | 8 +- .../communication/stream/scatter.py | 8 +- 11 files changed, 432 insertions(+), 446 deletions(-) diff --git a/paddle/fluid/distributed/collective/NCCLTools.h b/paddle/fluid/distributed/collective/NCCLTools.h index 464ae0b6581..37b1e0f114c 100644 --- a/paddle/fluid/distributed/collective/NCCLTools.h +++ b/paddle/fluid/distributed/collective/NCCLTools.h @@ -47,7 +47,7 @@ namespace paddle { namespace distributed { -#define NCCLCHECK(cmd) \ +#define NCCL_CHECK(cmd) \ do { \ ncclResult_t r = cmd; \ if (r != ncclSuccess) { \ @@ -60,6 +60,7 @@ namespace distributed { } while (0) ncclRedOp_t ToNCCLRedType(ReduceOp reduction); + std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); } // namespace distributed diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 152bb1aa6f9..795a1a91b52 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -150,6 +150,36 @@ class ProcessGroup { GetBackendName())); } + virtual std::shared_ptr Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support reduce with sync_op flag.", + GetBackendName())); + } + + virtual std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support reduce_scatter with sync_op flag.", + GetBackendName())); + } + + virtual std::shared_ptr Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support scatter with sync_op flag.", + GetBackendName())); + } + virtual std::shared_ptr Recv(phi::DenseTensor* tensor, int src_rank, int64_t offset, @@ -273,16 +303,6 @@ class ProcessGroup { "ProcessGroup%s does not support reduce", GetBackendName())); } - virtual std::shared_ptr Reduce( - std::vector& /* input tensors */, // NOLINT - std::vector& /* output tensors */, // NOLINT - const ReduceOptions&, - bool) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support reduce with sync_op flag", - GetBackendName())); - } - virtual std::shared_ptr Scatter( std::vector&, // NOLINT std::vector&, // NOLINT @@ -291,26 +311,6 @@ class ProcessGroup { "ProcessGroup%s does not support scatter", GetBackendName())); } - virtual std::shared_ptr Scatter( - std::vector&, // NOLINT - std::vector&, // NOLINT - const ScatterOptions&, - bool) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support scatter with sync_op flag", - GetBackendName())); - } - - virtual std::shared_ptr ReduceScatter( - std::vector&, // NOLINT - std::vector&, // NOLINT - const ReduceScatterOptions&, - bool) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support reduce_scatter with sync_op flag", - GetBackendName())); - } - protected: const int rank_; const int size_; diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc index 2574eb11be2..f0a65b02fb6 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc @@ -234,8 +234,8 @@ std::shared_ptr ProcessGroupGloo::Broadcast( const phi::DenseTensor& in_tensor, const BroadcastOptions& opts, bool sync_op) { - std::vector in_wrapper = {in_tensor}; - std::vector out_wrapper = {*out_tensor}; + std::vector in_wrapper{in_tensor}; + std::vector out_wrapper{*out_tensor}; return Broadcast(in_wrapper, out_wrapper, opts, true); } @@ -396,8 +396,8 @@ std::shared_ptr ProcessGroupGloo::AllGather( int64_t offset, // for compatibility, no use now int64_t numel, // for compatibility, no use now bool sync_op) { - std::vector in_wrapper = {in_tensor}; - std::vector out_wrapper = {*out_tensor}; + std::vector in_wrapper{in_tensor}; + std::vector out_wrapper{*out_tensor}; return AllGather(in_wrapper, out_wrapper, true); } @@ -475,26 +475,34 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask { }; std::shared_ptr ProcessGroupGloo::Reduce( - std::vector& inputs, - std::vector& outputs, - const ReduceOptions& opts) { - return Reduce(inputs, outputs, opts, true); -} - -std::shared_ptr ProcessGroupGloo::Reduce( - std::vector& inputs, - std::vector& outputs, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, const ReduceOptions& opts, - bool sync_op) { + bool sync_op // for compatibility, no use now +) { std::shared_ptr task; auto tag = next_tag(); auto context = get_context(); - task = std::make_shared( - rank_, context, inputs, outputs, opts.reduce_op, opts.root_rank, tag); + std::vector in_wrapper{in_tensor}; + std::vector out_wrapper{*out_tensor}; + task = std::make_shared(rank_, + context, + in_wrapper, + out_wrapper, + opts.reduce_op, + opts.root_rank, + tag); task->Run(); return task; } +std::shared_ptr ProcessGroupGloo::Reduce( + std::vector& inputs, + std::vector& outputs, + const ReduceOptions& opts) { + return Reduce(&outputs[0], inputs[0], opts, true); +} + class ScatterGlooTask : public ProcessGroupGloo::GlooTask { public: ScatterGlooTask(int rank, @@ -538,26 +546,28 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { }; std::shared_ptr ProcessGroupGloo::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts) { - return Scatter(in_tensors, out_tensors, opts, true); -} - -std::shared_ptr ProcessGroupGloo::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, const ScatterOptions& opts, bool sync_op) { std::shared_ptr task; auto tag = next_tag(); auto context = get_context(); + std::vector in_wrapper{in_tensor}; + std::vector out_wrapper{*out_tensor}; task = std::make_shared( - rank_, context, in_tensors, out_tensors, opts.root_rank, size_, tag); + rank_, context, in_wrapper, out_wrapper, opts.root_rank, size_, tag); task->Run(); return task; } +std::shared_ptr ProcessGroupGloo::Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts) { + return Scatter(&out_tensors[0], in_tensors[0], opts, true); +} + std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) { ::gloo::transport::tcp::attr attr; diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index 474fb0c027c..fd691e024c4 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -120,6 +120,16 @@ class ProcessGroupGloo : public ProcessGroup { const BroadcastOptions& opts, bool sync_op) override; + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op) override; + + std::shared_ptr Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op) override; + // TODO(sunyilun): methods below will be removed later std::shared_ptr Broadcast( std::vector& inputs, @@ -155,23 +165,11 @@ class ProcessGroupGloo : public ProcessGroup { std::vector& out_tensors, bool sync_op) override; - std::shared_ptr Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts, - bool sync_op) override; - std::shared_ptr Reduce( std::vector& in_tensors, std::vector& out_tensors, const ReduceOptions& opts) override; - std::shared_ptr Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions&, - bool sync_op) override; - std::shared_ptr Scatter( std::vector& in_tensors, std::vector& out_tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 4a70b81e310..74ebf802059 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -87,11 +87,11 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr& store, : ProcessGroupStream(rank, size, gid), store_(store) {} void ProcessGroupNCCL::GroupStart() { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + NCCL_CHECK(platform::dynload::ncclGroupStart()); } void ProcessGroupNCCL::GroupEnd() { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + NCCL_CHECK(platform::dynload::ncclGroupEnd()); } const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( @@ -144,13 +144,13 @@ std::shared_ptr ProcessGroupNCCL::AllGather( const phi::DenseTensor& input, ncclComm_t comm, gpuStream_t stream) { - return platform::dynload::ncclAllGather( + NCCL_CHECK(platform::dynload::ncclAllGather( input.data(), output->data(), input.numel(), platform::ToNCCLDataType(input.dtype()), comm, - stream); + stream)); }, CommType::ALLGATHER, sync_op, @@ -170,14 +170,14 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( const phi::DenseTensor& input, ncclComm_t comm, gpuStream_t stream) { - return platform::dynload::ncclAllReduce( + NCCL_CHECK(platform::dynload::ncclAllReduce( input.data(), output->data(), input.numel(), platform::ToNCCLDataType(input.type()), ToNCCLRedType(opts.reduce_op), comm, - stream); + stream)); }, CommType::ALLREDUCE, sync_op, @@ -231,7 +231,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; input_partial = GetPartialTensor(input, in_offset, in_numel); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + NCCL_CHECK(platform::dynload::ncclSend( input_partial.data(), in_numel, platform::ToNCCLDataType(input.dtype()), @@ -242,7 +242,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( out_numel = out_size_each_rank[i] * out_row_size; output_partial = GetPartialTensor(*output, out_offset, out_numel); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + NCCL_CHECK(platform::dynload::ncclRecv( output_partial.data(), out_numel, platform::ToNCCLDataType(output->dtype()), @@ -294,20 +294,127 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( ncclComm_t comm, gpuStream_t stream) { int root = opts.source_rank + opts.source_root; - return platform::dynload::ncclBroadcast( + NCCL_CHECK(platform::dynload::ncclBroadcast( input.data(), output->data(), input.numel(), platform::ToNCCLDataType(input.type()), root, comm, - stream); + stream)); }, CommType::BROADCAST, sync_op, use_calc_stream); } +std::shared_ptr ProcessGroupNCCL::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + gpuStream_t stream) { + NCCL_CHECK(platform::dynload::ncclReduce( + input.data(), + output->data(), + input.numel(), + platform::ToNCCLDataType(input.dtype()), + ToNCCLRedType(opts.reduce_op), + opts.root_rank, + comm, + stream)); + }, + CommType::REDUCE, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupNCCL::ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + gpuStream_t stream) { + NCCL_CHECK(platform::dynload::ncclReduceScatter( + input.data(), + output->data(), + output->numel(), + platform::ToNCCLDataType(input.dtype()), + ToNCCLRedType(opts.reduce_op), + comm, + stream)); + }, + CommType::REDUCE_SCATTER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupNCCL::Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + gpuStream_t stream) { + int64_t numel = input.numel() / size_; + if (rank_ == opts.root_rank) { + int64_t offset = 0; + phi::DenseTensor partial_tensor; + GroupStart(); + for (auto i = 0; i < size_; i++) { + partial_tensor = GetPartialTensor(input, offset, numel); + NCCL_CHECK(platform::dynload::ncclSend( + partial_tensor.data(), + numel, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + offset += numel; + } + NCCL_CHECK(platform::dynload::ncclRecv( + output->data(), + numel, + platform::ToNCCLDataType(output->dtype()), + opts.root_rank, + comm, + stream)); + GroupEnd(); + } else { + NCCL_CHECK(platform::dynload::ncclRecv( + output->data(), + numel, + platform::ToNCCLDataType(output->dtype()), + opts.root_rank, + comm, + stream)); + } + }, + CommType::SCATTER, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::Recv( phi::DenseTensor* tensor, int src_rank, @@ -328,13 +435,13 @@ std::shared_ptr ProcessGroupNCCL::Recv( int src, ncclComm_t comm, gpuStream_t stream) { - return platform::dynload::ncclRecv( + NCCL_CHECK(platform::dynload::ncclRecv( output->data(), output->numel(), platform::ToNCCLDataType(output->dtype()), src, comm, - stream); + stream)); }, CommType::RECV, sync_op, @@ -361,13 +468,13 @@ std::shared_ptr ProcessGroupNCCL::Send( int dst, ncclComm_t comm, gpuStream_t stream) { - return platform::dynload::ncclSend( + NCCL_CHECK(platform::dynload::ncclSend( input->data(), input->numel(), platform::ToNCCLDataType(input->dtype()), dst, comm, - stream); + stream)); }, CommType::SEND, sync_op, @@ -406,7 +513,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ncclUniqueId nccl_id; if (rank_ == 0) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); + NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id)); } BroadcastUniqueNCCLID(&nccl_id); @@ -418,7 +525,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, platform::DeviceContextPool::Instance().Get(place)); auto comm_ctx = std::make_unique(place); ncclComm_t nccl_comm; - NCCLCHECK(platform::dynload::ncclCommInitRank( + NCCL_CHECK(platform::dynload::ncclCommInitRank( &nccl_comm, GetSize(), nccl_id, GetRank())); comm_ctx->set_nccl_comm(nccl_comm); @@ -611,7 +718,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ncclUniqueId nccl_id; if (rank_ == 0) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); + NCCL_CHECK(platform::dynload::ncclGetUniqueId(&nccl_id)); } BroadcastUniqueNCCLID(&nccl_id); @@ -632,7 +739,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( dev_ctx[i].reset(new phi::GPUContext(places[i])); ncclComm_t nccl_comm; - NCCLCHECK(platform::dynload::ncclCommInitRank( + NCCL_CHECK(platform::dynload::ncclCommInitRank( &nccl_comm, GetSize(), nccl_id, GetRank())); dev_ctx[i]->set_nccl_comm(nccl_comm); dev_ctx_raw[i] = dev_ctx[i].get(); @@ -1257,70 +1364,6 @@ std::shared_ptr ProcessGroupNCCL::Reduce( CommType::REDUCE); } -std::shared_ptr ProcessGroupNCCL::Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.dtype()), - ToNCCLRedType(opts.reduce_op), - opts.root_rank, - comm, - stream)); - }, - CommType::REDUCE, - sync_op, - use_calc_stream); -} - -std::shared_ptr ProcessGroupNCCL::ReduceScatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - if (FLAGS_use_stream_safe_cuda_allocator) { - platform::CUDADeviceGuard cuda_guard; - cuda_guard.SetDevice(output.place()); - memory::RecordStream(output.Holder(), stream); - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( - input.data(), - output.data(), - output.numel(), - platform::ToNCCLDataType(input.dtype()), - ToNCCLRedType(opts.reduce_op), - comm, - stream)); - }, - CommType::REDUCE_SCATTER, - sync_op, - use_calc_stream); -} - std::shared_ptr ProcessGroupNCCL::Scatter( std::vector& in_tensors, std::vector& out_tensors, @@ -1374,67 +1417,5 @@ std::shared_ptr ProcessGroupNCCL::Scatter( CommType::SCATTER); } -std::shared_ptr ProcessGroupNCCL::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - PADDLE_ENFORCE_EQ( - output.numel(), - input.numel() / size_, - platform::errors::InvalidArgument( - "Input and output tensors should have the same shape.")); - size_t offset = 0; - if (rank_ == opts.root_rank) { - GroupStart(); - for (auto i = 0; i < size_; i++) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( - GetPointerByOffset(input.data(), offset, input.dtype()), - input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), - i, - comm, - stream)); - offset += input.numel() / size_; - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - output.data(), - input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), - opts.root_rank, - comm, - stream)); - GroupEnd(); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( - output.data(), - input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), - opts.root_rank, - comm, - stream)); - } - }, - CommType::SCATTER, - sync_op, - use_calc_stream); -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index a6528be80b4..c10c4370b4b 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -127,6 +127,25 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream) override; + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Recv(phi::DenseTensor* tensor, int src_rank, int64_t offset, @@ -184,32 +203,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream { std::vector& out_tensors, const ReduceOptions& opts) override; - std::shared_ptr Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts, - bool sync_op, - bool use_calc_stream) override; - - std::shared_ptr ReduceScatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream) override; - std::shared_ptr Scatter( std::vector& in_tensors, std::vector& out_tensors, const ScatterOptions& opts) override; - std::shared_ptr Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream) override; - private: std::shared_ptr CreateTask(const Place& place, int rank, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 3839f70ac13..9f7b3c1964e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -120,6 +120,72 @@ std::shared_ptr ProcessGroupStream::Broadcast( "ProcessGroup%s does not support broadcast.", GetBackendName())); } +std::shared_ptr ProcessGroupStream::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op) { + return Reduce(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support reduce.", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op) { + return ReduceScatter(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support reduce_scatter.", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op) { + return Scatter(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::Unimplemented( + "ProcessGroup%s does not support scatter.", GetBackendName())); +} + std::shared_ptr ProcessGroupStream::Recv( phi::DenseTensor* tensor, int src_rank, @@ -190,72 +256,6 @@ std::shared_ptr ProcessGroupStream::AllToAll( "ProcessGroup%s does not support do alltoall", GetBackendName())); } -std::shared_ptr ProcessGroupStream::Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts, - bool sync_op) { - return Reduce(in_tensors, - out_tensors, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Reduce( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support do reduce", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::ReduceScatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceScatterOptions& opts, - bool sync_op) { - return ReduceScatter(in_tensors, - out_tensors, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::ReduceScatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support do reduce_scatter", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts, - bool sync_op) { - return Scatter(in_tensors, - out_tensors, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support do scatter", GetBackendName())); -} - std::shared_ptr ProcessGroupStream::Recv( std::vector& tensors, int src_rank, bool sync_op) { return Recv(tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index ad37c330681..d1fd95953f1 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -117,6 +117,43 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream); + std::shared_ptr Recv(phi::DenseTensor* tensor, int src_rank, int64_t offset, @@ -155,45 +192,6 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); - std::shared_ptr Reduce( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const ReduceOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr Reduce( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const ReduceOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr ReduceScatter( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const ReduceScatterOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr ReduceScatter( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr Scatter( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const ScatterOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr Scatter( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream); - std::shared_ptr Recv( std::vector& tensors, // NOLINT int src_rank, diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 52160ea99a0..0634f825a01 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -412,16 +412,17 @@ void BindDistributed(py::module *m) { .def( "reduce", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, + py::handle py_tensor, int dst, distributed::ReduceOp op, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto p_dense = + std::dynamic_pointer_cast(tensor.impl()); + auto *out_dense = p_dense.get(); + auto in_dense = *p_dense; distributed::ReduceOptions opts{op, dst}; - auto dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector tensors = {*dense}; - return self.Reduce(tensors, tensors, opts, sync_op); + return self.Reduce(out_dense, in_dense, opts, sync_op); }, py::arg("tensor"), py::arg("dst"), @@ -432,28 +433,27 @@ void BindDistributed(py::module *m) { .def( "reduce_scatter", [](distributed::ProcessGroup &self, - py::handle py_in_tensor_list, py::handle py_out_tensor, + py::handle py_in_tensor_list, distributed::ReduceOp op, bool sync_op) { + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto p_out_tensor = std::dynamic_pointer_cast( + out_tensor.impl()); + auto out_dense = p_out_tensor.get(); + auto in_tensor_list = CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); - auto in_dense = std::dynamic_pointer_cast( + auto p_in_tensor = std::dynamic_pointer_cast( concat_in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto in_dense = *p_in_tensor; distributed::ReduceScatterOptions opts{op}; - return self.ReduceScatter( - in_wrapper, out_wrapper, opts, sync_op); + return self.ReduceScatter(out_dense, in_dense, opts, sync_op); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("op"), py::arg("sync_op"), py::call_guard()) @@ -461,26 +461,25 @@ void BindDistributed(py::module *m) { .def( "reduce_scatter_tensor", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, py::handle py_out_tensor, + py::handle py_in_tensor, distributed::ReduceOp op, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; distributed::ReduceScatterOptions opts{op}; - return self.ReduceScatter( - in_wrapper, out_wrapper, opts, sync_op); + return self.ReduceScatter(out_dense, in_dense, opts, sync_op); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("op"), py::arg("sync_op"), py::call_guard()) @@ -488,27 +487,27 @@ void BindDistributed(py::module *m) { .def( "scatter", [](distributed::ProcessGroup &self, - py::handle py_in_tensor_list, py::handle py_out_tensor, + py::handle py_in_tensor_list, int src, bool sync_op) { + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto p_out_tensor = std::dynamic_pointer_cast( + out_tensor.impl()); + auto *out_dense = p_out_tensor.get(); + auto in_tensor_list = CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); - auto in_dense = std::dynamic_pointer_cast( + auto p_in_tensor = std::dynamic_pointer_cast( concat_in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto in_dense = *p_in_tensor; distributed::ScatterOptions opts{src}; - return self.Scatter(in_wrapper, out_wrapper, opts, sync_op); + return self.Scatter(out_dense, in_dense, opts, sync_op); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("src"), py::arg("sync_op"), py::call_guard()) @@ -516,25 +515,25 @@ void BindDistributed(py::module *m) { .def( "scatter_tensor", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, py::handle py_out_tensor, + py::handle py_in_tensor, int src, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; distributed::ScatterOptions opts{src}; - return self.Scatter(in_wrapper, out_wrapper, opts, sync_op); + return self.Scatter(out_dense, in_dense, opts, sync_op); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("src"), py::arg("sync_op"), py::call_guard()) @@ -986,16 +985,17 @@ void BindDistributed(py::module *m) { .def( "reduce_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, + py::handle py_tensor, int dst, distributed::ReduceOp op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + auto p_dense = + std::dynamic_pointer_cast(tensor.impl()); + auto *out_dense = p_dense.get(); + auto in_dense = *p_dense; distributed::ReduceOptions opts{op, dst}; - auto dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector tensors = {*dense}; - return self.Reduce(tensors, - tensors, + return self.Reduce(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); @@ -1008,116 +1008,116 @@ void BindDistributed(py::module *m) { .def( "reduce_scatter_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor_list, py::handle py_out_tensor, + py::handle py_in_tensor_list, distributed::ReduceOp op) { + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto p_out_tensor = std::dynamic_pointer_cast( + out_tensor.impl()); + auto out_dense = p_out_tensor.get(); + auto in_tensor_list = CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); - auto in_dense = std::dynamic_pointer_cast( + auto p_in_tensor = std::dynamic_pointer_cast( concat_in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto in_dense = *p_in_tensor; distributed::ReduceScatterOptions opts{op}; - return self.ReduceScatter(in_wrapper, - out_wrapper, + return self.ReduceScatter(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("op"), py::call_guard()) .def( "reduce_scatter_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, py::handle py_out_tensor, + py::handle py_in_tensor, distributed::ReduceOp op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; distributed::ReduceScatterOptions opts{op}; - return self.ReduceScatter(in_wrapper, - out_wrapper, + return self.ReduceScatter(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("op"), py::call_guard()) .def( "scatter_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor_list, py::handle py_out_tensor, + py::handle py_in_tensor_list, int src) { + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto p_out_tensor = std::dynamic_pointer_cast( + out_tensor.impl()); + auto *out_dense = p_out_tensor.get(); + auto in_tensor_list = CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); - auto in_dense = std::dynamic_pointer_cast( + auto p_in_tensor = std::dynamic_pointer_cast( concat_in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( - out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto in_dense = *p_in_tensor; distributed::ScatterOptions opts{src}; - return self.Scatter(in_wrapper, - out_wrapper, + return self.Scatter(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("src"), py::call_guard()) .def( "scatter_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, py::handle py_out_tensor, + py::handle py_in_tensor, int src) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; distributed::ScatterOptions opts{src}; - return self.Scatter(in_wrapper, - out_wrapper, + return self.Scatter(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("src"), py::call_guard()) diff --git a/python/paddle/distributed/communication/stream/reduce_scatter.py b/python/paddle/distributed/communication/stream/reduce_scatter.py index aa0d1e9b955..4d26e8d2b66 100644 --- a/python/paddle/distributed/communication/stream/reduce_scatter.py +++ b/python/paddle/distributed/communication/stream/reduce_scatter.py @@ -57,11 +57,11 @@ def _reduce_scatter_tensor_in_dygraph( if use_calc_stream: return group.process_group.reduce_scatter_tensor_on_calc_stream( - in_tensor, out_tensor, op_type + out_tensor, in_tensor, op_type ) task = group.process_group.reduce_scatter_tensor( - in_tensor, out_tensor, op_type, sync_op + out_tensor, in_tensor, op_type, sync_op ) if sync_op: task.wait() @@ -78,11 +78,11 @@ def _reduce_scatter_in_dygraph( if use_calc_stream: return group.process_group.reduce_scatter_on_calc_stream( - tensor_list, tensor, op_type + tensor, tensor_list, op_type ) task = group.process_group.reduce_scatter( - tensor_list, tensor, op_type, sync_op + tensor, tensor_list, op_type, sync_op ) if sync_op: task.wait() diff --git a/python/paddle/distributed/communication/stream/scatter.py b/python/paddle/distributed/communication/stream/scatter.py index 75a8ab3909a..5767c2150d8 100644 --- a/python/paddle/distributed/communication/stream/scatter.py +++ b/python/paddle/distributed/communication/stream/scatter.py @@ -53,11 +53,11 @@ def _scatter_tensor_in_dygraph( if use_calc_stream: return group.process_group.scatter_tensor_on_calc_stream( - in_tensor, out_tensor, src_rank_in_group + out_tensor, in_tensor, src_rank_in_group ) task = group.process_group.scatter_tensor( - in_tensor, out_tensor, src_rank_in_group, sync_op + out_tensor, in_tensor, src_rank_in_group, sync_op ) if sync_op: task.wait() @@ -80,11 +80,11 @@ def _scatter_in_dygraph( if use_calc_stream: return group.process_group.scatter_on_calc_stream( - tensor_list, tensor, src_rank_in_group + tensor, tensor_list, src_rank_in_group ) task = group.process_group.scatter( - tensor_list, tensor, src_rank_in_group, sync_op + tensor, tensor_list, src_rank_in_group, sync_op ) if sync_op: task.wait() -- GitLab