diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index a588cb417aebe94bd4aeda02b1bc8ba07a04b960..a471c83115f2394ad81c9cf2053919211461a0c2 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -59,8 +59,8 @@ ThreadPool::~ThreadPool() { // notify all threads to stop running std::lock_guard l(mutex_); running_ = false; - scheduled_.notify_all(); } + scheduled_.notify_all(); for (auto& t : threads_) { t->join(); @@ -75,10 +75,14 @@ void ThreadPool::TaskLoop() { scheduled_.wait( lock, [this] { return !this->tasks_.empty() || !this->running_; }); - if (!running_ || tasks_.empty()) { + if (!running_ && tasks_.empty()) { return; } + if (tasks_.empty()) { + PADDLE_THROW("This thread has no task to Run"); + } + // pop a task from the task queue auto task = std::move(tasks_.front()); tasks_.pop(); diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 0687e628aaa4fb7b2e67938fa09a319c8bb35aff..7a51d18fbbf65f68725aa86a6a0ce4d15dff5673 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -58,7 +58,7 @@ class ThreadPool { ~ThreadPool(); // Run pushes a function to the task queue and returns a std::future - // object. To wait for the completion of the task, call + // object. To wait for the completion of the task, call // std::future::wait(). template std::future Run(Callback fn) { @@ -69,7 +69,6 @@ class ThreadPool { template std::future> RunAndGetException( Callback fn) { - std::unique_lock lock(mutex_); Task task([fn]() -> std::unique_ptr { try { fn(); @@ -84,7 +83,13 @@ class ThreadPool { return nullptr; }); std::future> f = task.get_future(); - tasks_.push(std::move(task)); + { + std::unique_lock lock(mutex_); + if (!running_) { + PADDLE_THROW("enqueue on stopped ThreadPool"); + } + tasks_.push(std::move(task)); + } scheduled_.notify_one(); return f; }