未验证 提交 dba59db7 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add task loop thread pool (#38420)

上级 5b6b88ab
......@@ -10,10 +10,12 @@ else()
set(BRPC_DEPS "")
endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper gflags glog ${BRPC_DEPS})
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
......
......@@ -42,30 +42,17 @@ void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
place_ = place;
root_scope_ = root_scope;
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
CreateInterceptors();
is_init_ = true;
}
void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
InterceptorMessage stop_msg;
// source node STOP is send by carrier, so set src_id=-1
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
Send(stop_msg);
}
// TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join();
}
}
void Carrier::Release() {}
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
......@@ -75,18 +62,9 @@ bool Carrier::EnqueueInterceptorMessage(
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
msg_bus_->IncreaseBarrierCount();
} else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
if (creating_interceptors_) {
std::unique_lock<std::mutex> lock_message(tmp_message_mutex_);
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG(3) << "Receiving message while creating interceptors.";
message_tmp_.emplace_back(interceptor_message);
return true;
}
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
......@@ -94,6 +72,8 @@ bool Carrier::EnqueueInterceptorMessage(
return true;
}
void Carrier::Barrier() { msg_bus_->Barrier(); }
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
......@@ -109,6 +89,11 @@ void Carrier::Wait() {
cond_var_.wait(lock);
}
void Carrier::WakeUp() {
// probably double notify, but ok for ut
cond_var_.notify_all();
}
void Carrier::Start() {
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
......@@ -126,12 +111,11 @@ void Carrier::Start() {
start_msg.set_message_type(DATA_IS_READY);
Send(start_msg);
}
// TODO(wangxi): async step
Wait();
dev_ctx_->Wait();
}
std::condition_variable& Carrier::GetCondVar() { return cond_var_; }
bool Carrier::IsInit() const { return is_init_; }
int64_t Carrier::GetRank(int64_t interceptor_id) const {
......@@ -183,51 +167,19 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id should be unique.",
interceptor_id));
interceptor->RegisterCarrier(this);
// TODO(fleet_exe dev): get loop
auto* loop = thread_pool_.GetLoop(interceptor_id % thread_num_);
PADDLE_ENFORCE_NOT_NULL(
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
return ptr;
}
void Carrier::SetCreatingFlag(bool flag) {
// set the creating flag
creating_flag_mutex_.lock();
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
<< " to " << flag << ".";
creating_interceptors_ = flag;
creating_flag_mutex_.unlock();
if (!flag) {
for (auto& pair : interceptor_idx_to_interceptor_) {
// update the source interceptor id
if (std::find(source_interceptor_ids_.begin(),
source_interceptor_ids_.end(),
pair.first) == source_interceptor_ids_.end()) {
auto task = pair.second->GetTaskNode();
if (task != nullptr && task->upstream().empty()) {
source_interceptor_ids_.emplace_back(pair.first);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
}
void Carrier::HandleTmpMessages() {
// NOTE: It's ok lock on the tmp_message_mutex_ here, when enter this
// `HandleTmpMessages` method, the creating_interceptors_ flag
// must be false, therefore, there won't have conflict with the
// lock on the tmp_message_mutex_ inside `EnqueueInterceptorMessage`
// on the same thread.
std::unique_lock<std::mutex> lock(tmp_message_mutex_);
VLOG(3) << "Carrier has received " << message_tmp_.size()
<< " messages during creating interceptors.";
for (const auto& msg : message_tmp_) {
EnqueueInterceptorMessage(msg);
}
message_tmp_.clear();
}
static std::shared_ptr<framework::GarbageCollector> GetGC(
const platform::Place& place) {
int64_t max_memory_size = framework::GetEagerDeletionThreshold();
......@@ -285,12 +237,6 @@ void Carrier::CreateInterceptors() {
source_interceptor_ids_.emplace_back(interceptor_id);
}
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_flag_mutex_.lock();
creating_interceptors_ = false;
creating_flag_mutex_.unlock();
HandleTmpMessages();
}
} // namespace distributed
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
......@@ -47,7 +48,11 @@ class Carrier final {
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
~Carrier();
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
......@@ -56,6 +61,7 @@ class Carrier final {
void Release();
void Wait();
void WakeUp();
// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
......@@ -67,23 +73,18 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);
void SetCreatingFlag(bool flag);
void SetCreatingFlag(bool flag) {}
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus;
}
std::condition_variable& GetCondVar();
void Start();
bool IsInit() const;
bool Send(const InterceptorMessage& msg);
// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std::mutex run;
void Barrier();
private:
DISABLE_COPY_AND_ASSIGN(Carrier);
......@@ -91,8 +92,6 @@ class Carrier final {
// create each Interceptor
void CreateInterceptors();
void HandleTmpMessages();
int64_t GetRank(int64_t interceptor_id) const;
// interceptor logic id to actually interceptor
......@@ -101,10 +100,6 @@ class Carrier final {
std::vector<int64_t> source_interceptor_ids_;
std::vector<InterceptorMessage> message_tmp_{};
std::mutex tmp_message_mutex_;
bool creating_interceptors_{true};
std::mutex creating_flag_mutex_;
bool is_init_{false};
std::mutex running_mutex_;
......@@ -118,6 +113,9 @@ class Carrier final {
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_;
TaskLoopThreadPool thread_pool_;
};
} // namespace distributed
......
......@@ -170,7 +170,6 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
void ComputeInterceptor::RunOps() {
std::unique_lock<std::mutex> lock(carrier_->run);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
......@@ -198,6 +197,7 @@ void ComputeInterceptor::Run() {
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " is stopping carrier.";
// FIXME(wangxi): with multi sink interceptor
StopCarrier();
}
}
......
......@@ -89,6 +89,11 @@ void FleetExecutor::Init(
CreateCarrier();
InitCarrier();
InitMessageBus();
// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier()->Barrier();
}
void FleetExecutor::InitCarrier() {
......
......@@ -14,26 +14,21 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) {
interceptor_thread_ = std::thread([this]() {
VLOG(3) << "Interceptor " << interceptor_id_
<< " starts the thread pooling it's local mailbox.";
PoolTheMailbox();
});
}
Interceptor::~Interceptor() { Join(); }
void Interceptor::Join() {
if (interceptor_thread_.joinable()) {
interceptor_thread_.join();
}
: interceptor_id_(interceptor_id), node_(node) {}
Interceptor::~Interceptor() {
// FIXME(wangxi): throw in stop function
// std::lock_guard<std::mutex> lock(mutex_);
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
}
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
......@@ -44,25 +39,47 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
handle_(msg);
}
void Interceptor::LoopOnce() {
std::deque<InterceptorMessage> tmp_messages;
{
std::lock_guard<std::mutex> lock(mutex_);
messages_.swap(tmp_messages);
}
PADDLE_ENFORCE_EQ(tmp_messages.empty(), false,
platform::errors::PreconditionNotMet(
"tmp_messages must not empty in task loop"));
for (auto& msg : tmp_messages) {
const MessageType message_type = msg.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << msg.src_id()
<< " with message: " << message_type << ".";
Handle(msg);
}
}
void Interceptor::StopCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
std::condition_variable& cond_var = carrier_->GetCondVar();
// probably double notify, but ok for ut
cond_var.notify_all();
}
int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return interceptor_id_;
carrier_->WakeUp();
}
void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
const InterceptorMessage& message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type()
<< " into " << interceptor_id_ << "'s remote mailbox.";
remote_mailbox_.Push(interceptor_message);
VLOG(3) << "Enqueue message: " << message.message_type() << " into "
<< interceptor_id_ << "'s remote mailbox.";
bool empty = false;
{
std::lock_guard<std::mutex> lock(mutex_);
empty = messages_.empty();
messages_.emplace_back(message);
}
if (empty) {
loop_->QueueInLoop([this]() { LoopOnce(); });
}
}
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
......@@ -73,39 +90,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return carrier_->Send(msg);
}
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
for (;;) {
if (local_mailbox_.empty()) {
// local mailbox is empty, fetch the remote mailbox
VLOG(3) << interceptor_id_ << "'s local mailbox is empty. "
<< "Fetch the remote mailbox.";
PADDLE_ENFORCE_EQ(FetchRemoteMailbox(), true,
platform::errors::InvalidArgument(
"Error encountered when fetch remote mailbox."));
}
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop_front();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
Handle(interceptor_message);
if (stop_) {
// break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break;
}
}
}
bool Interceptor::FetchRemoteMailbox() {
remote_mailbox_.PopAll(&local_mailbox_);
return !local_mailbox_.empty();
}
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
static InterceptorFactory::CreateInterceptorMap interceptorMap;
return interceptorMap;
......
......@@ -38,6 +38,7 @@ namespace distributed {
class TaskNode;
class Carrier;
class TaskLoop;
class Interceptor {
public:
......@@ -50,15 +51,13 @@ class Interceptor {
virtual ~Interceptor();
void Join();
// register interceptor handle
void RegisterMsgHandle(MsgHandle handle);
void Handle(const InterceptorMessage& msg);
// return the interceptor id
int64_t GetInterceptorId() const;
int64_t GetInterceptorId() const { return interceptor_id_; }
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void EnqueueRemoteInterceptorMessage(
......@@ -77,6 +76,7 @@ class Interceptor {
gc_ = gc;
}
void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; }
void RegisterTaskLoop(TaskLoop* loop) { loop_ = loop; }
TaskNode* GetTaskNode() const { return node_; }
......@@ -101,28 +101,16 @@ class Interceptor {
std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
Carrier* carrier_;
TaskLoop* loop_;
private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
bool FetchRemoteMailbox();
void LoopOnce();
// interceptor handle which process message
MsgHandle handle_{nullptr};
// interceptor runs PoolTheMailbox() function to poll local mailbox
std::thread interceptor_thread_;
// remote mailbox, written by EnqueueRemoteMessage()
// read by FetchRemoteMailbox()
framework::BlockingQueue<InterceptorMessage> remote_mailbox_;
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::deque<InterceptorMessage> local_mailbox_;
std::mutex mutex_;
std::deque<InterceptorMessage> messages_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
......
......@@ -105,21 +105,53 @@ bool MessageBus::Send(int64_t dst_rank,
return true;
}
void MessageBus::TestConnection() {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(rank_);
for (const auto& dst_rank_pair : rank_to_addr_) {
int64_t dst_rank = dst_rank_pair.first;
if (dst_rank != rank_) {
ctrl_msg.set_dst_id(dst_rank);
VLOG(3) << "Send control message bus from rank " << rank_ << " to rank "
<< dst_rank;
while (!Send(dst_rank, ctrl_msg)) {
void MessageBus::IncreaseBarrierCount() {
VLOG(3) << "IncreaseBarrierCount";
{
std::unique_lock<std::mutex> lock(mutex_);
++count_;
cv_.notify_one();
}
VLOG(3) << "End IncreaseBarrierCount";
}
void MessageBus::Barrier() {
// gather to root
if (rank_ != 0) {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(rank_);
ctrl_msg.set_dst_id(0);
VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0";
while (!Send(0, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
} else {
VLOG(3) << "Barrier 0 wait others rank ready";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return count_ == static_cast<int>(rank_to_addr_.size() - 1);
});
count_ = 0;
}
// scatter from root
if (rank_ == 0) {
for (int i = 1; i < static_cast<int>(rank_to_addr_.size()); ++i) {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(0);
ctrl_msg.set_dst_id(i);
VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i;
while (!Send(i, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << dst_rank << ".";
}
} else {
VLOG(3) << "Barrier " << rank_ << " wait others rank ready";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return count_ == 1; });
count_ = 0;
}
}
......@@ -151,7 +183,6 @@ void MessageBus::ListenPort() {
interval += 500;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";
TestConnection();
#else
LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is "
......
......@@ -14,6 +14,7 @@
#pragma once
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
......@@ -51,14 +52,15 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message);
void IncreaseBarrierCount();
void Barrier();
private:
DISABLE_COPY_AND_ASSIGN(MessageBus);
// function keep listen the port and handle the message
void ListenPort();
void TestConnection();
const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
......@@ -84,6 +86,11 @@ class MessageBus final {
// brpc server
brpc::Server server_;
#endif
// for barrier
std::mutex mutex_;
std::condition_variable cv_;
int count_{0};
};
} // namespace distributed
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
thread_local TaskLoop* TaskLoop::thread_local_loop_ = nullptr;
TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; }
TaskLoop::TaskLoop()
: looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) {
PADDLE_ENFORCE_EQ(
thread_local_loop_, nullptr,
platform::errors::AlreadyExists("Another TaskLoop is already init."));
thread_local_loop_ = this;
}
TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; }
void TaskLoop::Loop() {
PADDLE_ENFORCE_EQ(looping_, false,
platform::errors::PreconditionNotMet(
"Loop can only execute in one loop thread"));
AssertInLoopThread();
looping_ = true;
quit_ = false;
while (!quit_) {
auto tasks = tasks_.PopAll();
for (auto& task : tasks) {
task();
}
}
looping_ = false;
}
void TaskLoop::Quit() {
quit_ = true;
if (!IsInLoopThread()) WakeUp();
}
void TaskLoop::RunInLoop(Functor cb) {
if (IsInLoopThread()) {
cb();
} else {
QueueInLoop(cb);
}
}
void TaskLoop::QueueInLoop(Functor cb) { tasks_.Push(cb); }
void TaskLoop::WakeUp() {
Functor task([] {});
QueueInLoop(task);
}
void TaskLoop::AbortNotInLoopThread() {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This TaskLoop was created in thread %d, but current thread is %d",
thread_id_, std::this_thread::get_id()));
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
namespace paddle {
namespace distributed {
class TaskLoop {
public:
static TaskLoop* GetTaskLoopOfCurrentThread();
using Functor = std::function<void()>;
TaskLoop();
~TaskLoop();
void Loop();
void Quit();
void RunInLoop(Functor cb);
void QueueInLoop(Functor cb);
template <class F, class... Args>
auto Enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> task_future = task->get_future();
tasks_.Push([task]() { (*task)(); });
return task_future;
}
void WakeUp();
bool IsInLoopThread() const {
return thread_id_ == std::this_thread::get_id();
}
void AssertInLoopThread() {
if (!IsInLoopThread()) {
AbortNotInLoopThread();
}
}
private:
void AbortNotInLoopThread();
static thread_local TaskLoop* thread_local_loop_;
bool looping_;
std::atomic<bool> quit_;
std::thread::id thread_id_;
framework::BlockingQueue<Functor> tasks_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThread::TaskLoopThread() : start_(false), loop_(nullptr) {}
TaskLoopThread::~TaskLoopThread() {
if (loop_ != nullptr) {
loop_->Quit();
thread_.join();
}
}
TaskLoop* TaskLoopThread::StartLoop() {
PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet(
"thread is already running."));
start_ = true;
thread_ = std::thread([this]() { Loop(); });
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return loop_ != nullptr; });
return loop_;
}
void TaskLoopThread::Loop() {
TaskLoop loop;
{
std::unique_lock<std::mutex> lock(mutex_);
loop_ = &loop;
cv_.notify_one();
}
loop.Loop();
std::unique_lock<std::mutex> lock(mutex_);
loop_ = nullptr;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread {
public:
TaskLoopThread();
~TaskLoopThread();
TaskLoop* StartLoop();
private:
void Loop();
bool start_;
TaskLoop* loop_;
std::thread thread_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThreadPool::TaskLoopThreadPool() : TaskLoopThreadPool(1) {}
TaskLoopThreadPool::TaskLoopThreadPool(int thread_num)
: start_(false), thread_num_(thread_num) {}
TaskLoopThreadPool::~TaskLoopThreadPool() = default;
void TaskLoopThreadPool::Start() {
PADDLE_ENFORCE_EQ(start_, false, platform::errors::PreconditionNotMet(
"thread pool is already start."));
PADDLE_ENFORCE_GT(
thread_num_, 0,
platform::errors::InvalidArgument(
"thread num must greater than 0, but now is %d", thread_num_));
start_ = true;
for (int i = 0; i < thread_num_; ++i) {
threads_.emplace_back(new TaskLoopThread());
loops_.push_back(threads_[i]->StartLoop());
}
}
TaskLoop* TaskLoopThreadPool::GetLoop(int tid) {
PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet(
"thread pool must start first."));
PADDLE_ENFORCE_GE(tid, 0, platform::errors::OutOfRange(
"tid must >= 0, but now is %d", tid));
PADDLE_ENFORCE_LT(tid, thread_num_,
platform::errors::OutOfRange(
"tid must < thread_num, but now tid=%d thread_num=%d",
tid, thread_num_));
return loops_[tid];
}
std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() {
PADDLE_ENFORCE_EQ(start_, true, platform::errors::PreconditionNotMet(
"thread pool must start first."));
return loops_;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread;
class TaskLoopThreadPool {
public:
TaskLoopThreadPool();
explicit TaskLoopThreadPool(int thread_num);
~TaskLoopThreadPool();
void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
void Start();
TaskLoop* GetLoop(int tid);
std::vector<TaskLoop*> GetAllLoops();
private:
bool start_;
int thread_num_;
std::vector<std::unique_ptr<TaskLoopThread>> threads_;
std::vector<TaskLoop*> loops_;
};
} // namespace distributed
} // namespace paddle
......@@ -115,10 +115,13 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor::CreateCarrier(0, interceptor_id_to_rank);
carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->SetMsgBus(msg_bus);
// NOTE: need Init msg_bus after carrier SetMsgBus
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier->Barrier();
InterceptorMessage msg;
a->Send(1, msg);
carrier->Wait();
......@@ -127,10 +130,12 @@ TEST(InterceptorTest, PingPong) {
FleetExecutor::CreateCarrier(1, interceptor_id_to_rank);
carrier->SetCreatingFlag(false);
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetMsgBus(msg_bus);
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->SetInterceptor(
1, InterceptorFactory::Create("PingPong", 1, nullptr));
carrier->Barrier();
carrier->Wait();
int status;
int ret = waitpid(pid, &status, 0);
......
......@@ -81,6 +81,16 @@ class BlockingQueue {
std::swap(*empty_queue, q_);
}
std::deque<T> PopAll() {
std::deque<T> ret;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return !q_.empty(); });
std::swap(ret, q_);
}
return ret;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册