From 26cfc634b9f4dc02b051b49f54e33b57938e5ff2 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 12 Apr 2018 14:48:26 +0800 Subject: [PATCH] multi stream thread pool --- paddle/fluid/framework/threadpool.cc | 10 +++++++--- paddle/fluid/framework/threadpool.h | 10 +++++----- paddle/fluid/operators/detail/grpc_server.cc | 2 +- .../paddle/fluid/tests/book/test_recognize_digits.py | 1 - 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 0a8377cc47..109c2c745c 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -14,8 +14,12 @@ #include "paddle/fluid/framework/threadpool.h" +#include "gflags/gflags.h" #include "paddle/fluid/platform/enforce.h" +DEFINE_int32(io_threadpool_size, 100, + "number of threads used for doing IO, default 100"); + namespace paddle { namespace framework { @@ -94,15 +98,15 @@ void ThreadPool::TaskLoop() { std::unique_ptr MultiStreamThreadPool::io_threadpool_(nullptr); std::once_flag MultiStreamThreadPool::io_init_flag_; -MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() { +ThreadPool* MultiStreamThreadPool::GetInstanceIO() { std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO); - return static_cast(io_threadpool_.get()); + return io_threadpool_.get(); } void MultiStreamThreadPool::InitIO() { if (io_threadpool_.get() == nullptr) { // TODO(typhoonzero1986): make this configurable - io_threadpool_.reset(new ThreadPool(100)); + io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size)); } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 0a60488d9f..1cc058834c 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once -#include +#include // NOLINT #include -#include -#include +#include // NOLINT +#include // NOLINT #include -#include +#include // NOLINT #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" @@ -137,7 +137,7 @@ class ThreadPool { class MultiStreamThreadPool : ThreadPool { public: - static MultiStreamThreadPool* GetInstanceIO(); + static ThreadPool* GetInstanceIO(); static void InitIO(); private: diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index d5fc163bc2..36dad5dd43 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() { std::function prefetch_register = std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); + // TODO(wuyi): Run these "HandleRequest" in thread pool t_send_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); - t_get_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py index e4997b4069..5ec6890c1b 100644 --- a/python/paddle/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/fluid/tests/book/test_recognize_digits.py @@ -157,7 +157,6 @@ def train(nn_type, for ip in pserver_ips.split(","): eplist.append(':'.join([ip, port])) pserver_endpoints = ",".join(eplist) # ip:port,ip:port... - pserver_endpoints = os.getenv("PSERVERS") trainers = int(os.getenv("TRAINERS")) current_endpoint = os.getenv("POD_IP") + ":" + port trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) -- GitLab