未验证 提交 dba59db7 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add task loop thread pool (#38420)

上级 5b6b88ab
...@@ -10,10 +10,12 @@ else() ...@@ -10,10 +10,12 @@ else()
set(BRPC_DEPS "") set(BRPC_DEPS "")
endif() endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
executor_gc_helper gflags glog ${BRPC_DEPS}) op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
...@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph, ...@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
place_ = place; place_ = place;
root_scope_ = root_scope; root_scope_ = root_scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
CreateInterceptors(); CreateInterceptors();
is_init_ = true; is_init_ = true;
} }
void Carrier::Release() { void Carrier::Release() {}
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
InterceptorMessage stop_msg;
// source node STOP is send by carrier, so set src_id=-1
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
Send(stop_msg);
}
// TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join();
}
}
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
...@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage( ...@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
VLOG(3) << "Receiving control message from rank " VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank " << interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id(); << interceptor_message.dst_id();
// for barrier
msg_bus_->IncreaseBarrierCount();
} else { } else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
if (creating_interceptors_) {
std::unique_lock<std::mutex> lock_message(tmp_message_mutex_);
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG(3) << "Receiving message while creating interceptors.";
message_tmp_.emplace_back(interceptor_message);
return true;
}
}
int64_t dst_id = interceptor_message.dst_id(); int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id); Interceptor* dst_interceptor = GetInterceptor(dst_id);
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message); dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
...@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage( ...@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
return true; return true;
} }
void Carrier::Barrier() { msg_bus_->Barrier(); }
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(), PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
...@@ -109,6 +89,11 @@ void Carrier::Wait() { ...@@ -109,6 +89,11 @@ void Carrier::Wait() {
cond_var_.wait(lock); cond_var_.wait(lock);
} }
void Carrier::WakeUp() {
// probably double notify, but ok for ut
cond_var_.notify_all();
}
void Carrier::Start() { void Carrier::Start() {
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true, PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -126,12 +111,11 @@ void Carrier::Start() { ...@@ -126,12 +111,11 @@ void Carrier::Start() {
start_msg.set_message_type(DATA_IS_READY); start_msg.set_message_type(DATA_IS_READY);
Send(start_msg); Send(start_msg);
} }
// TODO(wangxi): async step
Wait(); Wait();
dev_ctx_->Wait(); dev_ctx_->Wait();
} }
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; } bool Carrier::IsInit() const { return is_init_; }
int64_t Carrier::GetRank(int64_t interceptor_id) const { int64_t Carrier::GetRank(int64_t interceptor_id) const {
...@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, ...@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id should be unique.", "The interceptor id should be unique.",
interceptor_id)); interceptor_id));
interceptor->RegisterCarrier(this); interceptor->RegisterCarrier(this);
// TODO(fleet_exe dev): get loop
auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_);
PADDLE_ENFORCE_NOT_NULL(
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);
auto* ptr = interceptor.get(); auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert( interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor))); std::make_pair(interceptor_id, std::move(interceptor)));
return ptr; return ptr;
} }
void Carrier::SetCreatingFlag(bool flag) {
// set the creating flag
creating_flag_mutex_.lock();
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
<< " to " << flag << ".";
creating_interceptors_ = flag;
creating_flag_mutex_.unlock();
if (!flag) {
for (auto& pair : interceptor_idx_to_interceptor_) {
// update the source interceptor id
if (std::find(source_interceptor_ids_.begin(),
source_interceptor_ids_.end(),
pair.first) == source_interceptor_ids_.end()) {
auto task = pair.second->GetTaskNode();
if (task != nullptr && task->upstream().empty()) {
source_interceptor_ids_.emplace_back(pair.first);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
}
void Carrier::HandleTmpMessages() {
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
// `HandleTmpMessages` method, the creating_interceptors_ flag
// must be false, therefore, there won't have conflict with the
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
// on the same thread.
std::unique_lock<std::mutex> lock(tmp_message_mutex_);
VLOG(3) << "Carrier has received " << message_tmp_.size()
<< " messages during creating interceptors.";
for (const auto& msg : message_tmp_) {
EnqueueInterceptorMessage(msg);
}
message_tmp_.clear();
}
static std::shared_ptr<framework::GarbageCollector> GetGC( static std::shared_ptr<framework::GarbageCollector> GetGC(
const platform::Place& place) { const platform::Place& place) {
int64_t max_memory_size = framework::GetEagerDeletionThreshold(); int64_t max_memory_size = framework::GetEagerDeletionThreshold();
...@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() { ...@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
source_interceptor_ids_.emplace_back(interceptor_id); source_interceptor_ids_.emplace_back(interceptor_id);
} }
} }
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_.lock();
creating_interceptors_ = false;
creating_flag_mutex_.unlock();
HandleTmpMessages();
} }
} // namespace distributed } // namespace distributed
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
...@@ -47,7 +48,11 @@ class Carrier final { ...@@ -47,7 +48,11 @@ class Carrier final {
Carrier() = default; Carrier() = default;
Carrier(int64_t rank, Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {} : rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
~Carrier(); ~Carrier();
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph, void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope, framework::Scope* root_scope, framework::Scope* minibatch_scope,
...@@ -56,6 +61,7 @@ class Carrier final { ...@@ -56,6 +61,7 @@ class Carrier final {
void Release(); void Release();
void Wait(); void Wait();
void WakeUp();
// Enqueue a message to corresponding interceptor id // Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
...@@ -67,23 +73,18 @@ class Carrier final { ...@@ -67,23 +73,18 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id, Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>); std::unique_ptr<Interceptor>);
void SetCreatingFlag(bool flag); void SetCreatingFlag(bool flag) {}
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) { void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus; msg_bus_ = msg_bus;
} }
std::condition_variable& GetCondVar();
void Start(); void Start();
bool IsInit() const; bool IsInit() const;
bool Send(const InterceptorMessage& msg); bool Send(const InterceptorMessage& msg);
// NOTE: This mutex will be used in interceptor's RunOps function. void Barrier();
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std::mutex run;
private: private:
DISABLE_COPY_AND_ASSIGN(Carrier); DISABLE_COPY_AND_ASSIGN(Carrier);
...@@ -91,8 +92,6 @@ class Carrier final { ...@@ -91,8 +92,6 @@ class Carrier final {
// create each Interceptor // create each Interceptor
void CreateInterceptors(); void CreateInterceptors();
void HandleTmpMessages();
int64_t GetRank(int64_t interceptor_id) const; int64_t GetRank(int64_t interceptor_id) const;
// interceptor logic id to actually interceptor // interceptor logic id to actually interceptor
...@@ -101,10 +100,6 @@ class Carrier final { ...@@ -101,10 +100,6 @@ class Carrier final {
std::vector<int64_t> source_interceptor_ids_; std::vector<int64_t> source_interceptor_ids_;
std::vector<InterceptorMessage> message_tmp_{};
std::mutex tmp_message_mutex_;
bool creating_interceptors_{true};
std::mutex creating_flag_mutex_;
bool is_init_{false}; bool is_init_{false};
std::mutex running_mutex_; std::mutex running_mutex_;
...@@ -118,6 +113,9 @@ class Carrier final { ...@@ -118,6 +113,9 @@ class Carrier final {
std::shared_ptr<MessageBus> msg_bus_; std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_; int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_;
TaskLoopThreadPool thread_pool_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
} }
void ComputeInterceptor::RunOps() { void ComputeInterceptor::RunOps() {
std::unique_lock<std::mutex> lock(carrier_->run);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time."; << step_ + 1 << " time.";
for (auto op : node_->ops()) { for (auto op : node_->ops()) {
...@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() { ...@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
if (is_last_ && (step_ % node_->max_run_times() == 0)) { if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId() VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier."; << " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier(); StopCarrier();
} }
} }
......
...@@ -89,6 +89,11 @@ void FleetExecutor::Init( ...@@ -89,6 +89,11 @@ void FleetExecutor::Init(
CreateCarrier(); CreateCarrier();
InitCarrier(); InitCarrier();
InitMessageBus(); InitMessageBus();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier()->Barrier();
} }
void FleetExecutor::InitCarrier() { void FleetExecutor::InitCarrier() {
......
...@@ -14,26 +14,21 @@ ...@@ -14,26 +14,21 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) { : interceptor_id_(interceptor_id), node_(node) {}
interceptor_thread_ = std::thread([this]() {
VLOG(3) << "Interceptor " << interceptor_id_ Interceptor::~Interceptor() {
<< " starts the thread pooling it's local mailbox."; // FIXME(wangxi): throw in stop function
PoolTheMailbox(); // std::lock_guard<std::mutex> lock(mutex_);
}); // PADDLE_ENFORCE_EQ(messages_.empty(), true,
} // platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
Interceptor::~Interceptor() { Join(); }
void Interceptor::Join() {
if (interceptor_thread_.joinable()) {
interceptor_thread_.join();
}
} }
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
...@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) { ...@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
handle_(msg); handle_(msg);
} }
void Interceptor::LoopOnce() {
std::deque<InterceptorMessage> tmp_messages;
{
std::lock_guard<std::mutex> lock(mutex_);
messages_.swap(tmp_messages);
}
PADDLE_ENFORCE_EQ(tmp_messages.empty(), false,
platform::errors::PreconditionNotMet(
"tmp_messages must not empty in task loop"));
for (auto& msg : tmp_messages) {
const MessageType message_type = msg.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << msg.src_id()
<< " with message: " << message_type << ".";
Handle(msg);
}
}
void Interceptor::StopCarrier() { void Interceptor::StopCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered.")); "Carrier is not registered."));
std::condition_variable& cond_var = carrier_->GetCondVar(); carrier_->WakeUp();
// probably double notify, but ok for ut
cond_var.notify_all();
}
int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return interceptor_id_;
} }
void Interceptor::EnqueueRemoteInterceptorMessage( void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox // Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type() VLOG(3) << "Enqueue message: " << message.message_type() << " into "
<< " into " << interceptor_id_ << "'s remote mailbox."; << interceptor_id_ << "'s remote mailbox.";
remote_mailbox_.Push(interceptor_message);
bool empty = false;
{
std::lock_guard<std::mutex> lock(mutex_);
empty = messages_.empty();
messages_.emplace_back(message);
}
if (empty) {
loop_->QueueInLoop([this]() { LoopOnce(); });
}
} }
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
...@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { ...@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return carrier_->Send(msg); return carrier_->Send(msg);
} }
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
for (;;) {
if (local_mailbox_.empty()) {
// local mailbox is empty, fetch the remote mailbox
VLOG(3) << interceptor_id_ << "'s local mailbox is empty. "
<< "Fetch the remote mailbox.";
PADDLE_ENFORCE_EQ(FetchRemoteMailbox(), true,
platform::errors::InvalidArgument(
"Error encountered when fetch remote mailbox."));
}
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop_front();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
Handle(interceptor_message);
if (stop_) {
// break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break;
}
}
}
bool Interceptor::FetchRemoteMailbox() {
remote_mailbox_.PopAll(&local_mailbox_);
return !local_mailbox_.empty();
}
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() { static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
static InterceptorFactory::CreateInterceptorMap interceptorMap; static InterceptorFactory::CreateInterceptorMap interceptorMap;
return interceptorMap; return interceptorMap;
......
...@@ -38,6 +38,7 @@ namespace distributed { ...@@ -38,6 +38,7 @@ namespace distributed {
class TaskNode; class TaskNode;
class Carrier; class Carrier;
class TaskLoop;
class Interceptor { class Interceptor {
public: public:
...@@ -50,15 +51,13 @@ class Interceptor { ...@@ -50,15 +51,13 @@ class Interceptor {
virtual ~Interceptor(); virtual ~Interceptor();
void Join();
// register interceptor handle // register interceptor handle
void RegisterMsgHandle(MsgHandle handle); void RegisterMsgHandle(MsgHandle handle);
void Handle(const InterceptorMessage& msg); void Handle(const InterceptorMessage& msg);
// return the interceptor id // return the interceptor id
int64_t GetInterceptorId() const; int64_t GetInterceptorId() const { return interceptor_id_; }
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox // Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void EnqueueRemoteInterceptorMessage( void EnqueueRemoteInterceptorMessage(
...@@ -77,6 +76,7 @@ class Interceptor { ...@@ -77,6 +76,7 @@ class Interceptor {
gc_ = gc; gc_ = gc;
} }
void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; } void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; }
void RegisterTaskLoop(TaskLoop* loop) { loop_ = loop; }
TaskNode* GetTaskNode() const { return node_; } TaskNode* GetTaskNode() const { return node_; }
...@@ -101,28 +101,16 @@ class Interceptor { ...@@ -101,28 +101,16 @@ class Interceptor {
std::shared_ptr<framework::GarbageCollector> gc_{nullptr}; std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
Carrier* carrier_; Carrier* carrier_;
TaskLoop* loop_;
private: private:
// pool the local mailbox, parse the Message void LoopOnce();
void PoolTheMailbox();
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool FetchRemoteMailbox();
// interceptor handle which process message // interceptor handle which process message
MsgHandle handle_{nullptr}; MsgHandle handle_{nullptr};
// interceptor runs PoolTheMailbox() function to poll local mailbox std::mutex mutex_;
std::thread interceptor_thread_; std::deque<InterceptorMessage> messages_;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
framework::BlockingQueue<InterceptorMessage> remote_mailbox_;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::deque<InterceptorMessage> local_mailbox_;
int64_t already_run_times_{0}; int64_t already_run_times_{0};
int64_t used_slot_nums_{0}; int64_t used_slot_nums_{0};
......
...@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank, ...@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
return true; return true;
} }
void MessageBus::TestConnection() { void MessageBus::IncreaseBarrierCount() {
InterceptorMessage ctrl_msg; VLOG(3) << "IncreaseBarrierCount";
ctrl_msg.set_ctrl_message(true); {
ctrl_msg.set_src_id(rank_); std::unique_lock<std::mutex> lock(mutex_);
for (const auto& dst_rank_pair : rank_to_addr_) { ++count_;
int64_t dst_rank = dst_rank_pair.first; cv_.notify_one();
if (dst_rank != rank_) { }
ctrl_msg.set_dst_id(dst_rank); VLOG(3) << "End IncreaseBarrierCount";
VLOG(3) << "Send control message bus from rank " << rank_ << " to rank " }
<< dst_rank;
while (!Send(dst_rank, ctrl_msg)) { void MessageBus::Barrier() {
// gather to root
if (rank_ != 0) {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(rank_);
ctrl_msg.set_dst_id(0);
VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0";
while (!Send(0, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
} else {
VLOG(3) << "Barrier 0 wait others rank ready";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return count_ == static_cast<int>(rank_to_addr_.size() - 1);
});
count_ = 0;
}
// scatter from root
if (rank_ == 0) {
for (int i = 1; i < static_cast<int>(rank_to_addr_.size()); ++i) {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(0);
ctrl_msg.set_dst_id(i);
VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i;
while (!Send(i, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000)); std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(3) << "Message bus has connected to rank: " << dst_rank << ".";
} }
} else {
VLOG(3) << "Barrier " << rank_ << " wait others rank ready";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return count_ == 1; });
count_ = 0;
} }
} }
...@@ -151,7 +183,6 @@ void MessageBus::ListenPort() { ...@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
interval += 500; interval += 500;
} }
LOG(INFO) << "Message bus's listen port thread starts successful."; LOG(INFO) << "Message bus's listen port thread starts successful.";
TestConnection();
#else #else
LOG(WARNING) LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is " << "Fleet executor's ListenPort() is a fake function when Paddle is "
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <condition_variable>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <thread> #include <thread>
...@@ -51,14 +52,15 @@ class MessageBus final { ...@@ -51,14 +52,15 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst // called by Interceptor, send InterceptorMessage to dst
bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message); bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message);
void IncreaseBarrierCount();
void Barrier();
private: private:
DISABLE_COPY_AND_ASSIGN(MessageBus); DISABLE_COPY_AND_ASSIGN(MessageBus);
// function keep listen the port and handle the message // function keep listen the port and handle the message
void ListenPort(); void ListenPort();
void TestConnection();
const std::string& GetAddr(int64_t rank) const; const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
...@@ -84,6 +86,11 @@ class MessageBus final { ...@@ -84,6 +86,11 @@ class MessageBus final {
// brpc server // brpc server
brpc::Server server_; brpc::Server server_;
#endif #endif
// for barrier
std::mutex mutex_;
std::condition_variable cv_;
int count_{0};
}; };
} // namespace distributed } // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
thread_local TaskLoop* TaskLoop::thread_local_loop_ = nullptr;
TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; }
TaskLoop::TaskLoop()
: looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) {
PADDLE_ENFORCE_EQ(
thread_local_loop_, nullptr,
platform::errors::AlreadyExists("Another TaskLoop is already init."));
thread_local_loop_ = this;
}
TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; }
void TaskLoop::Loop() {
PADDLE_ENFORCE_EQ(looping_, false,
platform::errors::PreconditionNotMet(
"Loop can only execute in one loop thread"));
AssertInLoopThread();
looping_ = true;
quit_ = false;
while (!quit_) {
auto tasks = tasks_.PopAll();
for (auto& task : tasks) {
task();
}
}
looping_ = false;
}
void TaskLoop::Quit() {
quit_ = true;
if (!IsInLoopThread()) WakeUp();
}
void TaskLoop::RunInLoop(Functor cb) {
if (IsInLoopThread()) {
cb();
} else {
QueueInLoop(cb);
}
}
void TaskLoop::QueueInLoop(Functor cb) { tasks_.Push(cb); }
void TaskLoop::WakeUp() {
Functor task([] {});
QueueInLoop(task);
}
void TaskLoop::AbortNotInLoopThread() {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This TaskLoop was created in thread %d, but current thread is %d",
thread_id_, std::this_thread::get_id()));
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
namespace paddle {
namespace distributed {
class TaskLoop {
public:
static TaskLoop* GetTaskLoopOfCurrentThread();
using Functor = std::function<void()>;
TaskLoop();
~TaskLoop();
void Loop();
void Quit();
void RunInLoop(Functor cb);
void QueueInLoop(Functor cb);
template <class F, class... Args>
auto Enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> task_future = task->get_future();
tasks_.Push([task]() { (*task)(); });
return task_future;
}
void WakeUp();
bool IsInLoopThread() const {
return thread_id_ == std::this_thread::get_id();
}
void AssertInLoopThread() {
if (!IsInLoopThread()) {
AbortNotInLoopThread();
}
}
private:
void AbortNotInLoopThread();
static thread_local TaskLoop* thread_local_loop_;
bool looping_;
std::atomic<bool> quit_;
std::thread::id thread_id_;
framework::BlockingQueue<Functor> tasks_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThread::TaskLoopThread() : start_(false), loop_(nullptr) {}
TaskLoopThread::~TaskLoopThread() {
if (loop_ != nullptr) {
loop_->Quit();
thread_.join();
}
}
TaskLoop* TaskLoopThread::StartLoop() {
PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet(
"thread is already running."));
start_ = true;
thread_ = std::thread([this]() { Loop(); });
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return loop_ != nullptr; });
return loop_;
}
void TaskLoopThread::Loop() {
TaskLoop loop;
{
std::unique_lock<std::mutex> lock(mutex_);
loop_ = &loop;
cv_.notify_one();
}
loop.Loop();
std::unique_lock<std::mutex> lock(mutex_);
loop_ = nullptr;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread {
public:
TaskLoopThread();
~TaskLoopThread();
TaskLoop* StartLoop();
private:
void Loop();
bool start_;
TaskLoop* loop_;
std::thread thread_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThreadPool::TaskLoopThreadPool() : TaskLoopThreadPool(1) {}
TaskLoopThreadPool::TaskLoopThreadPool(int thread_num)
: start_(false), thread_num_(thread_num) {}
TaskLoopThreadPool::~TaskLoopThreadPool() = default;
void TaskLoopThreadPool::Start() {
PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet(
"thread pool is already start."));
PADDLE_ENFORCE_GT(
thread_num_, 0,
platform::errors::InvalidArgument(
"thread num must greater than 0, but now is %d", thread_num_));
start_ = true;
for (int i = 0; i < thread_num_; ++i) {
threads_.emplace_back(new TaskLoopThread());
loops_.push_back(threads_[i]->StartLoop());
}
}
TaskLoop* TaskLoopThreadPool::GetLoop(int tid) {
PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet(
"thread pool must start first."));
PADDLE_ENFORCE_GE(tid, 0, platform::errors::OutOfRange(
"tid must >= 0, but now is %d", tid));
PADDLE_ENFORCE_LT(tid, thread_num_,
platform::errors::OutOfRange(
"tid must < thread_num, but now tid=%d thread_num=%d",
tid, thread_num_));
return loops_[tid];
}
std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() {
PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet(
"thread pool must start first."));
return loops_;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread;
class TaskLoopThreadPool {
public:
TaskLoopThreadPool();
explicit TaskLoopThreadPool(int thread_num);
~TaskLoopThreadPool();
void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
void Start();
TaskLoop* GetLoop(int tid);
std::vector<TaskLoop*> GetAllLoops();
private:
bool start_;
int thread_num_;
std::vector<std::unique_ptr<TaskLoopThread>> threads_;
std::vector<TaskLoop*> loops_;
};
} // namespace distributed
} // namespace paddle
...@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) { ...@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor::CreateCarrier(0, interceptor_id_to_rank); FleetExecutor::CreateCarrier(0, interceptor_id_to_rank);
carrier->SetCreatingFlag(false); carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
Interceptor* a = carrier->SetInterceptor( Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr)); 0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier->Barrier();
InterceptorMessage msg; InterceptorMessage msg;
a->Send(1, msg); a->Send(1, msg);
carrier->Wait(); carrier->Wait();
...@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) { ...@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor::CreateCarrier(1, interceptor_id_to_rank); FleetExecutor::CreateCarrier(1, interceptor_id_to_rank);
carrier->SetCreatingFlag(false); carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>(); auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetMsgBus(msg_bus); carrier->SetMsgBus(msg_bus);
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor( carrier->SetInterceptor(
1, InterceptorFactory::Create("PingPong", 1, nullptr)); 1, InterceptorFactory::Create("PingPong", 1, nullptr));
carrier->Barrier();
carrier->Wait(); carrier->Wait();
int status; int status;
int ret = waitpid(pid, &status, 0); int ret = waitpid(pid, &status, 0);
......
...@@ -81,6 +81,16 @@ class BlockingQueue { ...@@ -81,6 +81,16 @@ class BlockingQueue {
std::swap(*empty_queue, q_); std::swap(*empty_queue, q_);
} }
std::deque<T> PopAll() {
std::deque<T> ret;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return !q_.empty(); });
std::swap(ret, q_);
}
return ret;
}
T Pop() { T Pop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); }); cv_.wait(lock, [=] { return !q_.empty(); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册