提交 f4a76078 编写于 作者: Q Qiao Longfei

optimize thread pool

上级 913b5699
...@@ -59,8 +59,8 @@ ThreadPool::~ThreadPool() { ...@@ -59,8 +59,8 @@ ThreadPool::~ThreadPool() {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_); std::lock_guard<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all();
} }
scheduled_.notify_all();
for (auto& t : threads_) { for (auto& t : threads_) {
t->join(); t->join();
...@@ -75,10 +75,14 @@ void ThreadPool::TaskLoop() { ...@@ -75,10 +75,14 @@ void ThreadPool::TaskLoop() {
scheduled_.wait( scheduled_.wait(
lock, [this] { return !this->tasks_.empty() || !this->running_; }); lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) { if (!running_ && tasks_.empty()) {
return; return;
} }
if (tasks_.empty()) {
PADDLE_THROW("This thread has no task to Run");
}
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
......
...@@ -58,7 +58,7 @@ class ThreadPool { ...@@ -58,7 +58,7 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
template <typename Callback> template <typename Callback>
std::future<void> Run(Callback fn) { std::future<void> Run(Callback fn) {
...@@ -69,7 +69,6 @@ class ThreadPool { ...@@ -69,7 +69,6 @@ class ThreadPool {
template <typename Callback> template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException( std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) { Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> { Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try { try {
fn(); fn();
...@@ -84,7 +83,13 @@ class ThreadPool { ...@@ -84,7 +83,13 @@ class ThreadPool {
return nullptr; return nullptr;
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); {
std::unique_lock<std::mutex> lock(mutex_);
if (!running_) {
PADDLE_THROW("enqueue on stopped ThreadPool");
}
tasks_.push(std::move(task));
}
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册