From ae14bad174832be7c2d928988ed6e0288912f54b Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Wed, 9 Nov 2022 11:02:57 +0800 Subject: [PATCH] refactor: ProcessGroupNCCL (#47740) --- .../distributed/collective/ProcessGroup.h | 8 - .../collective/ProcessGroupNCCL.cc | 180 ++++++++---------- .../distributed/collective/ProcessGroupNCCL.h | 23 +-- paddle/fluid/pybind/distributed_py.cc | 21 -- 4 files changed, 85 insertions(+), 147 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 7550b83265..0076b5ee47 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -350,14 +350,6 @@ class ProcessGroup { GetBackendName())); } - virtual std::shared_ptr _ReduceScatterBase( - phi::DenseTensor&, // NOLINT - phi::DenseTensor&, // NOLINT - const ReduceScatterOptions&) { - PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support ReduceScatter", GetBackendName())); - } - protected: const int rank_; const int size_; diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 2a13001750..9ed56ae324 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -33,7 +33,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place, bool use_calc_stream) : TaskStream(rank, comm_type, sync_op, use_calc_stream), comm_event_(place), - place_(place) {} + task_place_(place) {} ProcessGroupNCCL::NCCLTask::~NCCLTask() {} @@ -53,8 +53,9 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { return true; } - const auto* calc_ctx = platform::DeviceContextPool::Instance().Get(place_); - comm_event_.Wait(platform::Place2DeviceType(place_), calc_ctx); + const auto* calc_ctx = + platform::DeviceContextPool::Instance().Get(task_place_); + comm_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx); if (FLAGS_nccl_blocking_wait) { // NOTE(shenliang03): It will block host for sync @@ -63,7 +64,7 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { } } - if (barrier_) { + if (IsBlockCPUInWait()) { // If we use the work to do barrier, we should block cpu #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); @@ -192,7 +193,7 @@ std::shared_ptr ProcessGroupNCCL::Barrier( /*sync_op*/ true, /*use_calc_stream*/ false); auto nccl_task = dynamic_cast(task.get()); - nccl_task->barrier_ = true; + nccl_task->SetBlockCPUInWait(); return task; } @@ -250,6 +251,10 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id) { void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, const std::string& place_key) { + if (place_to_comm_ctx_.size() > 0) { + VLOG(3) << "Warning: Tensors from multiple devices are not supported yet."; + } + ncclUniqueId nccl_id; if (rank_ == 0) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id)); @@ -260,7 +265,6 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, << ", place: " << place_key << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); - calc_event_ = std::make_shared(place); auto* calc_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); auto comm_ctx = std::make_unique(place); @@ -269,20 +273,23 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, &nccl_comm, GetSize(), nccl_id, GetRank())); comm_ctx->set_nccl_comm(nccl_comm); - place_to_calc_ctx_[place_key] = calc_ctx; - place_to_comm_ctx_[place_key] = std::move(comm_ctx); + place_to_calc_event_.emplace(place_key, place); + place_to_calc_ctx_.emplace(place_key, calc_ctx); + place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); // TODO(sunyilun): for compatibility, will be removed later - places_to_ctx_[place_key] = {place_to_comm_ctx_[place_key].get()}; + std::vector comm_ctx_wrapper{ + place_to_comm_ctx_[place_key].get()}; + places_to_ctx_.emplace(place_key, comm_ctx_wrapper); } -void ProcessGroupNCCL::SyncCalcStream( - const Place& place, const std::shared_ptr& event) { +void ProcessGroupNCCL::SyncCalcStream(const Place& place) { const std::string& key = GetKeyFromPlace(place); - const auto* calc_ctx = place_to_calc_ctx_[key]; - const auto* comm_ctx = place_to_comm_ctx_[key].get(); - event->Record(calc_ctx); - event->Wait(platform::Place2DeviceType(place), comm_ctx); + auto& calc_event = place_to_calc_event_.at(key); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto* comm_ctx = place_to_comm_ctx_.at(key).get(); + calc_event.Record(calc_ctx); + calc_event.Wait(platform::Place2DeviceType(place), comm_ctx); } template @@ -296,26 +303,29 @@ std::shared_ptr ProcessGroupNCCL::Collective( const auto& place = in_tensor.place(); const auto& key = GetKeyFromPlace(place); - if (!calc_event_) { + platform::CUDADeviceGuard cuda_guard(place); + + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLEnvCache(place, key); } if (!use_calc_stream) { - SyncCalcStream(place, calc_event_); + SyncCalcStream(place); } auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); - const auto* calc_ctx = place_to_calc_ctx_[key]; - const auto& comm_ctx = place_to_comm_ctx_[key]; + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); + auto nccl_comm = comm_ctx->nccl_comm(); auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - fn(out_tensor, in_tensor, comm_ctx->nccl_comm(), nccl_stream); + fn(out_tensor, in_tensor, nccl_comm, nccl_stream); if (!use_calc_stream) { if (FLAGS_use_stream_safe_cuda_allocator) { memory::RecordStream(in_tensor.Holder(), nccl_stream); } - task->comm_event_.Record(comm_ctx.get()); + task->UpdateWaitChain(*comm_ctx); } return task; @@ -352,13 +362,13 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector* split_sizes, // TODO(sunyilun): methods below will be removed later void SyncDefaultStream(const std::vector& places, - const std::shared_ptr& nccl_event, + platform::DeviceEvent& nccl_event, // NOLINT std::vector& dev_ctx) { // NOLINT for (size_t i = 0; i < places.size(); ++i) { auto* default_ctx = static_cast( platform::DeviceContextPool::Instance().Get(places[i])); - nccl_event->Record(default_ctx); - nccl_event->Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]); + nccl_event.Record(default_ctx); + nccl_event.Wait(platform::Place2DeviceType(places[i]), dev_ctx[i]); } } @@ -389,7 +399,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask( const std::vector& inputs) : TaskStream(rank, inputs, CommType), comm_event_(places[0]), - place_(places[0]) {} + task_place_(places[0]) {} ProcessGroupNCCL::NCCLTask::NCCLTask( const std::vector& places, @@ -400,7 +410,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask( bool use_calc_stream) : TaskStream(rank, inputs, comm_type, sync_op, use_calc_stream), comm_event_(places[0]), - place_(places[0]) {} + task_place_(places[0]) {} // create NCCLManager cache for places_key void ProcessGroupNCCL::CreateNCCLManagerCache( @@ -437,17 +447,18 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( NCCLCHECK(platform::dynload::ncclCommInitRank( &nccl_comm, GetSize(), nccl_id, GetRank())); dev_ctx[i]->set_nccl_comm(nccl_comm); - dev_ctx_raw[i] = dev_ctx[i].get(); } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); - calc_event_ = std::make_shared(places[0]); // TODO(sunyilun): for compatibility, will be removed later - place_to_calc_ctx_[places_key] = static_cast( - platform::DeviceContextPool::Instance().Get(places[0])); - place_to_comm_ctx_[places_key] = std::move(dev_ctx[0]); + place_to_calc_event_.emplace(places_key, places[0]); + place_to_calc_ctx_.emplace( + places_key, + static_cast( + platform::DeviceContextPool::Instance().Get(places[0]))); + place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0])); // These caches will be useful to process sync/wait/communicate places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw)); @@ -466,13 +477,14 @@ std::shared_ptr ProcessGroupNCCL::Collective( { std::lock_guard lock(mutex_); - if (!calc_event_) { + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLManagerCache(key, places); } } if (!use_calc_stream) { - SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); + SyncDefaultStream( + places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); } auto task = @@ -492,12 +504,12 @@ std::shared_ptr ProcessGroupNCCL::Collective( platform::DeviceContextPool::Instance().Get(places[i])) ->stream(); } else { - nccl_stream = places_to_ctx_[key][i]->stream(); + nccl_stream = places_to_ctx_.at(key)[i]->stream(); } fn(inputs[i], outputs[i], - places_to_ctx_[key][i]->nccl_comm(), + places_to_ctx_.at(key)[i]->nccl_comm(), nccl_stream); } } @@ -513,7 +525,7 @@ std::shared_ptr ProcessGroupNCCL::Collective( platform::DeviceContextPool::Instance().Get(places[i])) ->stream(); } else { - nccl_stream = places_to_ctx_[key][i]->stream(); + nccl_stream = places_to_ctx_.at(key)[i]->stream(); } memory::RecordStream(inputs[i].Holder(), nccl_stream); @@ -524,7 +536,7 @@ std::shared_ptr ProcessGroupNCCL::Collective( if (!use_calc_stream) { for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->comm_event_.Record(places_to_ctx_[key][i]); + task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); } } @@ -542,12 +554,13 @@ std::shared_ptr ProcessGroupNCCL::Collective( { std::lock_guard lock(mutex_); - if (!calc_event_) { + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLManagerCache(key, places); } } - SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); + SyncDefaultStream( + places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); auto task = CreateTask(places, rank_, op_type, inputs); @@ -558,10 +571,10 @@ std::shared_ptr ProcessGroupNCCL::Collective( platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream(); fn(inputs[i], outputs[i], - places_to_ctx_[key][i]->nccl_comm(), + places_to_ctx_.at(key)[i]->nccl_comm(), nccl_stream); } } @@ -570,13 +583,13 @@ std::shared_ptr ProcessGroupNCCL::Collective( for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); memory::RecordStream(inputs[i].Holder(), - places_to_ctx_[key][i]->stream()); + places_to_ctx_.at(key)[i]->stream()); } } for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->comm_event_.Record(places_to_ctx_[key][i]); + task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); } return task; } @@ -592,26 +605,27 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in, { std::lock_guard lock(mutex_); - if (!calc_event_) { + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLManagerCache(key, places); } } - SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); + SyncDefaultStream( + places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; if (FLAGS_use_stream_safe_cuda_allocator) { cuda_guard.SetDevice(places[0]); - memory::RecordStream(in->Holder(), places_to_ctx_[key][0]->stream()); + memory::RecordStream(in->Holder(), places_to_ctx_.at(key)[0]->stream()); } { platform::NCCLGroupGuard nccl_guard; cuda_guard.SetDevice(places[0]); - const auto& nccl_stream = places_to_ctx_[key][0]->stream(); - fn(in, out, places_to_ctx_[key][0]->nccl_comm(), nccl_stream); + const auto& nccl_stream = places_to_ctx_.at(key)[0]->stream(); + fn(in, out, places_to_ctx_.at(key)[0]->nccl_comm(), nccl_stream); } cuda_guard.SetDevice(places[0]); @@ -630,13 +644,14 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( { std::lock_guard lock(mutex_); - if (!calc_event_) { + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLManagerCache(key, places); } } if (!use_calc_stream) { - SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); + SyncDefaultStream( + places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); } auto task = @@ -655,10 +670,10 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::DeviceContextPool::Instance().Get(places[i])) ->stream(); } else { - nccl_stream = places_to_ctx_[key][i]->stream(); + nccl_stream = places_to_ctx_.at(key)[i]->stream(); } fn(tensors[i], - places_to_ctx_[key][i]->nccl_comm(), + places_to_ctx_.at(key)[i]->nccl_comm(), nccl_stream, dst_rank); } @@ -674,7 +689,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::DeviceContextPool::Instance().Get(places[i])) ->stream(); } else { - nccl_stream = places_to_ctx_[key][i]->stream(); + nccl_stream = places_to_ctx_.at(key)[i]->stream(); } memory::RecordStream(tensors[i].Holder(), nccl_stream); } @@ -683,7 +698,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( if (!use_calc_stream) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->comm_event_.Record(places_to_ctx_[key][i]); + task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); } } @@ -701,12 +716,13 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( { std::lock_guard lock(mutex_); - if (!calc_event_) { + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { CreateNCCLManagerCache(key, places); } } - SyncDefaultStream(places, calc_event_, places_to_ctx_[key]); + SyncDefaultStream( + places, place_to_calc_event_.at(key), places_to_ctx_.at(key)); auto task = CreateTask(places, rank_, op_type, tensors); @@ -717,9 +733,9 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + const auto& nccl_stream = places_to_ctx_.at(key)[i]->stream(); fn(tensors[i], - places_to_ctx_[key][i]->nccl_comm(), + places_to_ctx_.at(key)[i]->nccl_comm(), nccl_stream, dst_rank); } @@ -729,13 +745,13 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); memory::RecordStream(tensors[i].Holder(), - places_to_ctx_[key][i]->stream()); + places_to_ctx_.at(key)[i]->stream()); } } for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - task->comm_event_.Record(places_to_ctx_[key][i]); + task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); } return task; } @@ -1608,49 +1624,5 @@ std::shared_ptr ProcessGroupNCCL::Scatter( use_calc_stream); } -std::shared_ptr ProcessGroupNCCL::_ReduceScatterBase( - phi::DenseTensor& out_tensor, - phi::DenseTensor& in_tensor, - const ReduceScatterOptions& opts) { - // auto tensor = out_tensors.back(); - PADDLE_ENFORCE_EQ( - out_tensor.dtype(), - in_tensor.dtype(), - platform::errors::InvalidArgument( - "Input tensor and output tensor should be same dtype.")); - - PADDLE_ENFORCE_EQ( - out_tensor.numel() * size_, - in_tensor.numel(), - platform::errors::InvalidArgument("input tensor must be the same size as " - "output tensor size times world_size")); - - auto inputs = std::vector{in_tensor}; - auto outputs = std::vector{out_tensor}; - - return Collective( - inputs, - outputs, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - ncclComm_t comm, - const gpuStream_t& stream) { - if (FLAGS_use_stream_safe_cuda_allocator) { - platform::CUDADeviceGuard cuda_guard; - cuda_guard.SetDevice(output.place()); - memory::RecordStream(output.Holder(), stream); - } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( - input.data(), - output.data(), - output.numel(), - platform::ToNCCLDataType(input.dtype()), - ToNCCLRedType(opts.reduce_op), - comm, - stream)); - }, - CommType::REDUCE_SCATTER); -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 5ea83b906d..03260b8249 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -15,7 +15,6 @@ #pragma once #include -#include #include #include #include @@ -61,6 +60,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream { void Synchronize() override; void UpdateWaitChain(const phi::DeviceContext& ctx) override; + bool IsBlockCPUInWait() const { return block_cpu_in_wait_; } + void SetBlockCPUInWait() { block_cpu_in_wait_ = true; } + // TODO(sunyilun): methods below will be removed later NCCLTask(const std::vector& places, int rank, @@ -73,12 +75,10 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream); - public: - bool barrier_{false}; - platform::DeviceEvent comm_event_; // event on comm stream - private: - Place place_; + bool block_cpu_in_wait_{false}; + platform::DeviceEvent comm_event_; // event on comm stream + Place task_place_; }; public: @@ -253,11 +253,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr _ReduceScatterBase( - phi::DenseTensor&, // NOLINT - phi::DenseTensor&, // NOLINT - const ReduceScatterOptions&) override; - private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -278,8 +273,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream { bool sync_op, bool use_calc_stream); - void SyncCalcStream(const Place& place, - const std::shared_ptr& event); + void SyncCalcStream(const Place& place); // TODO(sunyilun): methods below will be removed later std::shared_ptr CreateTask( @@ -342,7 +336,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream { private: std::shared_ptr store_; - std::shared_ptr calc_event_; // event on calc stream + std::unordered_map + place_to_calc_event_; // event on calc stream std::unordered_map place_to_calc_ctx_; std::unordered_map> place_to_comm_ctx_; diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 06b26d66a6..721995e7b3 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -761,27 +761,6 @@ void BindDistributed(py::module *m) { py::arg("in"), py::arg("out"), py::arg("src"), - py::call_guard()) - - .def( - "_reduce_scatter_base", - [](distributed::ProcessGroup &self, - py::handle py_out_tensor, - py::handle py_in_tensor, - distributed::ReduceOp op) { - auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); - auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); - distributed::ReduceScatterOptions opts; - opts.reduce_op = op; - auto dense_out = std::dynamic_pointer_cast( - out_tensor.impl()); - auto dense_in = std::dynamic_pointer_cast( - in_tensor.impl()); - return self._ReduceScatterBase(*dense_out, *dense_in, opts); - }, - py::arg("out_tensor"), - py::arg("in_tensor"), - py::arg("op") = distributed::ReduceOp::SUM, py::call_guard()); auto ProcessGroupStream = -- GitLab