diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 9854d618d2b29ed123833f55198179638c95d6db..0a8377cc4794be29af255067517d9c6841aea108 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -91,5 +91,20 @@ void ThreadPool::TaskLoop() { } } +std::unique_ptr MultiStreamThreadPool::io_threadpool_(nullptr); +std::once_flag MultiStreamThreadPool::io_init_flag_; + +MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() { + std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO); + return static_cast(io_threadpool_.get()); +} + +void MultiStreamThreadPool::InitIO() { + if (io_threadpool_.get() == nullptr) { + // TODO(typhoonzero1986): make this configurable + io_threadpool_.reset(new ThreadPool(100)); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index f9dce7105e32ff0ba03d03f8faaac3a4ed1a3595..5d437594ab754ce6ee6becea3574c58cc59cbc63 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -135,6 +135,17 @@ class ThreadPool { std::condition_variable completed_; }; +class MultiStreamThreadPool : ThreadPool { + public: + static MultiStreamThreadPool* GetInstanceIO(); + static void InitIO(); + + private: + // NOTE: threadpool in base will be inhereted here. + static std::unique_ptr io_threadpool_; + static std::once_flag io_init_flag_; +}; + // Run a function asynchronously. // NOTE: The function must return void. If the function need to return a value, // you can use lambda to capture a value pointer. @@ -143,5 +154,10 @@ std::future Async(Callback callback) { return ThreadPool::GetInstance()->Run(callback); } +template +std::future AsyncIO(Callback callback) { + return MultiStreamThreadPool::GetInstanceIO()->Run(callback); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index d79ba6d291950e1f089eb11713bd1c3e4d154b27..3f96ce37185bcc7aa0baa9cb9b6fada3c838c21f 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -33,7 +33,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { + framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, + this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -88,7 +89,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, + this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -131,8 +133,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; @@ -195,7 +197,7 @@ bool RPCClient::Wait() { std::vector> waits(req_count_); for (int i = 0; i < req_count_; i++) { - waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); } for (int i = 0; i < req_count_; i++) {