From f43be75b82582ec5f81c2ceba45eb14128638478 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 2 Apr 2018 20:25:11 +0800 Subject: [PATCH] multi stream thread pool --- paddle/fluid/framework/threadpool.cc | 15 +++++++++++++++ paddle/fluid/framework/threadpool.h | 16 ++++++++++++++++ paddle/fluid/operators/detail/grpc_client.cc | 12 +++++++----- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 9854d618d2b..0a8377cc479 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -91,5 +91,20 @@ void ThreadPool::TaskLoop() { } } +std::unique_ptr MultiStreamThreadPool::io_threadpool_(nullptr); +std::once_flag MultiStreamThreadPool::io_init_flag_; + +MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() { + std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO); + return static_cast(io_threadpool_.get()); +} + +void MultiStreamThreadPool::InitIO() { + if (io_threadpool_.get() == nullptr) { + // TODO(typhoonzero1986): make this configurable + io_threadpool_.reset(new ThreadPool(100)); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index f9dce7105e3..5d437594ab7 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -135,6 +135,17 @@ class ThreadPool { std::condition_variable completed_; }; +class MultiStreamThreadPool : ThreadPool { + public: + static MultiStreamThreadPool* GetInstanceIO(); + static void InitIO(); + + private: + // NOTE: threadpool in base will be inhereted here. + static std::unique_ptr io_threadpool_; + static std::once_flag io_init_flag_; +}; + // Run a function asynchronously. // NOTE: The function must return void. If the function need to return a value, // you can use lambda to capture a value pointer. @@ -143,5 +154,10 @@ std::future Async(Callback callback) { return ThreadPool::GetInstance()->Run(callback); } +template +std::future AsyncIO(Callback callback) { + return MultiStreamThreadPool::GetInstanceIO()->Run(callback); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index d79ba6d2919..3f96ce37185 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -33,7 +33,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { + framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, + this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -88,7 +89,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, + this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -131,8 +133,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; @@ -195,7 +197,7 @@ bool RPCClient::Wait() { std::vector> waits(req_count_); for (int i = 0; i < req_count_; i++) { - waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); } for (int i = 0; i < req_count_; i++) { -- GitLab