未验证 提交 6449faec 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #14259 from jacquesqiao/optimize-thread-pool

Optimize thread pool
...@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) { ...@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) {
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_); std::unique_lock<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all();
} }
scheduled_.notify_all();
for (auto& t : threads_) { for (auto& t : threads_) {
t->join(); t->join();
...@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() { ...@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() {
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); Task task;
{
std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait( scheduled_.wait(
lock, [this] { return !this->tasks_.empty() || !this->running_; }); lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) { if (!running_ && tasks_.empty()) {
return; return;
} }
if (tasks_.empty()) {
PADDLE_THROW("This thread has no task to Run");
}
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
lock.unlock(); }
// run the task // run the task
task(); task();
......
...@@ -69,7 +69,6 @@ class ThreadPool { ...@@ -69,7 +69,6 @@ class ThreadPool {
template <typename Callback> template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException( std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) { Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> { Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try { try {
fn(); fn();
...@@ -84,7 +83,13 @@ class ThreadPool { ...@@ -84,7 +83,13 @@ class ThreadPool {
return nullptr; return nullptr;
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
{
std::unique_lock<std::mutex> lock(mutex_);
if (!running_) {
PADDLE_THROW("enqueue on stopped ThreadPool");
}
tasks_.push(std::move(task)); tasks_.push(std::move(task));
}
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册