diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index a588cb417aebe94bd4aeda02b1bc8ba07a04b960..fcec955360f1c681a62929e904d5736854a8ffad 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) { ThreadPool::~ThreadPool() { { // notify all threads to stop running - std::lock_guard l(mutex_); + std::unique_lock l(mutex_); running_ = false; - scheduled_.notify_all(); } + scheduled_.notify_all(); for (auto& t : threads_) { t->join(); @@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() { void ThreadPool::TaskLoop() { while (true) { - std::unique_lock lock(mutex_); + Task task; - scheduled_.wait( - lock, [this] { return !this->tasks_.empty() || !this->running_; }); + { + std::unique_lock lock(mutex_); + scheduled_.wait( + lock, [this] { return !this->tasks_.empty() || !this->running_; }); - if (!running_ || tasks_.empty()) { - return; - } + if (!running_ && tasks_.empty()) { + return; + } + + if (tasks_.empty()) { + PADDLE_THROW("This thread has no task to Run"); + } - // pop a task from the task queue - auto task = std::move(tasks_.front()); - tasks_.pop(); - lock.unlock(); + // pop a task from the task queue + task = std::move(tasks_.front()); + tasks_.pop(); + } // run the task task(); diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 0687e628aaa4fb7b2e67938fa09a319c8bb35aff..7a51d18fbbf65f68725aa86a6a0ce4d15dff5673 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -58,7 +58,7 @@ class ThreadPool { ~ThreadPool(); // Run pushes a function to the task queue and returns a std::future - // object. To wait for the completion of the task, call + // object. To wait for the completion of the task, call // std::future::wait(). template std::future Run(Callback fn) { @@ -69,7 +69,6 @@ class ThreadPool { template std::future> RunAndGetException( Callback fn) { - std::unique_lock lock(mutex_); Task task([fn]() -> std::unique_ptr { try { fn(); @@ -84,7 +83,13 @@ class ThreadPool { return nullptr; }); std::future> f = task.get_future(); - tasks_.push(std::move(task)); + { + std::unique_lock lock(mutex_); + if (!running_) { + PADDLE_THROW("enqueue on stopped ThreadPool"); + } + tasks_.push(std::move(task)); + } scheduled_.notify_one(); return f; }