提交 26cfc634 编写于 作者: T typhoonzero

multi stream thread pool

上级 70500398
......@@ -14,8 +14,12 @@
#include "paddle/fluid/framework/threadpool.h"
#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_int32(io_threadpool_size, 100,
"number of threads used for doing IO, default 100");
namespace paddle {
namespace framework {
......@@ -94,15 +98,15 @@ void ThreadPool::TaskLoop() {
std::unique_ptr<ThreadPool> MultiStreamThreadPool::io_threadpool_(nullptr);
std::once_flag MultiStreamThreadPool::io_init_flag_;
MultiStreamThreadPool* MultiStreamThreadPool::GetInstanceIO() {
ThreadPool* MultiStreamThreadPool::GetInstanceIO() {
std::call_once(io_init_flag_, &MultiStreamThreadPool::InitIO);
return static_cast<MultiStreamThreadPool*>(io_threadpool_.get());
return io_threadpool_.get();
}
void MultiStreamThreadPool::InitIO() {
if (io_threadpool_.get() == nullptr) {
// TODO(typhoonzero1986): make this configurable
io_threadpool_.reset(new ThreadPool(100));
io_threadpool_.reset(new ThreadPool(FLAGS_io_threadpool_size));
}
}
......
......@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once
#include <condition_variable>
#include <condition_variable> // NOLINT
#include <functional>
#include <future>
#include <mutex>
#include <future> // NOLINT
#include <mutex> // NOLINT
#include <queue>
#include <thread>
#include <thread> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -137,7 +137,7 @@ class ThreadPool {
class MultiStreamThreadPool : ThreadPool {
public:
static MultiStreamThreadPool* GetInstanceIO();
static ThreadPool* GetInstanceIO();
static void InitIO();
private:
......
......@@ -216,10 +216,10 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::function<void()> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
......
......@@ -157,7 +157,6 @@ def train(nn_type,
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
pserver_endpoints = os.getenv("PSERVERS")
trainers = int(os.getenv("TRAINERS"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册