From e1a1c35493d57035c857b7e1c5c9e8dff4ac9408 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Mon, 7 Nov 2022 10:19:53 +0800 Subject: [PATCH] Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481) --- .../distributed/collective/CMakeLists.txt | 1 + paddle/fluid/distributed/collective/Common.cc | 2 + paddle/fluid/distributed/collective/Common.h | 2 + .../fluid/distributed/collective/NCCLTools.h | 198 ----- .../distributed/collective/ProcessGroup.cc | 22 +- .../distributed/collective/ProcessGroup.h | 61 +- .../collective/ProcessGroupGloo.cc | 11 + .../distributed/collective/ProcessGroupGloo.h | 7 + .../collective/ProcessGroupNCCL.cc | 692 +++++++++--------- .../distributed/collective/ProcessGroupNCCL.h | 158 ++-- .../collective/ProcessGroupStream.cc | 75 +- .../collective/ProcessGroupStream.h | 55 +- .../operators/fused/fused_attention_op.cu | 7 +- .../operators/fused/fused_feedforward_op.cu | 7 +- paddle/fluid/pybind/distributed_py.cc | 130 ++-- .../communication/stream/all_gather.py | 15 +- 16 files changed, 676 insertions(+), 767 deletions(-) diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index aa816f26f93..e4d7b55d13c 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL) enforce collective_helper device_context + ${DEVICE_EVENT_LIBS} dense_tensor) if(WITH_DISTRIBUTE AND WITH_PSCORE) if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) diff --git a/paddle/fluid/distributed/collective/Common.cc b/paddle/fluid/distributed/collective/Common.cc index e9572f28d32..d968c99e479 100644 --- a/paddle/fluid/distributed/collective/Common.cc +++ b/paddle/fluid/distributed/collective/Common.cc @@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector& places) { return placeList; } +std::string GetKeyFromPlace(const Place& place) { return place.DebugString(); } + bool CheckTensorsInCudaPlace(const std::vector& tensors) { return std::all_of( tensors.cbegin(), tensors.cend(), [&](const phi::DenseTensor& t) { diff --git a/paddle/fluid/distributed/collective/Common.h b/paddle/fluid/distributed/collective/Common.h index 8d5db886989..38a3100b6eb 100644 --- a/paddle/fluid/distributed/collective/Common.h +++ b/paddle/fluid/distributed/collective/Common.h @@ -25,6 +25,8 @@ using Place = paddle::platform::Place; std::vector GetPlaceList(const std::vector& tensors); // Get the deviceList String from the list of devices std::string GetKeyFromPlaces(const std::vector& places); +// Get the device string from one device +std::string GetKeyFromPlace(const Place& place); bool CheckTensorsInCudaPlace(const std::vector& tensors); diff --git a/paddle/fluid/distributed/collective/NCCLTools.h b/paddle/fluid/distributed/collective/NCCLTools.h index c00b081438c..464ae0b6581 100644 --- a/paddle/fluid/distributed/collective/NCCLTools.h +++ b/paddle/fluid/distributed/collective/NCCLTools.h @@ -59,204 +59,6 @@ namespace distributed { } \ } while (0) -// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper. -// EventManage is different from paddle::platform::CudaEvent. -// It uses lazy initialization and is only created when the -// Record() method is called for the first time; it also monitors -// device information to ensure that recorded stream and event -// are on the same device. - -class EventManager { - public: - EventManager() {} - explicit EventManager(unsigned int flags) : flags_{flags} {} - - ~EventManager() { - if (is_created_) { - platform::CUDADeviceGuard guard(device_index_); -#ifdef PADDLE_WITH_HIP - hipEventDestroy(event_); -#else - cudaEventDestroy(event_); -#endif - } - } - - EventManager(const EventManager&) = delete; - EventManager& operator=(const EventManager&) = delete; - - EventManager(EventManager&& other) { - std::swap(flags_, other.flags_); - std::swap(is_created_, other.is_created_); - std::swap(device_index_, other.device_index_); - std::swap(event_, other.event_); - } - - EventManager& operator=(EventManager&& other) { - std::swap(flags_, other.flags_); - std::swap(is_created_, other.is_created_); - std::swap(device_index_, other.device_index_); - std::swap(event_, other.event_); - return *this; - } - - bool IsCreated() const { return is_created_; } - bool DeviceId() const { return device_index_; } - gpuEvent_t GetRawCudaEvent() const { return event_; } - - void Record(const phi::GPUContext& ctx) { - auto device_index = ctx.GetPlace().device; - if (!is_created_) { - CreateEvent(device_index); - } - PADDLE_ENFORCE_EQ(device_index, - device_index_, - platform::errors::PreconditionNotMet( - "phi::GPUContext's device %d does not match" - "Event's device %d", - device_index, - device_index_)); - - platform::CUDADeviceGuard guard(device_index_); -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_, ctx.stream())); -#else - PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event_, ctx.stream())); -#endif - } - - bool Query() const { -#ifdef PADDLE_WITH_HIP - gpuError_t err = hipEventQuery(event_); - if (err == hipSuccess) { - return true; - } - if (err == hipErrorNotReady) { - return false; - } -#else - gpuError_t err = cudaEventQuery(event_); - if (err == cudaSuccess) { - return true; - } - if (err == cudaErrorNotReady) { - return false; - } -#endif - PADDLE_ENFORCE_GPU_SUCCESS(err); - return false; - } - - void Synchronize() const { - if (is_created_) { -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event_)); -#endif - } - } - - void Block(const phi::GPUContext& ctx) const { - if (is_created_) { - auto device_index = ctx.GetPlace().device; - PADDLE_ENFORCE_EQ(device_index, - device_index_, - platform::errors::PreconditionNotMet( - "phi::GPUContext's device %d does not match" - "Event's device %d", - device_index, - device_index_)); - platform::CUDADeviceGuard guard(device_index_); - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipStreamWaitEvent(ctx.stream(), event_, 0)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamWaitEvent(ctx.stream(), event_, 0)); -#endif - } - } - - private: -#ifdef PADDLE_WITH_HIP - unsigned int flags_ = hipEventDefault; -#else - unsigned int flags_ = cudaEventDefault; -#endif - - bool is_created_{false}; - gpuEvent_t event_{}; - int8_t device_index_{0}; - - private: - void CreateEvent(int device_index) { - device_index_ = device_index; - platform::CUDADeviceGuard guard(device_index); - -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(&event_, flags_)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(&event_, flags_)); -#endif - - is_created_ = true; - } -}; - -// NOTE(shenliang03): NCCLCommManager is more lightweight than -// platform::NCCLComm - -class NCCLCommManager { - public: - explicit NCCLCommManager(ncclComm_t ncclComm) : nccl_comm_(ncclComm) {} - - NCCLCommManager() : NCCLCommManager(nullptr) {} - - ~NCCLCommManager() noexcept { - std::unique_lock lock(mutex_); - if (nccl_comm_) { - platform::dynload::ncclCommDestroy(nccl_comm_); - } - } - - static std::shared_ptr Create(int num_ranks, - int rank, - ncclUniqueId comm_id) { - auto nccl_manager = std::make_shared(); - NCCLCHECK(platform::dynload::ncclCommInitRank( - &(nccl_manager->nccl_comm_), num_ranks, comm_id, rank)); - - nccl_manager->nccl_id_ = comm_id; - nccl_manager->rank_ = rank; - return nccl_manager; - } - - ncclUniqueId GetNcclId() const { - std::unique_lock lock(mutex_); - return nccl_id_; - } - - ncclComm_t GetNcclComm() const { - std::unique_lock lock(mutex_); - return nccl_comm_; - } - - NCCLCommManager(const NCCLCommManager&) = delete; - NCCLCommManager& operator=(const NCCLCommManager&) = delete; - NCCLCommManager& operator=(NCCLCommManager&& other) = delete; - - NCCLCommManager(NCCLCommManager&& other) { - std::unique_lock lock(other.mutex_); - std::swap(nccl_comm_, other.nccl_comm_); - } - - protected: - ncclComm_t nccl_comm_; - ncclUniqueId nccl_id_; - int rank_; - mutable std::mutex mutex_; -}; - ncclRedOp_t ToNCCLRedType(ReduceOp reduction); std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); diff --git a/paddle/fluid/distributed/collective/ProcessGroup.cc b/paddle/fluid/distributed/collective/ProcessGroup.cc index e7942b714e4..72cd66467d9 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.cc +++ b/paddle/fluid/distributed/collective/ProcessGroup.cc @@ -17,15 +17,7 @@ namespace paddle { namespace distributed { -ProcessGroup::Task::Task(int rank, - const std::vector& inputs, - CommType comm_type) - : rank_(rank), comm_type_(comm_type) {} - -ProcessGroup::Task::Task(int rank, - const std::vector& inputs, - CommType comm_type, - bool sync_op) +ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op) : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} ProcessGroup::Task::~Task() = default; @@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid) } } +// TODO(sunyilun): methods below will be removed later +ProcessGroup::Task::Task(int rank, + const std::vector& inputs, + CommType comm_type) + : rank_(rank), comm_type_(comm_type) {} + +ProcessGroup::Task::Task(int rank, + const std::vector& inputs, + CommType comm_type, + bool sync_op) + : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 3926ce16f1c..7550b83265e 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -54,13 +54,7 @@ class ProcessGroup { public: class Task { public: - Task(int rank, - const std::vector& inputs, - CommType comm_type); - Task(int rank, - const std::vector& inputs, - CommType comm_type, - bool sync_op); + Task(int rank, CommType comm_type, bool sync_op); virtual ~Task(); virtual bool IsCompleted(); @@ -69,6 +63,15 @@ class ProcessGroup { virtual void UpdateWaitChain(const phi::DeviceContext& ctx); bool IsSync() const { return sync_op_; } + // TODO(sunyilun): methods below will be removed later + Task(int rank, + const std::vector& inputs, + CommType comm_type); + Task(int rank, + const std::vector& inputs, + CommType comm_type, + bool sync_op); + protected: const int rank_; CommType comm_type_{CommType::UNKNOWN}; @@ -79,6 +82,7 @@ class ProcessGroup { bool sync_op_{true}; }; + public: explicit ProcessGroup(int rank, int size, const platform::Place& place, @@ -93,12 +97,48 @@ class ProcessGroup { int GetSize() const { return size_; } virtual std::string GetBackendName() const = 0; + virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const { PADDLE_THROW(platform::errors::InvalidArgument( "Does not support to get device_context from ProcessGroup%s.", GetBackendName())); } + virtual std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support all_gather with sync_op flag", + GetBackendName())); + } + + virtual std::shared_ptr AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support all_reduce with sync_op flag", + GetBackendName())); + } + + virtual std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support barrier", GetBackendName())); + } + + virtual std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support broadcast with sync_op flag", + GetBackendName())); + } + // TODO(liyurui): This API will be moved later virtual std::shared_ptr AllReduce( std::vector& /* input tensors */, // NOLINT @@ -118,6 +158,7 @@ class ProcessGroup { GetBackendName())); } + // TODO(sunyilun): methods below will be removed later virtual std::shared_ptr Broadcast( std::vector& /* input tensors */, // NOLINT std::vector& /* output tensors */, // NOLINT @@ -136,12 +177,6 @@ class ProcessGroup { GetBackendName())); } - virtual std::shared_ptr Barrier( - const BarrierOptions& = BarrierOptions()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support barrier", GetBackendName())); - } - virtual std::shared_ptr Send( std::vector&, int) { // NOLINT PADDLE_THROW(platform::errors::InvalidArgument( diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc index 09af416ec56..d6d7f328aee 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc @@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { } }; +// TODO(sunyilun): for compatibility, will be updated later +std::shared_ptr ProcessGroupGloo::Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op) { + std::vector in_wrapper = {in_tensor}; + std::vector out_wrapper = {*out_tensor}; + return Broadcast(in_wrapper, out_wrapper, opts, true); +} + std::shared_ptr ProcessGroupGloo::Broadcast( std::vector& inputs, std::vector& outputs, diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.h b/paddle/fluid/distributed/collective/ProcessGroupGloo.h index c8959e399ab..72309bf7690 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.h +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.h @@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup { ~ProcessGroupGloo() = default; + std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op) override; + + // TODO(sunyilun): methods below will be removed later std::shared_ptr Broadcast( std::vector& inputs, std::vector& outputs, diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index db713ac304e..2a13001750f 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -15,12 +15,8 @@ #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/distributed/collective/Common.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/core/device_context.h" DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -30,89 +26,299 @@ constexpr int64_t kWaitBlockTImeout = 10; namespace paddle { namespace distributed { -void SyncDefaultStream( - const std::vector& places, - std::vector& ncclEvents, // NOLINT - std::vector>& dev_ctx) { // NOLINT - for (size_t i = 0; i < places.size(); ++i) { - auto* default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(places[i])); - ncclEvents[i].Record(*default_ctx); - ncclEvents[i].Block(*dev_ctx[i]); +ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream) + : TaskStream(rank, comm_type, sync_op, use_calc_stream), + comm_event_(place), + place_(place) {} + +ProcessGroupNCCL::NCCLTask::~NCCLTask() {} + +bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return comm_event_.Query(); } + +void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + comm_event_.Record(&ctx); +} + +// TODO(sheniang03): Add timeout for wait, now timeout unused +bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { + // Warning here when use calc stream but also invoke waiting explicitly. + if (UseCalcStream()) { + VLOG(3) << "Warning: The communication is on calc stream, wait here is " + "useless."; + return true; + } + + const auto* calc_ctx = platform::DeviceContextPool::Instance().Get(place_); + comm_event_.Wait(platform::Place2DeviceType(place_), calc_ctx); + + if (FLAGS_nccl_blocking_wait) { + // NOTE(shenliang03): It will block host for sync + while (!IsCompleted()) { + std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout)); + } } + + if (barrier_) { + // If we use the work to do barrier, we should block cpu +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + } + return true; } -std::shared_ptr ProcessGroupNCCL::CreateTask( - std::vector places, - int rank, - CommType comm_type, - const std::vector& inputs) { - return std::make_shared( - places, rank, comm_type, inputs); +// Same as Wait +void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } + +ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr& store, + int rank, + int size, + const platform::Place& place, + int gid) + : ProcessGroupStream(rank, size, place, gid), store_(store) { + platform::SetDeviceId(place_.device); +} + +void ProcessGroupNCCL::GroupStart() { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); +} + +void ProcessGroupNCCL::GroupEnd() { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); +} + +const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( + const Place& place) const { + return GetDeviceContext(place, /*use_calc_stream*/ false); +} + +const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( + const Place& place, bool use_calc_stream) const { + const std::string& key = GetKeyFromPlace(place); + if (use_calc_stream) { + const auto& iter = place_to_calc_ctx_.find(key); + return *iter->second; + } else { + const auto& iter = place_to_comm_ctx_.find(key); + PADDLE_ENFORCE_NE( + iter, + place_to_comm_ctx_.end(), + platform::errors::NotFound( + "Cannot find the device context in this process group.")); + return *iter->second; + } +} + +ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { + const std::string& key = GetKeyFromPlace(place); + const auto& iter = place_to_comm_ctx_.find(key); + PADDLE_ENFORCE_NE( + iter, + place_to_comm_ctx_.end(), + platform::errors::NotFound( + "Cannot find the NCCL commmunicator in this process group.")); + return iter->second->nccl_comm(); +} + +std::shared_ptr ProcessGroupNCCL::AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + const gpuStream_t& stream) { + return platform::dynload::ncclAllGather( + input.data(), + output->data(), + input.numel(), + platform::ToNCCLDataType(input.dtype()), + comm, + stream); + }, + CommType::ALLGATHER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupNCCL::AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + const gpuStream_t& stream) { + return platform::dynload::ncclAllReduce( + input.data(), + output->data(), + input.numel(), + platform::ToNCCLDataType(input.type()), + ToNCCLRedType(opts.reduce_op), + comm, + stream); + }, + CommType::ALLREDUCE, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupNCCL::Barrier( + const BarrierOptions& opts) { + auto allocator = std::unique_ptr( + new paddle::experimental::DefaultAllocator(place_)); + phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); + phi::DenseTensor barrier_tensor{allocator.get(), meta}; + + auto task = AllReduce(&barrier_tensor, + barrier_tensor, + {}, + /*sync_op*/ true, + /*use_calc_stream*/ false); + auto nccl_task = dynamic_cast(task.get()); + nccl_task->barrier_ = true; + return task; +} + +std::shared_ptr ProcessGroupNCCL::Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + out_tensor, + in_tensor, + [&](phi::DenseTensor* output, + const phi::DenseTensor& input, + ncclComm_t comm, + const gpuStream_t& stream) { + int root = opts.source_rank + opts.source_root; + return platform::dynload::ncclBroadcast( + input.data(), + output->data(), + input.numel(), + platform::ToNCCLDataType(input.type()), + root, + comm, + stream); + }, + CommType::BROADCAST, + sync_op, + use_calc_stream); } std::shared_ptr ProcessGroupNCCL::CreateTask( - const std::vector& places, + const Place& place, int rank, CommType comm_type, - const std::vector& inputs, bool is_sync, bool use_calc_stream) { return std::make_shared( - places, rank, comm_type, inputs, is_sync, use_calc_stream); + place, rank, comm_type, is_sync, use_calc_stream); } -ProcessGroupNCCL::NCCLTask::NCCLTask( - const std::vector& places, - int rank, - CommType CommType, - const std::vector& inputs) - : TaskStream(rank, inputs, CommType), places_(places) { - control_events_.resize(places.size()); - ncclComms_.resize(places.size()); +void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) { + const std::string key = + "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/0"; + if (rank_ == 0) { + std::vector nccl_id_wrapper( + reinterpret_cast(nccl_id), + reinterpret_cast(nccl_id) + NCCL_UNIQUE_ID_BYTES); + store_->set(key, nccl_id_wrapper); + } else { + const auto& nccl_id_wrapper = store_->get(key); + std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); + } } -ProcessGroupNCCL::NCCLTask::NCCLTask( - const std::vector& places, - int rank, - CommType comm_type, - const std::vector& inputs, - bool sync_op, - bool use_calc_stream) - : TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream), - places_(places) { - control_events_.resize(places.size()); - ncclComms_.resize(places.size()); -} +void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, + const std::string& place_key) { + ncclUniqueId nccl_id; + if (rank_ == 0) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); + } + BroadcastUniqueNCCLID(&nccl_id); -ProcessGroupNCCL::NCCLTask::~NCCLTask() {} + VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ + << ", place: " << place_key + << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); -void ProcessGroupNCCL::NCCLTask::SetOutputs( - std::vector& outputs) { // NOLINT - outputs_ = std::make_shared>(outputs); + calc_event_ = std::make_shared(place); + auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto comm_ctx = std::make_unique(place); + ncclComm_t nccl_comm; + NCCLCHECK(platform::dynload::ncclCommInitRank( + &nccl_comm, GetSize(), nccl_id, GetRank())); + comm_ctx->set_nccl_comm(nccl_comm); + + place_to_calc_ctx_[place_key] = calc_ctx; + place_to_comm_ctx_[place_key] = std::move(comm_ctx); + + // TODO(sunyilun): for compatibility, will be removed later + places_to_ctx_[place_key] = {place_to_comm_ctx_[place_key].get()}; } -void ProcessGroupNCCL::NCCLTask::SynchronizeStreams() { - for (size_t i = 0; i < places_.size(); ++i) { - auto* default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(places_[i])); - default_ctx->WaitEvent(control_events_[i].GetRawCudaEvent()); - } +void ProcessGroupNCCL::SyncCalcStream( + const Place& place, const std::shared_ptr& event) { + const std::string& key = GetKeyFromPlace(place); + const auto* calc_ctx = place_to_calc_ctx_[key]; + const auto* comm_ctx = place_to_comm_ctx_[key].get(); + event->Record(calc_ctx); + event->Wait(platform::Place2DeviceType(place), comm_ctx); } -bool ProcessGroupNCCL::NCCLTask::IsCompleted() { - for (size_t i = 0; i < places_.size(); ++i) { - if (!control_events_[i].Query()) { - return false; - } +template +std::shared_ptr ProcessGroupNCCL::Collective( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + Fn fn, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + const auto& place = in_tensor.place(); + const auto& key = GetKeyFromPlace(place); + + if (!calc_event_) { + CreateNCCLEnvCache(place, key); } - return true; -} + if (!use_calc_stream) { + SyncCalcStream(place, calc_event_); + } -void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( - const phi::DeviceContext& ctx) { - control_events_[0].Record(*static_cast(&ctx)); + auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); + + const auto* calc_ctx = place_to_calc_ctx_[key]; + const auto& comm_ctx = place_to_comm_ctx_[key]; + auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); + fn(out_tensor, in_tensor, comm_ctx->nccl_comm(), nccl_stream); + + if (!use_calc_stream) { + if (FLAGS_use_stream_safe_cuda_allocator) { + memory::RecordStream(in_tensor.Holder(), nccl_stream); + } + task->comm_event_.Record(comm_ctx.get()); + } + + return task; } void ProcessGroupNCCL::CheckSplitSizes(std::vector* split_sizes, @@ -144,70 +350,58 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector* split_sizes, } } -// TODO(sheniang03): Add timeout for wait, now timeout unused -bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { - // Warning here when use calc stream but also invoke waiting explicitly. - if (UseCalcStream()) { - VLOG(3) << "Warning: The communication is on calc stream, wait here is " - "useless."; - return true; - } - - SynchronizeStreams(); - if (FLAGS_nccl_blocking_wait) { - // NOTE(shenliang03): It will block host for sync - while (!IsCompleted()) { - std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout)); - } - } - - if (!barrierTensors_.empty()) { - // If we use the work to do barrier, we should block cpu - for (auto& place : places_) { - platform::CUDADeviceGuard gpuGuard(place); -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); -#else - PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); -#endif - } +// TODO(sunyilun): methods below will be removed later +void SyncDefaultStream(const std::vector& places, + const std::shared_ptr& nccl_event, + std::vector& dev_ctx) { // NOLINT + for (size_t i = 0; i < places.size(); ++i) { + auto* default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(places[i])); + nccl_event->Record(default_ctx); + nccl_event->Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]); } - return true; } -// Same as Wait -void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } - -ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr& store, - int rank, - int size, - const platform::Place& place, - int gid) - : ProcessGroupStream(rank, size, place, gid), store_(store) { - platform::SetDeviceId(place_.device); +std::shared_ptr ProcessGroupNCCL::CreateTask( + std::vector places, + int rank, + CommType comm_type, + const std::vector& inputs) { + return std::make_shared( + places, rank, comm_type, inputs); } -void ProcessGroupNCCL::BroadcastUniqueNCCLID( - std::vector& nccl_ids) { // NOLINT - if (rank_ == 0) { - for (size_t i = 0; i < nccl_ids.size(); i++) { - auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" + - std::to_string(i); - auto nccl_id = std::vector( - reinterpret_cast(&nccl_ids[i]), - reinterpret_cast(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES); - store_->set(key, nccl_id); - } - } else { - for (size_t i = 0; i < nccl_ids.size(); i++) { - auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(gid_) + "/" + - std::to_string(i); - auto ret = store_->get(key); - std::memcpy(&nccl_ids[i], ret.data(), ret.size()); - } - } +std::shared_ptr ProcessGroupNCCL::CreateTask( + const std::vector& places, + int rank, + CommType comm_type, + const std::vector& inputs, + bool is_sync, + bool use_calc_stream) { + return std::make_shared( + places, rank, comm_type, inputs, is_sync, use_calc_stream); } +ProcessGroupNCCL::NCCLTask::NCCLTask( + const std::vector& places, + int rank, + CommType CommType, + const std::vector& inputs) + : TaskStream(rank, inputs, CommType), + comm_event_(places[0]), + place_(places[0]) {} + +ProcessGroupNCCL::NCCLTask::NCCLTask( + const std::vector& places, + int rank, + CommType comm_type, + const std::vector& inputs, + bool sync_op, + bool use_calc_stream) + : TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream), + comm_event_(places[0]), + place_(places[0]) {} + // create NCCLManager cache for places_key void ProcessGroupNCCL::CreateNCCLManagerCache( const std::string& places_key, const std::vector& places) { @@ -217,22 +411,11 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( "Not able to create/get the NCCL Communicator since " "the GPU place are not known")); - std::vector> nccl_comms; - nccl_comms.resize(places.size()); - - // using vector just for broadcast - std::vector nccl_ids; - nccl_ids.resize(1); - auto& nccl_id = nccl_ids.front(); - - for (auto& place : places) { - used_place_ids_.insert(place.GetDeviceId()); - } - + ncclUniqueId nccl_id; if (rank_ == 0) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); } - BroadcastUniqueNCCLID(nccl_ids); + BroadcastUniqueNCCLID(&nccl_id); VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ << ", place: " << places_key @@ -241,23 +424,33 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( std::vector> dev_ctx; dev_ctx.resize(places.size()); + std::vector dev_ctx_raw; + dev_ctx_raw.resize(places.size()); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); for (size_t i = 0; i < places.size(); ++i) { platform::CUDADeviceGuard guard(places[i]); - nccl_comms[i] = NCCLCommManager::Create(GetSize(), GetRank(), nccl_id); + dev_ctx[i].reset(new phi::GPUContext(places[i])); + ncclComm_t nccl_comm; + NCCLCHECK(platform::dynload::ncclCommInitRank( + &nccl_comm, GetSize(), nccl_id, GetRank())); + dev_ctx[i]->set_nccl_comm(nccl_comm); + + dev_ctx_raw[i] = dev_ctx[i].get(); } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + calc_event_ = std::make_shared(places[0]); - std::vector events; - events.resize(places.size()); + // TODO(sunyilun): for compatibility, will be removed later + place_to_calc_ctx_[places_key] = static_cast( + platform::DeviceContextPool::Instance().Get(places[0])); + place_to_comm_ctx_[places_key] = std::move(dev_ctx[0]); // These caches will be useful to process sync/wait/communicate - places_to_events_.emplace(places_key, std::move(events)); - places_to_ncclcomm_.emplace(places_key, std::move(nccl_comms)); - places_to_ctx_.emplace(places_key, std::move(dev_ctx)); + places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw)); } template @@ -273,15 +466,13 @@ std::shared_ptr ProcessGroupNCCL::Collective( { std::lock_guard lock(mutex_); - if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + if (!calc_event_) { CreateNCCLManagerCache(key, places); } } - auto& nccl_comms = places_to_ncclcomm_[key]; - if (!use_calc_stream) { - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); } auto task = @@ -304,7 +495,10 @@ std::shared_ptr ProcessGroupNCCL::Collective( nccl_stream = places_to_ctx_[key][i]->stream(); } - fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); + fn(inputs[i], + outputs[i], + places_to_ctx_[key][i]->nccl_comm(), + nccl_stream); } } @@ -330,7 +524,7 @@ std::shared_ptr ProcessGroupNCCL::Collective( if (!use_calc_stream) { for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->control_events_[i].Record(*places_to_ctx_[key][i]); + task->comm_event_.Record(places_to_ctx_[key][i]); } } @@ -348,14 +542,12 @@ std::shared_ptr ProcessGroupNCCL::Collective( { std::lock_guard lock(mutex_); - if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + if (!calc_event_) { CreateNCCLManagerCache(key, places); } } - auto& nccl_comms = places_to_ncclcomm_[key]; - - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); auto task = CreateTask(places, rank_, op_type, inputs); @@ -367,7 +559,10 @@ std::shared_ptr ProcessGroupNCCL::Collective( for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); + fn(inputs[i], + outputs[i], + places_to_ctx_[key][i]->nccl_comm(), + nccl_stream); } } @@ -381,7 +576,7 @@ std::shared_ptr ProcessGroupNCCL::Collective( for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->control_events_[i].Record(*places_to_ctx_[key][i]); + task->comm_event_.Record(places_to_ctx_[key][i]); } return task; } @@ -393,18 +588,16 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, CommType op_type) { std::vector places; places.push_back(in->place()); - const auto key = GetKeyFromPlaces(places); + const std::string& key = GetKeyFromPlaces(places); { std::lock_guard lock(mutex_); - if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + if (!calc_event_) { CreateNCCLManagerCache(key, places); } } - auto& nccl_comms = places_to_ncclcomm_[key]; - - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; @@ -418,7 +611,7 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, platform::NCCLGroupGuard nccl_guard; cuda_guard.SetDevice(places[0]); const auto& nccl_stream = places_to_ctx_[key][0]->stream(); - fn(in, out, nccl_comms[0]->GetNcclComm(), nccl_stream); + fn(in, out, places_to_ctx_[key][0]->nccl_comm(), nccl_stream); } cuda_guard.SetDevice(places[0]); @@ -437,15 +630,13 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( { std::lock_guard lock(mutex_); - if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + if (!calc_event_) { CreateNCCLManagerCache(key, places); } } - auto& nccl_comms = places_to_ncclcomm_[key]; - if (!use_calc_stream) { - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); } auto task = @@ -466,7 +657,10 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + fn(tensors[i], + places_to_ctx_[key][i]->nccl_comm(), + nccl_stream, + dst_rank); } } @@ -489,7 +683,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( if (!use_calc_stream) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->control_events_[i].Record(*places_to_ctx_[key][i]); + task->comm_event_.Record(places_to_ctx_[key][i]); } } @@ -507,14 +701,12 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( { std::lock_guard lock(mutex_); - if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) { + if (!calc_event_) { CreateNCCLManagerCache(key, places); } } - auto& nccl_comms = places_to_ncclcomm_[key]; - - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); + SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); auto task = CreateTask(places, rank_, op_type, tensors); @@ -526,7 +718,10 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + fn(tensors[i], + places_to_ctx_[key][i]->nccl_comm(), + nccl_stream, + dst_rank); } } @@ -540,7 +735,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->control_events_[i].Record(*places_to_ctx_[key][i]); + task->comm_event_.Record(places_to_ctx_[key][i]); } return task; } @@ -572,37 +767,6 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( CommType::ALLREDUCE); } -std::shared_ptr ProcessGroupNCCL::AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - return platform::dynload::ncclAllReduce( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.type()), - ToNCCLRedType(opts.reduce_op), - comm, - stream); - }, - CommType::ALLREDUCE, - sync_op, - use_calc_stream); -} - std::shared_ptr ProcessGroupNCCL::Broadcast( std::vector& in_tensors, std::vector& out_tensors, @@ -633,63 +797,6 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( CommType::BROADCAST); } -std::shared_ptr ProcessGroupNCCL::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - const auto root = - opts.source_rank * in_tensors.size() + opts.source_root; - return platform::dynload::ncclBroadcast( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.type()), - root, - comm, - stream); - }, - CommType::BROADCAST, - sync_op, - use_calc_stream); -} - -std::shared_ptr ProcessGroupNCCL::Barrier( - const BarrierOptions& opts) { - // Only support single card single process - std::vector places = {place_}; - - std::vector barrierTensors; - barrierTensors.reserve(places.size()); - - platform::CUDADeviceGuard gpuGuard; - for (auto& place : places) { - gpuGuard.SetDeviceIndex(place.GetDeviceId()); - phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1})); - auto allocator = std::unique_ptr( - new paddle::experimental::DefaultAllocator(place)); - barrierTensors.emplace_back(allocator.get(), meta); - } - auto task = ProcessGroupNCCL::AllReduce( - barrierTensors, barrierTensors, AllreduceOptions()); - auto nccl_task = dynamic_cast(task.get()); - nccl_task->barrierTensors_ = std::move(barrierTensors); - return task; -} - void CheckTensorsInDifferentDevices( const std::vector& tensors, const size_t num_devices) { PADDLE_ENFORCE_EQ( @@ -975,39 +1082,6 @@ std::shared_ptr ProcessGroupNCCL::AllGather( CommType::ALLGATHER); } -std::shared_ptr ProcessGroupNCCL::AllGather( - std::vector& in_tensors, - std::vector& out_tensors, - bool sync_op, - bool use_calc_stream) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(in_tensors), - true, - platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInCudaPlace(out_tensors), - true, - platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); - return Collective( - in_tensors, - out_tensors, - [&](const phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - return platform::dynload::ncclAllGather( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.dtype()), - comm, - stream); - }, - CommType::ALLGATHER, - sync_op, - use_calc_stream); -} - void* GetPointerByOffset(void* raw_pointer, size_t offset, experimental::DataType type) { @@ -1578,43 +1652,5 @@ std::shared_ptr ProcessGroupNCCL::_ReduceScatterBase( CommType::REDUCE_SCATTER); } -void ProcessGroupNCCL::GroupStart() { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); -} - -void ProcessGroupNCCL::GroupEnd() { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); -} - -ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { - std::vector places = {place}; - const auto& iter = places_to_ncclcomm_.find(GetKeyFromPlaces(places)); - PADDLE_ENFORCE_NE(iter, - places_to_ncclcomm_.end(), - platform::errors::InvalidArgument( - "Cannot find nccl comm in process group.")); - return iter->second[0]->GetNcclComm(); -} - -const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( - const Place& place) const { - return GetDeviceContext(place, /*use_calc_stream*/ false); -} - -const phi::DeviceContext& ProcessGroupNCCL::GetDeviceContext( - const Place& place, bool use_calc_stream) const { - if (use_calc_stream) { - return *platform::DeviceContextPool::Instance().Get(place); - } else { - std::vector places = {place}; - const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places)); - PADDLE_ENFORCE_NE(iter, - places_to_ctx_.end(), - platform::errors::InvalidArgument( - "Cannot find device context in process group.")); - return *iter->second[0]; - } -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 8be5a63c1ad..5ea83b906dd 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -24,10 +24,10 @@ #include "paddle/fluid/distributed/collective/ProcessGroupStream.h" #include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/cuda_device_guard.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/gen_comm_id_helper.h" -#include "paddle/fluid/platform/place.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/device_context.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/NCCLTools.h" @@ -44,16 +44,28 @@ namespace distributed { using Place = paddle::platform::Place; -class ProcessGroupNCCL : public ProcessGroupStream { +class ProcessGroupNCCL final : public ProcessGroupStream { public: - class NCCLTask : public ProcessGroupStream::TaskStream, - public std::enable_shared_from_this { + class NCCLTask final : public ProcessGroupStream::TaskStream, + public std::enable_shared_from_this { public: + NCCLTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + virtual ~NCCLTask(); + + bool IsCompleted() override; + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; + void Synchronize() override; + void UpdateWaitChain(const phi::DeviceContext& ctx) override; + + // TODO(sunyilun): methods below will be removed later NCCLTask(const std::vector& places, int rank, CommType CommType, const std::vector& inputs); - NCCLTask(const std::vector& places, int rank, CommType comm_type, @@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream { bool sync_op, bool use_calc_stream); - bool IsCompleted(); - - void SynchronizeStreams(); - - bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); - - void Synchronize(); - - void SetOutputs(std::vector& outputs); // NOLINT - - virtual ~NCCLTask(); - - void UpdateWaitChain(const phi::DeviceContext& ctx) override; - - std::vector control_events_; - std::vector barrierTensors_; - - protected: - std::vector places_; - std::vector> ncclComms_; - std::shared_ptr> outputs_; + public: + bool barrier_{false}; + platform::DeviceEvent comm_event_; // event on comm stream private: + Place place_; }; + public: ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, @@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream { const phi::DeviceContext& GetDeviceContext( const Place& place, bool use_calc_stream) const override; + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr AllReduce( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const AllreduceOptions& options, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) override; + + std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) override; + static void GroupStart(); + + static void GroupEnd(); + + ncclComm_t NCCLComm(const Place& place) const; + // TODO(liyurui): This API will be moved later std::shared_ptr AllReduce( std::vector& in_tensors, std::vector& out_tensors, const AllreduceOptions& = AllreduceOptions()) override; + // TODO(sunyilun): methods below will be removed later std::shared_ptr Broadcast( std::vector& in_tensors, std::vector& out_tensors, const BroadcastOptions& = BroadcastOptions()) override; - std::shared_ptr Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts, - bool sync_op, - bool use_calc_stream) override; - - std::shared_ptr Barrier( - const BarrierOptions& = BarrierOptions()) override; - std::shared_ptr Send( std::vector& tensors, int dst_rank) override; @@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream { std::vector& in_tensors, std::vector& out_tensors) override; - std::shared_ptr AllGather( - std::vector& in_tensors, - std::vector& out_tensors, - bool sync_op, - bool use_calc_stream) override; - std::shared_ptr AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, @@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream { phi::DenseTensor&, // NOLINT const ReduceScatterOptions&) override; - static void GroupStart(); + private: + std::shared_ptr CreateTask(const Place& place, + int rank, + CommType op_type, + bool sync_op, + bool use_calc_stream); - static void GroupEnd(); + void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id); - ncclComm_t NCCLComm(const Place& place) const; + void CreateNCCLEnvCache(const Place& place, const std::string& place_key); + + template + std::shared_ptr Collective( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + Fn fn, + CommType comm_type, + bool sync_op, + bool use_calc_stream); - protected: - virtual std::shared_ptr CreateTask( + void SyncCalcStream(const Place& place, + const std::shared_ptr& event); + + // TODO(sunyilun): methods below will be removed later + std::shared_ptr CreateTask( std::vector places, int rank, CommType op_type, const std::vector& inputs); - virtual std::shared_ptr CreateTask( + std::shared_ptr CreateTask( const std::vector& places, int rank, CommType op_type, @@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream { bool sync_op, bool use_calc_stream); - protected: - std::shared_ptr store_; - std::shared_ptr nccl_comm_; - std::mutex mutex_; - std::unordered_map>> - places_to_ncclcomm_; - - std::unordered_map> places_to_events_; - - std::unordered_map>> - places_to_ctx_; - - std::set used_place_ids_; - - private: - void BcastNCCLId(std::vector& nccl_ids, // NOLINT - int root, // NOLINT - int server_fd); - - void BroadcastUniqueNCCLID(std::vector& nccl_ids); // NOLINT - template std::shared_ptr Collective( std::vector& inputs, // NOLINT @@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream { void CheckSplitSizes(std::vector* split_sizes, std::vector tensor_shape); + + private: + std::shared_ptr store_; + std::shared_ptr calc_event_; // event on calc stream + std::unordered_map place_to_calc_ctx_; + std::unordered_map> + place_to_comm_ctx_; + + // TODO(sunyilun): attrs below will be removed later + std::mutex mutex_; + std::unordered_map> places_to_ctx_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 11530ab872d..b8f4b25bd28 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext( } std::shared_ptr ProcessGroupStream::AllGather( - std::vector& input_tensors, // NOLINT - std::vector& output_tensors, // NOLINT + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, bool sync_op) { - return AllGather(input_tensors, - output_tensors, + return AllGather(out_tensor, + in_tensor, sync_op, /*use_calc_stream*/ false); } std::shared_ptr ProcessGroupStream::AllGather( - std::vector& input_tensors, // NOLINT - std::vector& output_tensors, // NOLINT + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, bool sync_op, bool use_calc_stream) { PADDLE_THROW(platform::errors::InvalidArgument( @@ -49,27 +49,50 @@ std::shared_ptr ProcessGroupStream::AllGather( } std::shared_ptr ProcessGroupStream::AllReduce( - std::vector& input_tensors, // NOLINT - std::vector& output_tensors, // NOLINT - const AllreduceOptions& options, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, bool sync_op) { - return AllReduce(input_tensors, - output_tensors, - options, + return AllReduce(out_tensor, + in_tensor, + opts, sync_op, /*use_calc_stream*/ false); } std::shared_ptr ProcessGroupStream::AllReduce( - std::vector& input_tensors, // NOLINT - std::vector& output_tensors, // NOLINT - const AllreduceOptions& options, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support do all_reduce", GetBackendName())); } +std::shared_ptr ProcessGroupStream::Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op) { + return Broadcast(out_tensor, + in_tensor, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do broadcast", GetBackendName())); +} + +// TODO(sunyilun): methods below will be removed later std::shared_ptr ProcessGroupStream::AllToAll( std::vector& in_tensors, std::vector& out_tensors, @@ -114,28 +137,6 @@ std::shared_ptr ProcessGroupStream::AllToAllSingle( "ProcessGroup%s does not support do alltoall_single", GetBackendName())); } -std::shared_ptr ProcessGroupStream::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts, - bool sync_op) { - return Broadcast(in_tensors, - out_tensors, - opts, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupStream::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts, - bool sync_op, - bool use_calc_stream) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support do broadcast", GetBackendName())); -} - std::shared_ptr ProcessGroupStream::Reduce( std::vector& in_tensors, std::vector& out_tensors, diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index 56799c4bd3e..c4699d40901 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup { public: class TaskStream : public ProcessGroup::Task { public: + TaskStream(int rank, CommType comm_type, bool sync_op, bool use_calc_stream) + : Task(rank, comm_type, sync_op), use_calc_stream_(use_calc_stream) {} + + virtual ~TaskStream() = default; + // TODO(liyurui): This constructor is temporary here for compatible reason, // will be deleted soon. TaskStream(int rank, @@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup { : Task(rank, inputs, comm_type, sync_op), use_calc_stream_(use_calc_stream) {} - virtual ~TaskStream() = default; - protected: bool UseCalcStream() const { return use_calc_stream_; } @@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup { bool use_calc_stream_{false}; }; + public: ProcessGroupStream(int rank, int size, const platform::Place& place, int gid); virtual ~ProcessGroupStream() = default; @@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup { const Place& place, bool use_calc_stream) const; std::shared_ptr AllGather( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, bool sync_op) override; virtual std::shared_ptr AllGather( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, bool sync_op, bool use_calc_stream); std::shared_ptr AllReduce( - std::vector& input_tensors, // NOLINT - std::vector& output_tensors, // NOLINT - const AllreduceOptions& options, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, bool sync_op) override; virtual std::shared_ptr AllReduce( - std::vector& input_tensors, // NOLINT - std::vector& output_tensors, // NOLINT - const AllreduceOptions& options, + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, bool sync_op, bool use_calc_stream); + // TODO(sunyilun): methods below will be removed later std::shared_ptr AllToAll( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT @@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); - std::shared_ptr Broadcast( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const BroadcastOptions& opts, - bool sync_op) override; - - virtual std::shared_ptr Broadcast( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const BroadcastOptions& opts, - bool sync_op, - bool use_calc_stream); - std::shared_ptr Reduce( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index f37c1a0be8b..a13bfcf12ea 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT if (map->has(ring_id)) { paddle::distributed::ProcessGroup *pg = map->get(ring_id); auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(tensor); paddle::distributed::AllreduceOptions opts; opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); + auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true); task->Wait(); } else { auto dtype = platform::ToNCCLDataType( diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 39dfe969e3d..669672084b5 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT if (map->has(ring_id)) { paddle::distributed::ProcessGroup* pg = map->get(ring_id); auto pg_nccl = static_cast(pg); - - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(tensor); - out_tensor.push_back(tensor); paddle::distributed::AllreduceOptions opts; opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); + auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true); task->Wait(); } else { auto dtype = platform::ToNCCLDataType( diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index b84dd7fcbe1..a4b53922a8e 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -147,12 +147,12 @@ void BindDistributed(py::module *m) { distributed::ReduceOp op, bool sync_op) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - distributed::AllreduceOptions opts; - opts.reduce_op = op; - auto dense = + auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.AllReduce(tensors, tensors, opts, sync_op); + auto *out_dense = p_dense.get(); + auto in_dense = *p_dense; + distributed::AllreduceOptions opts{op}; + return self.AllReduce(out_dense, in_dense, opts, sync_op); }, py::arg("tensor"), py::arg("op"), @@ -183,11 +183,12 @@ void BindDistributed(py::module *m) { int src, bool sync_op) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - distributed::BroadcastOptions opts{src}; - auto dense = + auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Broadcast(tensors, tensors, opts, sync_op); + auto *out_dense = p_dense.get(); + auto in_dense = *p_dense; + distributed::BroadcastOptions opts{src}; + return self.Broadcast(out_dense, in_dense, opts, sync_op); }, py::arg("tensor"), py::arg("src"), @@ -380,52 +381,52 @@ void BindDistributed(py::module *m) { .def( "allgather", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, py::handle py_out_tensor_list, + py::handle py_in_tensor, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor_list = CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( concat_out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; const auto &dev_ctx = self.GetDeviceContext(in_tensor.place()); - auto task = self.AllGather(in_wrapper, out_wrapper, sync_op); + auto task = self.AllGather(out_dense, in_dense, sync_op); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); task->UpdateWaitChain(dev_ctx); return task; }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("sync_op"), py::call_guard()) .def( "allgather_into_tensor", [](distributed::ProcessGroup &self, - py::handle py_in_tensor, py::handle py_out_tensor, + py::handle py_in_tensor, bool sync_op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; - return self.AllGather(in_wrapper, out_wrapper, sync_op); + return self.AllGather(out_dense, in_dense, sync_op); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::arg("sync_op"), py::call_guard()) @@ -784,55 +785,55 @@ void BindDistributed(py::module *m) { .def( "allgather_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, - py::handle py_out_tensor_list) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - + py::handle py_out_tensor_list, + py::handle py_in_tensor) { auto out_tensor_list = CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( concat_out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); + + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; const auto &dev_ctx = self.GetDeviceContext(in_tensor.place(), true); - auto task = self.AllGather(in_wrapper, - out_wrapper, + auto task = self.AllGather(out_dense, + in_dense, /*sync_op*/ true, /*use_calc_stream*/ true); distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); return task; }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::call_guard()) .def( "allgather_into_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, - py::handle py_in_tensor, - py::handle py_out_tensor) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto in_dense = std::dynamic_pointer_cast( - in_tensor.impl()); - std::vector in_wrapper = {*in_dense}; - + py::handle py_out_tensor, + py::handle py_in_tensor) { auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - auto out_dense = std::dynamic_pointer_cast( + auto p_out_tensor = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector out_wrapper = {*out_dense}; + auto *out_dense = p_out_tensor.get(); - return self.AllGather(in_wrapper, - out_wrapper, + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; + + return self.AllGather(out_dense, + in_dense, /*sync_op*/ true, /*use_calc_stream*/ true); }, - py::arg("in"), py::arg("out"), + py::arg("in"), py::call_guard()) .def( @@ -872,13 +873,13 @@ void BindDistributed(py::module *m) { py::handle py_tensor, distributed::ReduceOp op) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - distributed::AllreduceOptions opts; - opts.reduce_op = op; - auto dense = + auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.AllReduce(tensors, - tensors, + auto in_dense = *p_dense; + auto *out_dense = p_dense.get(); + distributed::AllreduceOptions opts{op}; + return self.AllReduce(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); @@ -980,12 +981,13 @@ void BindDistributed(py::module *m) { py::handle py_tensor, int src) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); - distributed::BroadcastOptions opts{src}; - auto dense = + auto p_dense = std::dynamic_pointer_cast(tensor.impl()); - std::vector tensors = {*dense}; - return self.Broadcast(tensors, - tensors, + auto *out_dense = p_dense.get(); + auto in_dense = *p_dense; + distributed::BroadcastOptions opts{src}; + return self.Broadcast(out_dense, + in_dense, opts, /*sync_op*/ true, /*use_calc_stream*/ true); diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index cdd7f98554b..e74623f9486 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1): expect_shape = list(shape) expect_shape[0] *= nranks if list(tensor.shape) != expect_shape: - raise RuntimeError('The tensor for all_gather is not correctly-sized.') + raise RuntimeError("The tensor for all_gather is not correctly-sized.") def _check_tensor_list_shape(tensor_list, shape, nranks=1): if len(tensor_list) != nranks: raise RuntimeError( - 'The tensor_list for all_gather is not correctly-sized.' + "The tensor_list for all_gather is not correctly-sized." ) for tensor in tensor_list: if tensor.shape != shape: raise RuntimeError( - 'The tensor_list for all_gather is not correctly-sized.' + "The tensor_list for all_gather is not correctly-sized." ) @@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph( if use_calc_stream: return group.process_group.allgather_into_tensor_on_calc_stream( - in_tensor, out_tensor + out_tensor, + in_tensor, ) task = group.process_group.allgather_into_tensor( - in_tensor, out_tensor, sync_op + out_tensor, in_tensor, sync_op ) if sync_op: task.wait() @@ -68,9 +69,9 @@ def _all_gather_in_dygraph( _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) if use_calc_stream: - return group.process_group.allgather_on_calc_stream(tensor, tensor_list) + return group.process_group.allgather_on_calc_stream(tensor_list, tensor) - task = group.process_group.allgather(tensor, tensor_list, sync_op) + task = group.process_group.allgather(tensor_list, tensor, sync_op) if sync_op: task.wait() -- GitLab