diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 0a8377cc4794be29af255067517d9c6841aea108..109c2c745c68756de5e1e4384bd8c717d74d158a 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -14,8 +14,12 @@ #include "paddle/fluid/framework/threadpool.h" +#include "gflags/gflags.h" #include "paddle/fluid/platform/enforce.h" +DEFINE_int32(io_threadpool_size, 100, + "number of threads used for doing IO, default 100"); + namespace paddle { namespace framework { @@ -94,15 +98,15 @@ void ThreadPool::TaskLoop() { std::unique_ptr MultiStreamThreadPool::io_threadpool_(nullptr); std::once_flag MultiStreamThreadPool::io_init_flag_; -MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() { +ThreadPool* MultiStreamThreadPool::GetInstanceIO() { std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO); - return static_cast(io_threadpool_.get()); + return io_threadpool_.get(); } void MultiStreamThreadPool::InitIO() { if (io_threadpool_.get() == nullptr) { // TODO(typhoonzero1986): make this configurable - io_threadpool_.reset(new ThreadPool(100)); + io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 0a60488d9f629c253bf66f1704fc4cde9e9b65a0..1cc058834c0f4313766773373719a8825f4285b9 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include +#include // NOLINT #include -#include -#include +#include // NOLINT +#include // NOLINT #include -#include +#include // NOLINT #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" @@ -137,7 +137,7 @@ class ThreadPool { class MultiStreamThreadPool : ThreadPool { public: - static MultiStreamThreadPool* GetInstanceIO(); + static ThreadPool* GetInstanceIO(); static void InitIO(); private: diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index d5fc163bc25409e0607b149b61c6266b38119d9d..36dad5dd43a6a0ab57633034cc45b6261c10b3fd 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() { std::function prefetch_register = std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + // TODO(wuyi): Run these "HandleRequest" in thread pool t_send_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); - t_get_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index e4997b4069f60ff4382b4254bc026ae8ae29b345..5ec6890c1b0dabd2804a92071b63c9610299e67c 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -157,7 +157,6 @@ def train(nn_type, for ip in pserver_ips.split(","): eplist.append(':'.join([ip, port])) pserver_endpoints = ",".join(eplist) # ip:port,ip:port... - pserver_endpoints = os.getenv("PSERVERS") trainers = int(os.getenv("TRAINERS")) current_endpoint = os.getenv("POD_IP") + ":" + port trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))