提交 26cfc634 编写于 作者: T typhoonzero

multi stream thread pool

上级 70500398
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
DEFINE_int32(io_threadpool_size, 100,
"number of threads used for doing IO, default 100");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -94,15 +98,15 @@ void ThreadPool::TaskLoop() { ...@@ -94,15 +98,15 @@ void ThreadPool::TaskLoop() {
std::unique_ptr<ThreadPool> MultiStreamThreadPool::io_threadpool_(nullptr); std::unique_ptr<ThreadPool> MultiStreamThreadPool::io_threadpool_(nullptr);
std::once_flag MultiStreamThreadPool::io_init_flag_; std::once_flag MultiStreamThreadPool::io_init_flag_;
MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() { ThreadPool* MultiStreamThreadPool::GetInstanceIO() {
std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO); std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO);
return static_cast<MultiStreamThreadPool*>(io_threadpool_.get()); return io_threadpool_.get();
} }
void MultiStreamThreadPool::InitIO() { void MultiStreamThreadPool::InitIO() {
if (io_threadpool_.get() == nullptr) { if (io_threadpool_.get() == nullptr) {
// TODO(typhoonzero1986): make this configurable // TODO(typhoonzero1986): make this configurable
io_threadpool_.reset(new ThreadPool(100)); io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
} }
} }
......
...@@ -14,12 +14,12 @@ limitations under the License. */ ...@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <condition_variable> #include <condition_variable> // NOLINT
#include <functional> #include <functional>
#include <future> #include <future> // NOLINT
#include <mutex> #include <mutex> // NOLINT
#include <queue> #include <queue>
#include <thread> #include <thread> // NOLINT
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -137,7 +137,7 @@ class ThreadPool { ...@@ -137,7 +137,7 @@ class ThreadPool {
class MultiStreamThreadPool : ThreadPool { class MultiStreamThreadPool : ThreadPool {
public: public:
static MultiStreamThreadPool* GetInstanceIO(); static ThreadPool* GetInstanceIO();
static void InitIO(); static void InitIO();
private: private:
......
...@@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::function<void()> prefetch_register = std::function<void()> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_.reset( t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register))); cq_send_.get(), "cq_send", send_register)));
t_get_.reset( t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register))); cq_get_.get(), "cq_get", get_register)));
......
...@@ -157,7 +157,6 @@ def train(nn_type, ...@@ -157,7 +157,6 @@ def train(nn_type,
for ip in pserver_ips.split(","): for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port])) eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port... pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
pserver_endpoints = os.getenv("PSERVERS")
trainers = int(os.getenv("TRAINERS")) trainers = int(os.getenv("TRAINERS"))
current_endpoint = os.getenv("POD_IP") + ":" + port current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册