From dba59db7d5d53fabdecb65b0dd93d59e9cf18241 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 27 Dec 2021 21:15:42 +0800 Subject: [PATCH] [fleet_executor] Add task loop thread pool (#38420) --- .../distributed/fleet_executor/CMakeLists.txt | 6 +- .../distributed/fleet_executor/carrier.cc | 102 +++++------------- .../distributed/fleet_executor/carrier.h | 26 +++-- .../fleet_executor/compute_interceptor.cc | 2 +- .../fleet_executor/fleet_executor.cc | 5 + .../distributed/fleet_executor/interceptor.cc | 102 ++++++++---------- .../distributed/fleet_executor/interceptor.h | 26 ++--- .../distributed/fleet_executor/message_bus.cc | 57 +++++++--- .../distributed/fleet_executor/message_bus.h | 11 +- .../distributed/fleet_executor/task_loop.cc | 82 ++++++++++++++ .../distributed/fleet_executor/task_loop.h | 81 ++++++++++++++ .../fleet_executor/task_loop_thread.cc | 58 ++++++++++ .../fleet_executor/task_loop_thread.h | 44 ++++++++ .../fleet_executor/task_loop_thread_pool.cc | 66 ++++++++++++ .../fleet_executor/task_loop_thread_pool.h | 47 ++++++++ .../interceptor_ping_pong_with_brpc_test.cc | 9 +- paddle/fluid/framework/blocking_queue.h | 10 ++ 17 files changed, 544 insertions(+), 190 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop.h create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread.h create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc create mode 100644 paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 82444ae77dc..95ec6b32996 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -10,10 +10,12 @@ else() set(BRPC_DEPS "") endif() +cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) + cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc - DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry - executor_gc_helper gflags glog ${BRPC_DEPS}) + DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper + op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 3279f954fa5..ea35b36aa4a 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr runtime_graph, place_ = place; root_scope_ = root_scope; dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); + + // TODO(fleet_exe dev): thread pool + thread_num_ = 1; + thread_pool_.SetThreadNum(thread_num_); + thread_pool_.Start(); + CreateInterceptors(); is_init_ = true; } -void Carrier::Release() { - // NOTE(wangxi): must join before `Derived Interceptor` destruct, - // otherwise Derived object will be destructed before thread complete. - - for (int64_t id : source_interceptor_ids_) { - VLOG(3) << "Carrier Release is sending stop to source interceptor " << id - << "."; - InterceptorMessage stop_msg; - // source node STOP is send by carrier, so set src_id=-1 - stop_msg.set_src_id(-1); - stop_msg.set_dst_id(id); - stop_msg.set_message_type(STOP); - Send(stop_msg); - } - - // TODO(wangxi): Maybe need a better to use thread. - for (auto& interceptor : interceptor_idx_to_interceptor_) { - interceptor.second->Join(); - } -} +void Carrier::Release() {} Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } @@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage( VLOG(3) << "Receiving control message from rank " << interceptor_message.src_id() << " to rank " << interceptor_message.dst_id(); + // for barrier + msg_bus_->IncreaseBarrierCount(); } else { - { - std::unique_lock lock_creating(creating_flag_mutex_); - if (creating_interceptors_) { - std::unique_lock lock_message(tmp_message_mutex_); - // Cannot handle the message to interceptor since interceptors - // are still under creating. Will enqueue into a tmp stack. - VLOG(3) << "Receiving message while creating interceptors."; - message_tmp_.emplace_back(interceptor_message); - return true; - } - } int64_t dst_id = interceptor_message.dst_id(); Interceptor* dst_interceptor = GetInterceptor(dst_id); dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); @@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage( return true; } +void Carrier::Barrier() { msg_bus_->Barrier(); } + Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), @@ -109,6 +89,11 @@ void Carrier::Wait() { cond_var_.wait(lock); } +void Carrier::WakeUp() { + // probably double notify, but ok for ut + cond_var_.notify_all(); +} + void Carrier::Start() { PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, platform::errors::PreconditionNotMet( @@ -126,12 +111,11 @@ void Carrier::Start() { start_msg.set_message_type(DATA_IS_READY); Send(start_msg); } + // TODO(wangxi): async step Wait(); dev_ctx_->Wait(); } -std::condition_variable& Carrier::GetCondVar() { return cond_var_; } - bool Carrier::IsInit() const { return is_init_; } int64_t Carrier::GetRank(int64_t interceptor_id) const { @@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, "The interceptor id should be unique.", interceptor_id)); interceptor->RegisterCarrier(this); + + // TODO(fleet_exe dev): get loop + auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_); + PADDLE_ENFORCE_NOT_NULL( + loop, platform::errors::Fatal("thread task loop must not null")); + interceptor->RegisterTaskLoop(loop); + auto* ptr = interceptor.get(); interceptor_idx_to_interceptor_.insert( std::make_pair(interceptor_id, std::move(interceptor))); return ptr; } -void Carrier::SetCreatingFlag(bool flag) { - // set the creating flag - creating_flag_mutex_.lock(); - VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_ - << " to " << flag << "."; - creating_interceptors_ = flag; - creating_flag_mutex_.unlock(); - if (!flag) { - for (auto& pair : interceptor_idx_to_interceptor_) { - // update the source interceptor id - if (std::find(source_interceptor_ids_.begin(), - source_interceptor_ids_.end(), - pair.first) == source_interceptor_ids_.end()) { - auto task = pair.second->GetTaskNode(); - if (task != nullptr && task->upstream().empty()) { - source_interceptor_ids_.emplace_back(pair.first); - } - } - } - // finish create interceptors outside, handle tmp messsages - HandleTmpMessages(); - } -} - -void Carrier::HandleTmpMessages() { - // NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this - // `HandleTmpMessages` method, the creating_interceptors_ flag - // must be false, therefore, there won't have conflict with the - // lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage` - // on the same thread. - std::unique_lock lock(tmp_message_mutex_); - VLOG(3) << "Carrier has received " << message_tmp_.size() - << " messages during creating interceptors."; - for (const auto& msg : message_tmp_) { - EnqueueInterceptorMessage(msg); - } - message_tmp_.clear(); -} - static std::shared_ptr GetGC( const platform::Place& place) { int64_t max_memory_size = framework::GetEagerDeletionThreshold(); @@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() { source_interceptor_ids_.emplace_back(interceptor_id); } } - // The carrier will be always waiting for outside initializer - // since there is no interceptor has been created during auto init - creating_flag_mutex_.lock(); - creating_interceptors_ = false; - creating_flag_mutex_.unlock(); - HandleTmpMessages(); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index 54cf2150030..81643a74550 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -24,6 +24,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -47,7 +48,11 @@ class Carrier final { Carrier() = default; Carrier(int64_t rank, const std::unordered_map& interceptor_id_to_rank) - : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {} + : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) { + thread_num_ = 1; + thread_pool_.SetThreadNum(thread_num_); + thread_pool_.Start(); + } ~Carrier(); void Init(int64_t rank, std::shared_ptr runtime_graph, framework::Scope* root_scope, framework::Scope* minibatch_scope, @@ -56,6 +61,7 @@ class Carrier final { void Release(); void Wait(); + void WakeUp(); // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); @@ -67,23 +73,18 @@ class Carrier final { Interceptor* SetInterceptor(int64_t interceptor_id, std::unique_ptr); - void SetCreatingFlag(bool flag); + void SetCreatingFlag(bool flag) {} void SetMsgBus(const std::shared_ptr& msg_bus) { msg_bus_ = msg_bus; } - std::condition_variable& GetCondVar(); - void Start(); bool IsInit() const; bool Send(const InterceptorMessage& msg); - // NOTE: This mutex will be used in interceptor's RunOps function. - // This mutex is used for avoiding forward ops and backward ops run - // simultaneously, which will lead to a random hang for some sync ops. - std::mutex run; + void Barrier(); private: DISABLE_COPY_AND_ASSIGN(Carrier); @@ -91,8 +92,6 @@ class Carrier final { // create each Interceptor void CreateInterceptors(); - void HandleTmpMessages(); - int64_t GetRank(int64_t interceptor_id) const; // interceptor logic id to actually interceptor @@ -101,10 +100,6 @@ class Carrier final { std::vector source_interceptor_ids_; - std::vector message_tmp_{}; - std::mutex tmp_message_mutex_; - bool creating_interceptors_{true}; - std::mutex creating_flag_mutex_; bool is_init_{false}; std::mutex running_mutex_; @@ -118,6 +113,9 @@ class Carrier final { std::shared_ptr msg_bus_; int64_t rank_; std::unordered_map interceptor_id_to_rank_; + + int thread_num_; + TaskLoopThreadPool thread_pool_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 1f0d3408a3d..d934ab1948e 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { } void ComputeInterceptor::RunOps() { - std::unique_lock lock(carrier_->run); VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " << step_ + 1 << " time."; for (auto op : node_->ops()) { @@ -198,6 +197,7 @@ void ComputeInterceptor::Run() { if (is_last_ && (step_ % node_->max_run_times() == 0)) { VLOG(3) << "Interceptor " << GetInterceptorId() << " is stopping carrier."; + // FIXME(wangxi): with multi sink interceptor StopCarrier(); } } diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 697c4aaaf3a..a5badcb36eb 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -89,6 +89,11 @@ void FleetExecutor::Init( CreateCarrier(); InitCarrier(); InitMessageBus(); + + // refine this? wait all carrier ready + // NOTE(wangxi): must add after Carrier::SetMsgBus, for we use + // MessageBus::IncreaseBarrierCount when receive barrier msg. + GetCarrier()->Barrier(); } void FleetExecutor::InitCarrier() { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index f5501754cd7..710ebda4124 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -14,26 +14,21 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" namespace paddle { namespace distributed { Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) - : interceptor_id_(interceptor_id), node_(node) { - interceptor_thread_ = std::thread([this]() { - VLOG(3) << "Interceptor " << interceptor_id_ - << " starts the thread pooling it's local mailbox."; - PoolTheMailbox(); - }); -} - -Interceptor::~Interceptor() { Join(); } - -void Interceptor::Join() { - if (interceptor_thread_.joinable()) { - interceptor_thread_.join(); - } + : interceptor_id_(interceptor_id), node_(node) {} + +Interceptor::~Interceptor() { + // FIXME(wangxi): throw in stop function + // std::lock_guard lock(mutex_); + // PADDLE_ENFORCE_EQ(messages_.empty(), true, + // platform::errors::PreconditionNotMet( + // "Interceptor must destruct with messages empty")); } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } @@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) { handle_(msg); } +void Interceptor::LoopOnce() { + std::deque tmp_messages; + { + std::lock_guard lock(mutex_); + messages_.swap(tmp_messages); + } + PADDLE_ENFORCE_EQ(tmp_messages.empty(), false, + platform::errors::PreconditionNotMet( + "tmp_messages must not empty in task loop")); + + for (auto& msg : tmp_messages) { + const MessageType message_type = msg.message_type(); + VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" + << " from interceptor " << msg.src_id() + << " with message: " << message_type << "."; + + Handle(msg); + } +} + void Interceptor::StopCarrier() { PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( "Carrier is not registered.")); - std::condition_variable& cond_var = carrier_->GetCondVar(); - // probably double notify, but ok for ut - cond_var.notify_all(); -} - -int64_t Interceptor::GetInterceptorId() const { - // return the interceptor id - return interceptor_id_; + carrier_->WakeUp(); } void Interceptor::EnqueueRemoteInterceptorMessage( - const InterceptorMessage& interceptor_message) { + const InterceptorMessage& message) { // Called by Carrier, enqueue an InterceptorMessage to remote mailbox - VLOG(3) << "Enqueue message: " << interceptor_message.message_type() - << " into " << interceptor_id_ << "'s remote mailbox."; - remote_mailbox_.Push(interceptor_message); + VLOG(3) << "Enqueue message: " << message.message_type() << " into " + << interceptor_id_ << "'s remote mailbox."; + + bool empty = false; + { + std::lock_guard lock(mutex_); + empty = messages_.empty(); + messages_.emplace_back(message); + } + if (empty) { + loop_->QueueInLoop([this]() { LoopOnce(); }); + } } bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { @@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { return carrier_->Send(msg); } -void Interceptor::PoolTheMailbox() { - // pool the local mailbox, parse the Message - for (;;) { - if (local_mailbox_.empty()) { - // local mailbox is empty, fetch the remote mailbox - VLOG(3) << interceptor_id_ << "'s local mailbox is empty. " - << "Fetch the remote mailbox."; - PADDLE_ENFORCE_EQ(FetchRemoteMailbox(), true, - platform::errors::InvalidArgument( - "Error encountered when fetch remote mailbox.")); - } - const InterceptorMessage interceptor_message = local_mailbox_.front(); - local_mailbox_.pop_front(); - const MessageType message_type = interceptor_message.message_type(); - VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message" - << " from interceptor " << interceptor_message.src_id() - << " with message: " << message_type << "."; - - Handle(interceptor_message); - - if (stop_) { - // break the pooling thread - VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting."; - break; - } - } -} - -bool Interceptor::FetchRemoteMailbox() { - remote_mailbox_.PopAll(&local_mailbox_); - return !local_mailbox_.empty(); -} - static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() { static InterceptorFactory::CreateInterceptorMap interceptorMap; return interceptorMap; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index d9e8d050dd1..cb7ff2da89a 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -38,6 +38,7 @@ namespace distributed { class TaskNode; class Carrier; +class TaskLoop; class Interceptor { public: @@ -50,15 +51,13 @@ class Interceptor { virtual ~Interceptor(); - void Join(); - // register interceptor handle void RegisterMsgHandle(MsgHandle handle); void Handle(const InterceptorMessage& msg); // return the interceptor id - int64_t GetInterceptorId() const; + int64_t GetInterceptorId() const { return interceptor_id_; } // Called by Carrier, enqueue an InterceptorMessage to remote mailbox void EnqueueRemoteInterceptorMessage( @@ -77,6 +76,7 @@ class Interceptor { gc_ = gc; } void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; } + void RegisterTaskLoop(TaskLoop* loop) { loop_ = loop; } TaskNode* GetTaskNode() const { return node_; } @@ -101,28 +101,16 @@ class Interceptor { std::shared_ptr gc_{nullptr}; Carrier* carrier_; + TaskLoop* loop_; private: - // pool the local mailbox, parse the Message - void PoolTheMailbox(); - - // fetch all Message from remote mailbox to local mailbox - // return true if remote mailbox not empty, otherwise return false - bool FetchRemoteMailbox(); + void LoopOnce(); // interceptor handle which process message MsgHandle handle_{nullptr}; - // interceptor runs PoolTheMailbox() function to poll local mailbox - std::thread interceptor_thread_; - - // remote mailbox, written by EnqueueRemoteMessage() - // read by FetchRemoteMailbox() - framework::BlockingQueue remote_mailbox_; - - // local mailbox, written by FetchRemoteMailbox() - // read by PoolTheMailbox() - std::deque local_mailbox_; + std::mutex mutex_; + std::deque messages_; int64_t already_run_times_{0}; int64_t used_slot_nums_{0}; diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index ac7b08c4b28..dd95a90ad1b 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank, return true; } -void MessageBus::TestConnection() { - InterceptorMessage ctrl_msg; - ctrl_msg.set_ctrl_message(true); - ctrl_msg.set_src_id(rank_); - for (const auto& dst_rank_pair : rank_to_addr_) { - int64_t dst_rank = dst_rank_pair.first; - if (dst_rank != rank_) { - ctrl_msg.set_dst_id(dst_rank); - VLOG(3) << "Send control message bus from rank " << rank_ << " to rank " - << dst_rank; - while (!Send(dst_rank, ctrl_msg)) { +void MessageBus::IncreaseBarrierCount() { + VLOG(3) << "IncreaseBarrierCount"; + { + std::unique_lock lock(mutex_); + ++count_; + cv_.notify_one(); + } + VLOG(3) << "End IncreaseBarrierCount"; +} + +void MessageBus::Barrier() { + // gather to root + if (rank_ != 0) { + InterceptorMessage ctrl_msg; + ctrl_msg.set_ctrl_message(true); + ctrl_msg.set_src_id(rank_); + ctrl_msg.set_dst_id(0); + VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0"; + while (!Send(0, ctrl_msg)) { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + } else { + VLOG(3) << "Barrier 0 wait others rank ready"; + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { + return count_ == static_cast(rank_to_addr_.size() - 1); + }); + count_ = 0; + } + + // scatter from root + if (rank_ == 0) { + for (int i = 1; i < static_cast(rank_to_addr_.size()); ++i) { + InterceptorMessage ctrl_msg; + ctrl_msg.set_ctrl_message(true); + ctrl_msg.set_src_id(0); + ctrl_msg.set_dst_id(i); + VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i; + while (!Send(i, ctrl_msg)) { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } - VLOG(3) << "Message bus has connected to rank: " << dst_rank << "."; } + } else { + VLOG(3) << "Barrier " << rank_ << " wait others rank ready"; + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return count_ == 1; }); + count_ = 0; } } @@ -151,7 +183,6 @@ void MessageBus::ListenPort() { interval += 500; } LOG(INFO) << "Message bus's listen port thread starts successful."; - TestConnection(); #else LOG(WARNING) << "Fleet executor's ListenPort() is a fake function when Paddle is " diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index d4a2af54e6c..c8685a73900 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -51,14 +52,15 @@ class MessageBus final { // called by Interceptor, send InterceptorMessage to dst bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message); + void IncreaseBarrierCount(); + void Barrier(); + private: DISABLE_COPY_AND_ASSIGN(MessageBus); // function keep listen the port and handle the message void ListenPort(); - void TestConnection(); - const std::string& GetAddr(int64_t rank) const; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ @@ -84,6 +86,11 @@ class MessageBus final { // brpc server brpc::Server server_; #endif + + // for barrier + std::mutex mutex_; + std::condition_variable cv_; + int count_{0}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/task_loop.cc b/paddle/fluid/distributed/fleet_executor/task_loop.cc new file mode 100644 index 00000000000..bfe9a939b96 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace distributed { + +thread_local TaskLoop* TaskLoop::thread_local_loop_ = nullptr; + +TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; } + +TaskLoop::TaskLoop() + : looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) { + PADDLE_ENFORCE_EQ( + thread_local_loop_, nullptr, + platform::errors::AlreadyExists("Another TaskLoop is already init.")); + thread_local_loop_ = this; +} + +TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; } + +void TaskLoop::Loop() { + PADDLE_ENFORCE_EQ(looping_, false, + platform::errors::PreconditionNotMet( + "Loop can only execute in one loop thread")); + AssertInLoopThread(); + + looping_ = true; + quit_ = false; + + while (!quit_) { + auto tasks = tasks_.PopAll(); + for (auto& task : tasks) { + task(); + } + } + looping_ = false; +} + +void TaskLoop::Quit() { + quit_ = true; + if (!IsInLoopThread()) WakeUp(); +} + +void TaskLoop::RunInLoop(Functor cb) { + if (IsInLoopThread()) { + cb(); + } else { + QueueInLoop(cb); + } +} + +void TaskLoop::QueueInLoop(Functor cb) { tasks_.Push(cb); } + +void TaskLoop::WakeUp() { + Functor task([] {}); + QueueInLoop(task); +} + +void TaskLoop::AbortNotInLoopThread() { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "This TaskLoop was created in thread %d, but current thread is %d", + thread_id_, std::this_thread::get_id())); +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop.h b/paddle/fluid/distributed/fleet_executor/task_loop.h new file mode 100644 index 00000000000..91425304e57 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop.h @@ -0,0 +1,81 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/blocking_queue.h" + +namespace paddle { +namespace distributed { + +class TaskLoop { + public: + static TaskLoop* GetTaskLoopOfCurrentThread(); + + using Functor = std::function; + + TaskLoop(); + ~TaskLoop(); + + void Loop(); + void Quit(); + + void RunInLoop(Functor cb); + void QueueInLoop(Functor cb); + + template + auto Enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + std::future task_future = task->get_future(); + + tasks_.Push([task]() { (*task)(); }); + return task_future; + } + + void WakeUp(); + + bool IsInLoopThread() const { + return thread_id_ == std::this_thread::get_id(); + } + + void AssertInLoopThread() { + if (!IsInLoopThread()) { + AbortNotInLoopThread(); + } + } + + private: + void AbortNotInLoopThread(); + + static thread_local TaskLoop* thread_local_loop_; + + bool looping_; + std::atomic quit_; + std::thread::id thread_id_; + + framework::BlockingQueue tasks_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread.cc b/paddle/fluid/distributed/fleet_executor/task_loop_thread.cc new file mode 100644 index 00000000000..bb313ad3789 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h" + +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace distributed { + +TaskLoopThread::TaskLoopThread() : start_(false), loop_(nullptr) {} + +TaskLoopThread::~TaskLoopThread() { + if (loop_ != nullptr) { + loop_->Quit(); + thread_.join(); + } +} + +TaskLoop* TaskLoopThread::StartLoop() { + PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet( + "thread is already running.")); + start_ = true; + thread_ = std::thread([this]() { Loop(); }); + + std::unique_lock lock(mutex_); + cv_.wait(lock, [=] { return loop_ != nullptr; }); + return loop_; +} + +void TaskLoopThread::Loop() { + TaskLoop loop; + { + std::unique_lock lock(mutex_); + loop_ = &loop; + cv_.notify_one(); + } + loop.Loop(); + + std::unique_lock lock(mutex_); + loop_ = nullptr; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread.h b/paddle/fluid/distributed/fleet_executor/task_loop_thread.h new file mode 100644 index 00000000000..07952abdc24 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace distributed { + +class TaskLoop; + +class TaskLoopThread { + public: + TaskLoopThread(); + ~TaskLoopThread(); + + TaskLoop* StartLoop(); + + private: + void Loop(); + + bool start_; + TaskLoop* loop_; + std::thread thread_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc new file mode 100644 index 00000000000..ed34bbb87fc --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h" + +#include "paddle/fluid/distributed/fleet_executor/task_loop.h" +#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace distributed { + +TaskLoopThreadPool::TaskLoopThreadPool() : TaskLoopThreadPool(1) {} + +TaskLoopThreadPool::TaskLoopThreadPool(int thread_num) + : start_(false), thread_num_(thread_num) {} + +TaskLoopThreadPool::~TaskLoopThreadPool() = default; + +void TaskLoopThreadPool::Start() { + PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet( + "thread pool is already start.")); + PADDLE_ENFORCE_GT( + thread_num_, 0, + platform::errors::InvalidArgument( + "thread num must greater than 0, but now is %d", thread_num_)); + + start_ = true; + for (int i = 0; i < thread_num_; ++i) { + threads_.emplace_back(new TaskLoopThread()); + loops_.push_back(threads_[i]->StartLoop()); + } +} + +TaskLoop* TaskLoopThreadPool::GetLoop(int tid) { + PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet( + "thread pool must start first.")); + PADDLE_ENFORCE_GE(tid, 0, platform::errors::OutOfRange( + "tid must >= 0, but now is %d", tid)); + PADDLE_ENFORCE_LT(tid, thread_num_, + platform::errors::OutOfRange( + "tid must < thread_num, but now tid=%d thread_num=%d", + tid, thread_num_)); + return loops_[tid]; +} + +std::vector TaskLoopThreadPool::GetAllLoops() { + PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet( + "thread pool must start first.")); + return loops_; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h new file mode 100644 index 00000000000..ffc9588f4e7 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h @@ -0,0 +1,47 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace distributed { + +class TaskLoop; +class TaskLoopThread; + +class TaskLoopThreadPool { + public: + TaskLoopThreadPool(); + explicit TaskLoopThreadPool(int thread_num); + ~TaskLoopThreadPool(); + + void SetThreadNum(int thread_num) { thread_num_ = thread_num; } + + void Start(); + + TaskLoop* GetLoop(int tid); + std::vector GetAllLoops(); + + private: + bool start_; + int thread_num_; + std::vector> threads_; + std::vector loops_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index a577b30fa8c..262e5caa8c8 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) { FleetExecutor::CreateCarrier(0, interceptor_id_to_rank); carrier->SetCreatingFlag(false); auto msg_bus = std::make_shared(); - msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); carrier->SetMsgBus(msg_bus); + // NOTE: need Init msg_bus after carrier SetMsgBus + msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0); Interceptor* a = carrier->SetInterceptor( 0, InterceptorFactory::Create("PingPong", 0, nullptr)); + carrier->Barrier(); + InterceptorMessage msg; a->Send(1, msg); carrier->Wait(); @@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) { FleetExecutor::CreateCarrier(1, interceptor_id_to_rank); carrier->SetCreatingFlag(false); auto msg_bus = std::make_shared(); - msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); carrier->SetMsgBus(msg_bus); + msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1); carrier->SetInterceptor( 1, InterceptorFactory::Create("PingPong", 1, nullptr)); + carrier->Barrier(); + carrier->Wait(); int status; int ret = waitpid(pid, &status, 0); diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h index 5bc38c1398a..04937fa6b97 100644 --- a/paddle/fluid/framework/blocking_queue.h +++ b/paddle/fluid/framework/blocking_queue.h @@ -81,6 +81,16 @@ class BlockingQueue { std::swap(*empty_queue, q_); } + std::deque PopAll() { + std::deque ret; + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !q_.empty(); }); + std::swap(ret, q_); + } + return ret; + } + T Pop() { std::unique_lock lock(mutex_); cv_.wait(lock, [=] { return !q_.empty(); }); -- GitLab