diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 9854d618d2b29ed123833f55198179638c95d6db..f26f212d4d5793b88fd1e6d782cdf983bf341879 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -14,8 +14,12 @@ #include "paddle/fluid/framework/threadpool.h" +#include "gflags/gflags.h" #include "paddle/fluid/platform/enforce.h" +DEFINE_int32(io_threadpool_size, 100, + "number of threads used for doing IO, default 100"); + namespace paddle { namespace framework { @@ -91,5 +95,20 @@ void ThreadPool::TaskLoop() { } } +std::unique_ptr ThreadPoolIO::io_threadpool_(nullptr); +std::once_flag ThreadPoolIO::io_init_flag_; + +ThreadPool* ThreadPoolIO::GetInstanceIO() { + std::call_once(io_init_flag_, &ThreadPoolIO::InitIO); + return io_threadpool_.get(); +} + +void ThreadPoolIO::InitIO() { + if (io_threadpool_.get() == nullptr) { + // TODO(typhoonzero1986): make this configurable + io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index f9dce7105e32ff0ba03d03f8faaac3a4ed1a3595..94111ee335b1a5df327b3e46d62069b4735c54f6 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include +#include // NOLINT #include -#include -#include +#include // NOLINT +#include // NOLINT #include -#include +#include // NOLINT #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" @@ -28,6 +28,22 @@ limitations under the License. */ namespace paddle { namespace framework { +struct ExceptionHandler { + mutable std::future> future_; + explicit ExceptionHandler( + std::future>&& f) + : future_(std::move(f)) {} + void operator()() const { + auto ex = this->future_.get(); + if (ex != nullptr) { + LOG(FATAL) << "The exception is thrown inside the thread pool. You " + "should use RunAndGetException to handle the exception.\n" + "The default exception handler is LOG(FATAL)." + << ex->what(); + } + } +}; + // ThreadPool maintains a queue of tasks, and runs them using a fixed // number of threads. class ThreadPool { @@ -87,22 +103,6 @@ class ThreadPool { void Wait(); private: - struct ExceptionHandler { - mutable std::future> future_; - explicit ExceptionHandler( - std::future>&& f) - : future_(std::move(f)) {} - void operator()() const { - auto ex = this->future_.get(); - if (ex != nullptr) { - LOG(FATAL) << "The exception is thrown inside the thread pool. You " - "should use RunAndGetException to handle the exception.\n" - "The default exception handler is LOG(FATAL)." - << ex->what(); - } - } - }; - DISABLE_COPY_AND_ASSIGN(ThreadPool); // If the task queue is empty and avaialbe is equal to the number of @@ -135,6 +135,17 @@ class ThreadPool { std::condition_variable completed_; }; +class ThreadPoolIO : ThreadPool { + public: + static ThreadPool* GetInstanceIO(); + static void InitIO(); + + private: + // NOTE: threadpool in base will be inhereted here. + static std::unique_ptr io_threadpool_; + static std::once_flag io_init_flag_; +}; + // Run a function asynchronously. // NOTE: The function must return void. If the function need to return a value, // you can use lambda to capture a value pointer. @@ -143,5 +154,10 @@ std::future Async(Callback callback) { return ThreadPool::GetInstance()->Run(callback); } +template +std::future AsyncIO(Callback callback) { + return ThreadPoolIO::GetInstanceIO()->Run(callback); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 45f88ec8697d9f3de2612f28889fefc36f7ddbf9..661dfa69fe1580ff3890f12defcd124225be0c06 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -35,7 +35,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { + framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, + this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -89,7 +90,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, + this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -132,8 +134,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - framework::Async([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, ch, this] { + framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, + time_out, ch, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; @@ -196,7 +198,7 @@ bool RPCClient::Wait() { std::vector> waits(req_count_); for (int i = 0; i < req_count_; i++) { - waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); } for (int i = 0; i < req_count_; i++) { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 0b582a08bc0bfbcfdc8f338a6add8edaa5e80818..119e146e078e476b2768a8495ea63e468f952fd2 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -217,10 +217,10 @@ void AsyncGRPCServer::RunSyncUpdate() { std::function prefetch_register = std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + // TODO(wuyi): Run these "HandleRequest" in thread pool t_send_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); - t_get_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index e4997b4069f60ff4382b4254bc026ae8ae29b345..5ec6890c1b0dabd2804a92071b63c9610299e67c 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -157,7 +157,6 @@ def train(nn_type, for ip in pserver_ips.split(","): eplist.append(':'.join([ip, port])) pserver_endpoints = ",".join(eplist) # ip:port,ip:port... - pserver_endpoints = os.getenv("PSERVERS") trainers = int(os.getenv("TRAINERS")) current_endpoint = os.getenv("POD_IP") + ":" + port trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))