From 1be70bc516c2faa5282779aeaf82028f7524f7f6 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Thu, 5 Jan 2023 11:57:27 +0800 Subject: [PATCH] Refactor `ProcessGroup` to support comm context migration & clang compilation (#49451) * refactor: use base class * fix: incorrect deps * fix: add missing header * refactor: update class structures * fix: bkcl typo * fix: remove redundant def --- .../distributed/collective/CMakeLists.txt | 8 +- paddle/fluid/distributed/collective/check.cc | 2 +- .../distributed/collective/nccl_tools.cc | 9 +- .../fluid/distributed/collective/nccl_tools.h | 9 +- .../distributed/collective/process_group.cc | 26 +- .../distributed/collective/process_group.h | 276 ++++++++++++----- .../collective/process_group_bkcl.cc | 6 +- .../collective/process_group_bkcl.h | 17 +- .../collective/process_group_custom.cc | 5 +- .../collective/process_group_custom.h | 3 +- .../collective/process_group_gloo.cc | 11 +- .../collective/process_group_gloo.h | 13 +- .../collective/process_group_mpi.cc | 4 +- .../collective/process_group_mpi.h | 3 +- .../collective/process_group_nccl.cc | 50 ++-- .../collective/process_group_nccl.h | 15 +- .../collective/process_group_stream.cc | 279 ------------------ .../collective/process_group_stream.h | 203 ------------- .../collective/process_group_with_stream.h | 258 ++++++++++++++++ .../collective/process_group_without_stream.h | 69 +++++ paddle/fluid/pybind/distributed_py.cc | 45 ++- 21 files changed, 625 insertions(+), 686 deletions(-) delete mode 100644 paddle/fluid/distributed/collective/process_group_stream.cc delete mode 100644 paddle/fluid/distributed/collective/process_group_stream.h create mode 100644 paddle/fluid/distributed/collective/process_group_with_stream.h create mode 100644 paddle/fluid/distributed/collective/process_group_without_stream.h diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 3b76733a61..23e1b48d15 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -2,14 +2,11 @@ cc_library( process_group SRCS process_group.cc DEPS dense_tensor) -cc_library( - process_group_stream - SRCS process_group_stream.cc - DEPS dense_tensor) + cc_library( eager_reducer SRCS reducer.cc - DEPS eager_api process_group process_group_stream phi_api string_helper) + DEPS eager_api process_group phi_api string_helper) if(WITH_DISTRIBUTE) cc_library( @@ -23,7 +20,6 @@ if(WITH_NCCL OR WITH_RCCL) process_group_nccl SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc DEPS process_group - process_group_stream place enforce collective_helper diff --git a/paddle/fluid/distributed/collective/check.cc b/paddle/fluid/distributed/collective/check.cc index 151d7f3574..a5cd37dbc3 100644 --- a/paddle/fluid/distributed/collective/check.cc +++ b/paddle/fluid/distributed/collective/check.cc @@ -15,9 +15,9 @@ #include "paddle/fluid/distributed/collective/check.h" #include "paddle/fluid/distributed/collective/nccl_tools.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" #ifdef PADDLE_WITH_HIP diff --git a/paddle/fluid/distributed/collective/nccl_tools.cc b/paddle/fluid/distributed/collective/nccl_tools.cc index ffb51d706d..940c8d47cc 100644 --- a/paddle/fluid/distributed/collective/nccl_tools.cc +++ b/paddle/fluid/distributed/collective/nccl_tools.cc @@ -14,13 +14,16 @@ #include "paddle/fluid/distributed/collective/nccl_tools.h" -#include "paddle/fluid/platform/enforce.h" +#include + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" namespace paddle { namespace distributed { ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { - static const std::map red_type = { + static const std::unordered_map red_type = { {ReduceOp::MIN, ncclMin}, {ReduceOp::MAX, ncclMax}, {ReduceOp::SUM, ncclSum}, @@ -29,7 +32,7 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { auto it = red_type.find(reduction); PADDLE_ENFORCE_EQ(it != red_type.end(), true, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Invalid nccl reduction. Must be ncclMin | ncclMax | " "ncclProd | ncclSum")); return it->second; diff --git a/paddle/fluid/distributed/collective/nccl_tools.h b/paddle/fluid/distributed/collective/nccl_tools.h index 2efef5381b..135aadd2a2 100644 --- a/paddle/fluid/distributed/collective/nccl_tools.h +++ b/paddle/fluid/distributed/collective/nccl_tools.h @@ -14,20 +14,15 @@ #pragma once -#ifdef PADDLE_WITH_CUDA -#include -#endif -#ifdef PADDLE_WITH_HIP -#include -#endif - #include #include "paddle/fluid/distributed/collective/types.h" #ifdef PADDLE_WITH_RCCL +#include #include "paddle/phi/backends/dynload/rccl.h" #else +#include #include "paddle/phi/backends/dynload/nccl.h" #endif diff --git a/paddle/fluid/distributed/collective/process_group.cc b/paddle/fluid/distributed/collective/process_group.cc index fc05028059..d670477f2d 100644 --- a/paddle/fluid/distributed/collective/process_group.cc +++ b/paddle/fluid/distributed/collective/process_group.cc @@ -17,44 +17,20 @@ namespace paddle { namespace distributed { -ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op) - : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} - -ProcessGroup::Task::~Task() = default; - bool ProcessGroup::Task::IsCompleted() { std::lock_guard lock(mutex_); return is_completed_; } -bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { - return false; -} - -void ProcessGroup::Task::Synchronize() {} - -void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {} - ProcessGroup::ProcessGroup(int rank, int size, int gid) : rank_(rank), size_(size), gid_(gid) { - if (gid != IGNORE_ID) { + if (gid != kIgnoreId) { auto map = ProcessGroupMapFromGid::getInstance(); map->insert(gid_, this); } } // TODO(sunyilun): methods below will be removed later -ProcessGroup::Task::Task(int rank, - const std::vector& inputs, - CommType comm_type) - : rank_(rank), comm_type_(comm_type) {} - -ProcessGroup::Task::Task(int rank, - const std::vector& inputs, - CommType comm_type, - bool sync_op) - : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} - ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { static ProcessGroupIdMap instance; return instance; diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index fa02ed22ee..4980dfc307 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -17,21 +17,22 @@ #include #include #include +#include #include #include "paddle/fluid/distributed/collective/types.h" -#include "paddle/fluid/eager/api/utils/tensor_utils.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/eager/api/utils/tensor_utils.h" // NOTE: this header is required somewhere +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" constexpr auto kWaitTimeout = std::chrono::milliseconds(0); namespace paddle { namespace distributed { -constexpr int IGNORE_ID = -1; -using Tensor = paddle::experimental::Tensor; +constexpr int kIgnoreId = -1; enum class CommType : std::uint8_t { BROADCAST = 0, @@ -53,23 +54,28 @@ class ProcessGroup { public: class Task { public: - Task(int rank, CommType comm_type, bool sync_op); + Task(int rank, CommType comm_type, bool sync_op) + : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} + virtual ~Task() = default; - virtual ~Task(); virtual bool IsCompleted(); - virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); - virtual void Synchronize(); - virtual void UpdateWaitChain(const phi::DeviceContext& ctx); + virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) { + return false; + } + virtual void Synchronize() {} + virtual void UpdateWaitChain(const phi::DeviceContext& ctx) {} bool IsSync() const { return sync_op_; } // TODO(sunyilun): methods below will be removed later Task(int rank, const std::vector& inputs, - CommType comm_type); + CommType comm_type) + : rank_(rank), comm_type_(comm_type) {} Task(int rank, const std::vector& inputs, CommType comm_type, - bool sync_op); + bool sync_op) + : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} protected: const int rank_; @@ -92,20 +98,26 @@ class ProcessGroup { virtual std::string GetBackendName() const = 0; virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support get device_context.", GetBackendName())); } + virtual phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const { + PADDLE_THROW(phi::errors::Unimplemented( + "ProcessGroup%s does not support get device_context.", + GetBackendName())); + } + + // without stream APIs virtual std::shared_ptr AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, bool sync_op) { - return AllGather(out_tensor, - in_tensor, - /*offset*/ 0, - /*numel*/ -1, // -1 indicates the whole tensor - sync_op); + PADDLE_THROW(phi::errors::Unimplemented( + "ProcessGroup%s does not support all_gather with sync_op flag.", + GetBackendName())); } virtual std::shared_ptr AllGather( @@ -114,7 +126,7 @@ class ProcessGroup { int64_t offset, int64_t numel, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support all_gather with sync_op flag.", GetBackendName())); } @@ -124,7 +136,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor, const AllreduceOptions& opts, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support all_reduce with sync_op flag.", GetBackendName())); } @@ -135,14 +147,14 @@ class ProcessGroup { const std::vector& out_size_each_rank, const std::vector& in_size_each_rank, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support all_to_all with sync_op flag.", GetBackendName())); } virtual std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support barrier.", GetBackendName())); } @@ -151,7 +163,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor, const BroadcastOptions& opts, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support broadcast with sync_op flag", GetBackendName())); } @@ -161,7 +173,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor, const ReduceOptions& opts, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support reduce with sync_op flag.", GetBackendName())); } @@ -171,7 +183,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor, const ReduceScatterOptions& opts, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support reduce_scatter with sync_op flag.", GetBackendName())); } @@ -181,7 +193,7 @@ class ProcessGroup { const phi::DenseTensor& in_tensor, const ScatterOptions& opts, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support scatter with sync_op flag.", GetBackendName())); } @@ -189,11 +201,9 @@ class ProcessGroup { virtual std::shared_ptr Recv(phi::DenseTensor* tensor, int src_rank, bool sync_op) { - return Recv(tensor, - src_rank, - /*offset*/ 0, - /*numel*/ -1, // -1 indicates the whole tensor - sync_op); + PADDLE_THROW(phi::errors::Unimplemented( + "ProcessGroup%s does not support recv with sync_op flag.", + GetBackendName())); } virtual std::shared_ptr Recv(phi::DenseTensor* tensor, @@ -201,18 +211,16 @@ class ProcessGroup { int64_t offset, int64_t numel, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support recv with sync_op flag.", GetBackendName())); } virtual std::shared_ptr Send( const phi::DenseTensor& tensor, int dst_rank, bool sync_op) { - return Send(tensor, - dst_rank, - /*offset*/ 0, - /*numel*/ -1, // -1 indicates the whole tensor - sync_op); + PADDLE_THROW(phi::errors::Unimplemented( + "ProcessGroup%s does not support send with sync_op flag.", + GetBackendName())); } virtual std::shared_ptr Send( @@ -221,17 +229,162 @@ class ProcessGroup { int64_t offset, int64_t numel, bool sync_op) { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "ProcessGroup%s does not support send with sync_op flag.", GetBackendName())); } + // stream APIs + virtual std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support all_gather " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support all_gather " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support all_reduce " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support all_to_all " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support broadcast " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support reduce " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(phi::errors::Unimplemented( + "ProcessGroup%s does not support reduce_scatter " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support scatter " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support recv with " + "sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support recv " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Send( + const phi::DenseTensor& tensor, + int dst_rank, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support send " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Send( + const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support send " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + // legacy APIs // TODO(liyurui): This API will be moved later virtual std::shared_ptr AllReduce( std::vector& /* input tensors */, // NOLINT std::vector& /* output tensors */, // NOLINT const AllreduceOptions& = AllreduceOptions()) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support allreduce", GetBackendName())); } @@ -240,7 +393,7 @@ class ProcessGroup { std::vector& /* output tensors */, // NOLINT const AllreduceOptions&, bool) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support allreduce with sync_op flag", GetBackendName())); } @@ -250,7 +403,7 @@ class ProcessGroup { std::vector& /* input tensors */, // NOLINT std::vector& /* output tensors */, // NOLINT const BroadcastOptions& = BroadcastOptions()) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support broadcast", GetBackendName())); } @@ -259,27 +412,27 @@ class ProcessGroup { std::vector& /* output tensors */, // NOLINT const BroadcastOptions&, bool) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support broadcast with sync_op flag", GetBackendName())); } virtual std::shared_ptr Send( std::vector&, int) { // NOLINT - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support send", GetBackendName())); } virtual std::shared_ptr Recv( std::vector&, int) { // NOLINT - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support recv", GetBackendName())); } virtual std::shared_ptr AllGather( std::vector&, // NOLINT std::vector&) { // NOLINT - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support all_gather", GetBackendName())); } @@ -287,7 +440,7 @@ class ProcessGroup { std::vector&, // NOLINT std::vector&, // NOLINT bool) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support all_gather with sync_op flag", GetBackendName())); } @@ -295,7 +448,7 @@ class ProcessGroup { virtual std::shared_ptr AllToAll( std::vector&, // NOLINT std::vector&) { // NOLINT - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support AllToAll", GetBackendName())); } @@ -303,7 +456,7 @@ class ProcessGroup { std::vector&, // NOLINT std::vector&, // NOLINT const ReduceOptions& opts) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support reduce", GetBackendName())); } @@ -311,7 +464,7 @@ class ProcessGroup { std::vector&, // NOLINT std::vector&, // NOLINT const ScatterOptions&) { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "ProcessGroup%s does not support scatter", GetBackendName())); } @@ -330,28 +483,11 @@ class ProcessGroupIdMap // TODO(dev): The following method will be removed soon. class ProcessGroupMapFromGid { public: - bool has(int gid) { - auto it = map_.find(gid); - return it != map_.end(); - } + bool has(int gid) { return map_.find(gid) != map_.end(); } - void insert(int gid, ProcessGroup* pg) { - // TODO(sandyhouse): address ut and uncomment the following codes - // PADDLE_ENFORCE_EQ(has(gid), false, - // platform::errors::PreconditionNotMet( - // "The process group with id %d doesnot exist.", - // gid)); - map_[gid] = pg; - } + void insert(int gid, ProcessGroup* pg) { map_[gid] = pg; } - ProcessGroup* get(int gid) { - // TODO(sandyhouse): address ut and uncomment the following codes - // PADDLE_ENFORCE_EQ(has(gid), true, - // platform::errors::PreconditionNotMet( - // "The process group with id %d doesnot exist.", - // gid)); - return map_.find(gid)->second; - } + ProcessGroup* get(int gid) { return map_.find(gid)->second; } static std::shared_ptr getInstance() { static auto s_instance = std::make_shared(); diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 356b5127bb..e1f35ecd5e 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -18,8 +18,8 @@ #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" namespace paddle { @@ -76,7 +76,7 @@ ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr& store, int rank, int size, int gid) - : ProcessGroupStream(rank, size, gid), store_(store) {} + : ProcessGroupWithStream(rank, size, gid), store_(store) {} void ProcessGroupBKCL::GroupStart() { PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.h b/paddle/fluid/distributed/collective/process_group_bkcl.h index cb7132ac9f..15c908554b 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.h +++ b/paddle/fluid/distributed/collective/process_group_bkcl.h @@ -15,17 +15,16 @@ #pragma once #include -#include #include #include #include -#include "paddle/fluid/distributed/collective/process_group_stream.h" +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" -#include "paddle/fluid/platform/place.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" #if defined(PADDLE_WITH_XPU) @@ -37,12 +36,12 @@ constexpr const char* BKCL_BACKEND_NAME = "BKCL"; namespace paddle { namespace distributed { -using Place = paddle::platform::Place; +using Place = phi::Place; // BKCL funcs use separate communication stream by default -class ProcessGroupBKCL : public ProcessGroupStream { +class ProcessGroupBKCL : public ProcessGroupWithStream { public: - class BKCLTask final : public ProcessGroupStream::TaskStream, + class BKCLTask final : public ProcessGroupWithStream::TaskStream, public std::enable_shared_from_this { public: BKCLTask(const Place& place, @@ -161,12 +160,12 @@ class ProcessGroupBKCL : public ProcessGroupStream { bool sync_op, bool use_calc_stream); - void BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id); // NOLINT + void BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id); void CreateBKCLEnvCache(const Place& place, const std::string& place_key); template - std::shared_ptr Collective( + std::shared_ptr Collective( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, Fn fn, diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 9f7c2eeb2a..2fb23b455c 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/common/place.h" DECLARE_bool(xccl_blocking_wait); @@ -102,7 +103,9 @@ ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr& store, int rank, int size, int gid) - : ProcessGroup(rank, size, gid), store_(store), device_type_(device_type) {} + : ProcessGroupWithoutStream(rank, size, gid), + store_(store), + device_type_(device_type) {} void ProcessGroupCustom::BroadcastUniqueCustomID( std::vector& ccl_ids) { // NOLINT diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 3e55d150d3..3169b9d5bc 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -23,6 +23,7 @@ #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_without_stream.h" #include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/device/npu/npu_stream.h" #include "paddle/fluid/platform/device_context.h" @@ -34,7 +35,7 @@ namespace paddle { namespace distributed { using Place = paddle::platform::Place; using CustomDeviceContext = paddle::platform::CustomDeviceContext; -class ProcessGroupCustom : public ProcessGroup { +class ProcessGroupCustom : public ProcessGroupWithoutStream { public: class CustomTask : public ProcessGroup::Task, public std::enable_shared_from_this { diff --git a/paddle/fluid/distributed/collective/process_group_gloo.cc b/paddle/fluid/distributed/collective/process_group_gloo.cc index d547754938..8e3fcc8ec5 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.cc +++ b/paddle/fluid/distributed/collective/process_group_gloo.cc @@ -182,7 +182,7 @@ ProcessGroupGloo::ProcessGroupGloo( int world_size, int gid, const std::shared_ptr options) - : ProcessGroup(rank, world_size, gid), + : ProcessGroupWithoutStream(rank, world_size, gid), _tag(0), _store(new GlooStore(store)) { _context = std::make_shared(rank, world_size); @@ -400,15 +400,6 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { } }; -std::shared_ptr ProcessGroupGloo::AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - bool sync_op) { - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - return AllGather(in_wrapper, out_wrapper, true); -} - std::shared_ptr ProcessGroupGloo::AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, diff --git a/paddle/fluid/distributed/collective/process_group_gloo.h b/paddle/fluid/distributed/collective/process_group_gloo.h index 97b8338720..4a72a58ee1 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.h +++ b/paddle/fluid/distributed/collective/process_group_gloo.h @@ -19,18 +19,18 @@ #include #include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_without_stream.h" +#include "paddle/fluid/distributed/store/store.h" +#include "paddle/fluid/distributed/store/tcp_store.h" #ifdef PADDLE_WITH_GLOO #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif -#include "paddle/fluid/distributed/store/store.h" -#include "paddle/fluid/distributed/store/tcp_store.h" - namespace paddle { namespace distributed { -class ProcessGroupGloo : public ProcessGroup { +class ProcessGroupGloo : public ProcessGroupWithoutStream { public: class GlooTask : public ProcessGroup::Task, public std::enable_shared_from_this { @@ -120,11 +120,6 @@ class ProcessGroupGloo : public ProcessGroup { int64_t /*numel*/, // for compatibility, no use now bool sync_op) override; - std::shared_ptr AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - bool sync_op) override; - std::shared_ptr AllReduce( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, diff --git a/paddle/fluid/distributed/collective/process_group_mpi.cc b/paddle/fluid/distributed/collective/process_group_mpi.cc index 796e0bb692..771c745865 100644 --- a/paddle/fluid/distributed/collective/process_group_mpi.cc +++ b/paddle/fluid/distributed/collective/process_group_mpi.cc @@ -174,7 +174,9 @@ std::shared_ptr ProcessGroupMPI::CreateProcessGroupMPI( } ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pg_comm, int gid) - : ProcessGroup(rank, size, gid), stop_(false), pg_comm(pg_comm) { + : ProcessGroupWithoutStream(rank, size, gid), + stop_(false), + pg_comm(pg_comm) { PADDLE_ENFORCE_EQ( pg_comm == MPI_COMM_NULL, false, diff --git a/paddle/fluid/distributed/collective/process_group_mpi.h b/paddle/fluid/distributed/collective/process_group_mpi.h index dd6793ed07..0c497da02b 100644 --- a/paddle/fluid/distributed/collective/process_group_mpi.h +++ b/paddle/fluid/distributed/collective/process_group_mpi.h @@ -26,6 +26,7 @@ #include #include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_without_stream.h" #include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/platform/device_context.h" @@ -57,7 +58,7 @@ struct TaskEntry { std::function&)> run_; }; -class ProcessGroupMPI : public ProcessGroup { +class ProcessGroupMPI : public ProcessGroupWithoutStream { public: class MPITask : public ProcessGroup::Task { public: diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 1353ea719a..425edf6e37 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -18,9 +18,11 @@ #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/nccl_tools.h" #include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/core/enforce.h" DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -88,7 +90,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, int gid) - : ProcessGroupStream(rank, size, gid), store_(store) {} + : ProcessGroupWithStream(rank, size, gid), store_(store) {} void ProcessGroupNCCL::GroupStart() { NCCL_CHECK(phi::dynload::ncclGroupStart()); @@ -112,7 +114,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( PADDLE_ENFORCE_NE( iter, place_to_comm_ctx_.end(), - platform::errors::NotFound( + phi::errors::NotFound( "Cannot find the device context in this process group.")); return iter->second.get(); } @@ -124,7 +126,7 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { PADDLE_ENFORCE_NE( iter, place_to_comm_ctx_.end(), - platform::errors::NotFound( + phi::errors::NotFound( "Cannot find the NCCL commmunicator in this process group.")); return iter->second->nccl_comm(); } @@ -207,7 +209,7 @@ void CheckSizeOnEachRank(const phi::DDim& tensor_dim, PADDLE_ENFORCE_EQ( length_size_on_each_rank, world_size, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The length of size_on_each_rank must be equal to world_size.")); int64_t sum_size_on_each_rank = @@ -215,7 +217,7 @@ void CheckSizeOnEachRank(const phi::DDim& tensor_dim, PADDLE_ENFORCE_EQ( sum_size_on_each_rank, tensor_dim[0], - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The sum of size_on_each_rank must be equal to tensor's dim[0].")); } @@ -289,7 +291,7 @@ std::shared_ptr ProcessGroupNCCL::Barrier( const BarrierOptions& opts) { PADDLE_ENFORCE_GE(opts.device_id, 0, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The barrier device id must greater or equal than 0.")); platform::CUDAPlace place(opts.device_id); auto allocator = std::unique_ptr( @@ -681,7 +683,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( const std::string& places_key, const std::vector& places) { PADDLE_ENFORCE_EQ(places_key.empty(), false, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "Not able to create/get the NCCL Communicator since " "the GPU place are not known")); @@ -837,7 +839,7 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); return Collective( in_tensors, out_tensors, @@ -864,7 +866,7 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); return Collective( in_tensors, @@ -892,25 +894,25 @@ void CheckTensorsInDifferentDevices( PADDLE_ENFORCE_EQ( tensors.size() == 0, false, - platform::errors::InvalidArgument("Tensor list must be nonempty.")); + phi::errors::InvalidArgument("Tensor list must be nonempty.")); PADDLE_ENFORCE_LE( tensors.size(), num_devices, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Tensor list mustn't be larger than the number of available GPUs.")); std::set used_devices; for (const auto& t : tensors) { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()), - true, - platform::errors::InvalidArgument( - "Tensors must be CUDA and dense tensor.")); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(t.place()), + true, + phi::errors::InvalidArgument("Tensors must be CUDA and dense tensor.")); const auto inserted = used_devices.insert(t.place()).second; PADDLE_ENFORCE_EQ(inserted, true, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Tensors must be on distinct GPU devices.")); } } @@ -965,11 +967,11 @@ std::shared_ptr ProcessGroupNCCL::AllGather( PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(out_tensors), true, - platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All outputs should be in CudaPlace.")); return Collective( in_tensors, out_tensors, @@ -1019,7 +1021,7 @@ void* GetPointerByOffset(void* raw_pointer, return reinterpret_cast(reinterpret_cast(raw_pointer) + offset); } else { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "Datatype %s in NCCL is not supported.", type)); } return nullptr; @@ -1031,11 +1033,11 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(out_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); return Collective( in_tensors, out_tensors, @@ -1074,7 +1076,7 @@ std::shared_ptr ProcessGroupNCCL::Reduce( PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); return Collective( in_tensors, out_tensors, @@ -1102,11 +1104,11 @@ std::shared_ptr ProcessGroupNCCL::Scatter( PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(out_tensors), true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); return Collective( in_tensors, out_tensors, diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index 816a0d2ec9..9d268cb03f 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -20,11 +20,10 @@ #include #include -#include "paddle/fluid/distributed/collective/process_group_stream.h" +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/distributed/store/store.h" -#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device_event.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" @@ -41,11 +40,11 @@ namespace paddle { namespace distributed { -using Place = paddle::platform::Place; +using Place = phi::Place; -class ProcessGroupNCCL final : public ProcessGroupStream { +class ProcessGroupNCCL final : public ProcessGroupWithStream { public: - class NCCLTask final : public ProcessGroupStream::TaskStream, + class NCCLTask final : public ProcessGroupWithStream::TaskStream, public std::enable_shared_from_this { public: NCCLTask(const Place& place, @@ -86,11 +85,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream { std::string GetBackendName() const override { return "NCCL"; } + phi::DeviceContext* GetDeviceContext(const Place& place) const override; + phi::DeviceContext* GetDeviceContext(const Place& place, bool use_calc_stream) const override; - phi::DeviceContext* GetDeviceContext(const Place& place) const override; - std::shared_ptr AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, diff --git a/paddle/fluid/distributed/collective/process_group_stream.cc b/paddle/fluid/distributed/collective/process_group_stream.cc deleted file mode 100644 index 2b69cf51fe..0000000000 --- a/paddle/fluid/distributed/collective/process_group_stream.cc +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/distributed/collective/process_group_stream.h" - -namespace paddle { -namespace distributed { - -ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid) - : ProcessGroup(rank, size, gid) {} - -phi::DeviceContext* ProcessGroupStream::GetDeviceContext( - const Place& place, bool use_calc_stream) const { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support get device_context.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - int64_t offset, - int64_t numel, - bool sync_op) { - return AllGather(out_tensor, - in_tensor, - offset, - numel, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - bool sync_op, - bool use_calc_stream) { - return AllGather(out_tensor, - in_tensor, - /*offset*/ 0, - /*numel*/ -1, // -1 indicates the whole tensor - sync_op, - use_calc_stream); -} - -std::shared_ptr ProcessGroupStream::AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support all_gather.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::AllReduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const AllreduceOptions& opts, - bool sync_op) { - return AllReduce(out_tensor, - in_tensor, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::AllReduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const AllreduceOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support all_reduce.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::AllToAll( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& out_size_each_rank, - const std::vector& in_size_each_rank, - bool sync_op) { - return AllToAll(out_tensor, - in_tensor, - out_size_each_rank, - in_size_each_rank, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::AllToAll( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& out_size_each_rank, - const std::vector& in_size_each_rank, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support all_to_all.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::Broadcast( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const BroadcastOptions& opts, - bool sync_op) { - return Broadcast(out_tensor, - in_tensor, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Broadcast( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const BroadcastOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support broadcast.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::Reduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceOptions& opts, - bool sync_op) { - return Reduce(out_tensor, - in_tensor, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Reduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support reduce.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::ReduceScatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceScatterOptions& opts, - bool sync_op) { - return ReduceScatter(out_tensor, - in_tensor, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::ReduceScatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support reduce_scatter.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::Scatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ScatterOptions& opts, - bool sync_op) { - return Scatter(out_tensor, - in_tensor, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Scatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support scatter.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::Recv( - phi::DenseTensor* tensor, - int src_rank, - int64_t offset, - int64_t numel, - bool sync_op) { - return Recv(tensor, - src_rank, - offset, - numel, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Recv( - phi::DenseTensor* tensor, - int src_rank, - bool sync_op, - bool use_calc_stream) { - return Recv(tensor, - src_rank, - /*offset*/ 0, - /*numel*/ -1, // -1 indicates sending the whole tensor - sync_op, - use_calc_stream); -} - -std::shared_ptr ProcessGroupStream::Recv( - phi::DenseTensor* tensor, - int src_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support recv.", GetBackendName())); -} - -std::shared_ptr ProcessGroupStream::Send( - const phi::DenseTensor& tensor, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op) { - return Send(tensor, - dst_rank, - offset, - numel, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Send( - const phi::DenseTensor& tensor, - int dst_rank, - bool sync_op, - bool use_calc_stream) { - return Send(tensor, - dst_rank, - /*offset*/ 0, - /*numel*/ -1, // -1 indicates receiving the whole tensor - sync_op, - use_calc_stream); -} - -std::shared_ptr ProcessGroupStream::Send( - const phi::DenseTensor& tensor, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::Unimplemented( - "ProcessGroup%s does not support send.", GetBackendName())); -} - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_stream.h b/paddle/fluid/distributed/collective/process_group_stream.h deleted file mode 100644 index 43827669cd..0000000000 --- a/paddle/fluid/distributed/collective/process_group_stream.h +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/distributed/collective/process_group.h" - -namespace paddle { -namespace distributed { - -// NOTE(liyurui): Notice that some backends use `stream` as an abstract -// conception of hardward resource. We provide this base class allowing users to -// put communications on calculation stream. In some scenorios, we found this -// will save the time of switching streams. -class ProcessGroupStream : public ProcessGroup { - public: - class TaskStream : public ProcessGroup::Task { - public: - TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream) - : Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {} - - virtual ~TaskStream() = default; - - // TODO(liyurui): This constructor is temporary here for compatible reason, - // will be deleted soon. - TaskStream(int rank, - const std::vector& inputs, - CommType comm_type) - : Task(rank, inputs, comm_type) {} - - TaskStream(int rank, - const std::vector& inputs, - CommType comm_type, - bool sync_op, - bool use_calc_stream) - : Task(rank, inputs, comm_type, sync_op), - use_calc_stream_(use_calc_stream) {} - - protected: - bool UseCalcStream() const { return use_calc_stream_; } - - private: - bool use_calc_stream_{false}; - }; - - public: - ProcessGroupStream(int rank, int size, int gid); - virtual ~ProcessGroupStream() = default; - using ProcessGroup::GetDeviceContext; - - virtual phi::DeviceContext* GetDeviceContext(const Place& place, - bool use_calc_stream) const; - - std::shared_ptr AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - int64_t offset, - int64_t numel, - bool sync_op) override; - - virtual std::shared_ptr AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - bool sync_op, - bool use_calc_stream); - - virtual std::shared_ptr AllGather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr AllReduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const AllreduceOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr AllReduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const AllreduceOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr AllToAll( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& out_size_each_rank, - const std::vector& in_size_each_rank, - bool sync_op) override; - - virtual std::shared_ptr AllToAll( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& out_size_each_rank, - const std::vector& in_size_each_rank, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr Broadcast( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const BroadcastOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr Broadcast( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const BroadcastOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr Reduce(phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr Reduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr ReduceScatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceScatterOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr ReduceScatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr Scatter(phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ScatterOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr Scatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr Recv(phi::DenseTensor* tensor, - int src_rank, - int64_t offset, - int64_t numel, - bool sync_op) override; - - virtual std::shared_ptr Recv(phi::DenseTensor* tensor, - int src_rank, - bool sync_op, - bool use_calc_stream); - - virtual std::shared_ptr Recv(phi::DenseTensor* tensor, - int src_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream); - - std::shared_ptr Send(const phi::DenseTensor& tensor, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op) override; - - std::shared_ptr Send(const phi::DenseTensor& tensor, - int dst_rank, - bool sync_op, - bool use_calc_stream); - - virtual std::shared_ptr Send( - const phi::DenseTensor& tensor, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream); -}; - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_with_stream.h b/paddle/fluid/distributed/collective/process_group_with_stream.h new file mode 100644 index 0000000000..375d230cf6 --- /dev/null +++ b/paddle/fluid/distributed/collective/process_group_with_stream.h @@ -0,0 +1,258 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace distributed { + +// NOTE: Notice that some backends use `stream` as an abstract conception of +// hardward resource. We provide this base class allowing users to put +// communications on calculation stream. In some scenorios, we found this will +// save the time of switching streams. +class ProcessGroupWithStream : public ProcessGroup { + public: + class TaskStream : public ProcessGroup::Task { + public: + TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream) + : Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {} + + virtual ~TaskStream() = default; + + // TODO(liyurui): This constructor is temporary here for compatible reason, + // will be deleted soon. + TaskStream(int rank, + const std::vector& inputs, + CommType comm_type) + : Task(rank, inputs, comm_type) {} + + TaskStream(int rank, + const std::vector& inputs, + CommType comm_type, + bool sync_op, + bool use_calc_stream) + : Task(rank, inputs, comm_type, sync_op), + use_calc_stream_(use_calc_stream) {} + + protected: + bool UseCalcStream() const { return use_calc_stream_; } + + private: + bool use_calc_stream_{false}; + }; + + public: + ProcessGroupWithStream(int rank, int size, int gid) + : ProcessGroup(rank, size, gid) {} + + virtual ~ProcessGroupWithStream() = default; + + // methods from base class + using ProcessGroup::AllGather; + using ProcessGroup::AllReduce; + using ProcessGroup::AllToAll; + using ProcessGroup::Broadcast; + using ProcessGroup::Recv; + using ProcessGroup::Reduce; + using ProcessGroup::ReduceScatter; + using ProcessGroup::Scatter; + using ProcessGroup::Send; + + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op) override { + return AllGather(out_tensor, + in_tensor, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op); + } + + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int64_t offset, + int64_t numel, + bool sync_op) override { + return AllGather(out_tensor, + in_tensor, + offset, + numel, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op, + bool use_calc_stream) override { + return AllGather(out_tensor, + in_tensor, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op, + use_calc_stream); + } + + std::shared_ptr AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op) override { + return AllReduce(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op) override { + return AllToAll(out_tensor, + in_tensor, + out_size_each_rank, + in_size_each_rank, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op) override { + return Broadcast(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op) override { + return Reduce(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op) override { + return ReduceScatter(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op) override { + return Scatter(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + bool sync_op) override { + return Recv(tensor, + src_rank, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op); + } + + std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + int64_t offset, + int64_t numel, + bool sync_op) override { + return Recv(tensor, + src_rank, + offset, + numel, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + bool sync_op, + bool use_calc_stream) override { + return Recv(tensor, + src_rank, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates sending the whole tensor + sync_op, + use_calc_stream); + } + + std::shared_ptr Send(const phi::DenseTensor& tensor, + int dst_rank, + bool sync_op) override { + return Send(tensor, + dst_rank, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op); + } + + std::shared_ptr Send(const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op) override { + return Send(tensor, + dst_rank, + offset, + numel, + sync_op, + /*use_calc_stream*/ false); + } + + std::shared_ptr Send(const phi::DenseTensor& tensor, + int dst_rank, + bool sync_op, + bool use_calc_stream) override { + return Send(tensor, + dst_rank, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates receiving the whole tensor + sync_op, + use_calc_stream); + } +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_without_stream.h b/paddle/fluid/distributed/collective/process_group_without_stream.h new file mode 100644 index 0000000000..ee05906966 --- /dev/null +++ b/paddle/fluid/distributed/collective/process_group_without_stream.h @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace distributed { + +class ProcessGroupWithoutStream : public ProcessGroup { + public: + ProcessGroupWithoutStream(int rank, int size, int gid) + : ProcessGroup(rank, size, gid) {} + + virtual ~ProcessGroupWithoutStream() = default; + + // methods from base class + using ProcessGroup::AllGather; + using ProcessGroup::Recv; + using ProcessGroup::Send; + + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op) override { + return AllGather(out_tensor, + in_tensor, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op); + } + + std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + bool sync_op) override { + return Recv(tensor, + src_rank, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op); + } + + std::shared_ptr Send(const phi::DenseTensor& tensor, + int dst_rank, + bool sync_op) override { + return Send(tensor, + dst_rank, + /*offset*/ 0, + /*numel*/ -1, // -1 indicates the whole tensor + sync_op); + } +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 9515ca7f64..8bac6f92e7 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -22,7 +22,6 @@ limitations under the License. */ #endif #include "paddle/fluid/distributed/collective/process_group.h" -#include "paddle/fluid/distributed/collective/process_group_stream.h" #include "paddle/fluid/distributed/collective/reducer.h" #include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -733,15 +732,11 @@ void BindDistributed(py::module *m) { py::arg("in"), py::arg("out"), py::arg("src"), - py::call_guard()); + py::call_guard()) - auto ProcessGroupStream = - py::class_>( - *m, "ProcessGroupStream", ProcessGroup) .def( "all_gather_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor_list, py::handle py_in_tensor) { auto out_tensor_list = @@ -770,7 +765,7 @@ void BindDistributed(py::module *m) { .def( "all_gather_into_tensor_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor) { auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); @@ -794,7 +789,7 @@ void BindDistributed(py::module *m) { .def( "all_gather_partial_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor, int nranks, @@ -828,7 +823,7 @@ void BindDistributed(py::module *m) { .def( "all_reduce_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, distributed::ReduceOp op) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); @@ -849,7 +844,7 @@ void BindDistributed(py::module *m) { .def( "all_to_all_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor_list, py::handle py_in_tensor_list) { auto out_tensor_list = @@ -886,7 +881,7 @@ void BindDistributed(py::module *m) { .def( "all_to_all_tensor_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor) { auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); @@ -914,7 +909,7 @@ void BindDistributed(py::module *m) { .def( "all_to_all_single_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor, const std::vector &out_sizes, @@ -944,7 +939,7 @@ void BindDistributed(py::module *m) { .def( "broadcast_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, int src) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); @@ -965,7 +960,7 @@ void BindDistributed(py::module *m) { .def( "reduce_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, int dst, distributed::ReduceOp op) { @@ -988,7 +983,7 @@ void BindDistributed(py::module *m) { .def( "reduce_scatter_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor_list, distributed::ReduceOp op) { @@ -1018,7 +1013,7 @@ void BindDistributed(py::module *m) { .def( "reduce_scatter_tensor_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor, distributed::ReduceOp op) { @@ -1046,7 +1041,7 @@ void BindDistributed(py::module *m) { .def( "scatter_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor_list, int src) { @@ -1076,7 +1071,7 @@ void BindDistributed(py::module *m) { .def( "scatter_tensor_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_out_tensor, py::handle py_in_tensor, int src) { @@ -1104,7 +1099,7 @@ void BindDistributed(py::module *m) { .def( "send_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, int dst) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); @@ -1122,7 +1117,7 @@ void BindDistributed(py::module *m) { .def( "send_partial_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, int dst_rank, int nranks, @@ -1151,7 +1146,7 @@ void BindDistributed(py::module *m) { .def( "recv_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, int src) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); @@ -1169,7 +1164,7 @@ void BindDistributed(py::module *m) { .def( "recv_partial_on_calc_stream", - [](distributed::ProcessGroupStream &self, + [](distributed::ProcessGroup &self, py::handle py_tensor, int src_rank, int nranks, @@ -1199,7 +1194,7 @@ void BindDistributed(py::module *m) { #if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) py::class_>( - *m, "ProcessGroupNCCL", ProcessGroupStream) + *m, "ProcessGroupNCCL", ProcessGroup) .def_static("create", distributed::ProcessGroupNCCL::CreateProcessGroupNCCL, py::arg("store"), @@ -1250,7 +1245,7 @@ void BindDistributed(py::module *m) { auto processGroupBKCL = py::class_>( - *m, "ProcessGroupBKCL", ProcessGroupStream) + *m, "ProcessGroupBKCL", ProcessGroup) .def_static("create", distributed::ProcessGroupBKCL::CreateProcessGroupBKCL, py::arg("store"), -- GitLab