diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 18cdca3a658a6a89e6ab637a7f38825756acfea8..3d42aea229dbd26fa69a593babe10f76b478c1e2 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -34,6 +34,11 @@ ThreadPool* ThreadPool::GetInstance() { return threadpool_.get(); } +void ThreadPool::Reset() { + threadpool_.reset(nullptr); + ThreadPool::Init(); +} + void ThreadPool::Init() { if (threadpool_.get() == nullptr) { // TODO(Yancey1989): specify the max threads number @@ -59,6 +64,7 @@ ThreadPool::ThreadPool(int num_threads) ThreadPool::~ThreadPool() { { // notify all threads to stop running + std::lock_guard l(mutex_); running_ = false; scheduled_.notify_all(); } @@ -69,19 +75,18 @@ ThreadPool::~ThreadPool() { } } -void ThreadPool::Wait() { - std::unique_lock lock(mutex_); - completed_.wait(lock, [=] { return Done() == true; }); -} - void ThreadPool::TaskLoop() { - while (running_) { + while (true) { std::unique_lock lock(mutex_); - scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; }); - if (!running_) { - break; + scheduled_.wait( + lock, [this] { return !this->tasks_.empty() || !this->running_; }); + + std::lock_guard l(mutex_); + if (!running_ || tasks_.empty()) { + return; } + // pop a task from the task queue auto task = std::move(tasks_.front()); tasks_.pop(); @@ -91,14 +96,6 @@ void ThreadPool::TaskLoop() { // run the task task(); - - { - std::unique_lock lock(mutex_); - ++idle_threads_; - if (Done()) { - completed_.notify_all(); - } - } } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 94111ee335b1a5df327b3e46d62069b4735c54f6..459aba9ef41039a3423a9dd93b2689a94916c774 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -55,16 +55,10 @@ class ThreadPool { // Returns the singleton of ThreadPool. static ThreadPool* GetInstance(); - ~ThreadPool(); - - // Returns the number of threads created by the constructor. - size_t Threads() const { return total_threads_; } + // delete current thread pool and create a new one. + static void Reset(); - // Returns the number of currently idle threads. - size_t IdleThreads() { - std::unique_lock lock(mutex_); - return idle_threads_; - } + ~ThreadPool(); // Run pushes a function to the task queue and returns a std::future // object. To wait for the completion of the task, call @@ -94,25 +88,13 @@ class ThreadPool { }); std::future> f = task.get_future(); tasks_.push(std::move(task)); - lock.unlock(); scheduled_.notify_one(); return f; } - // Wait until all the tasks are completed. - void Wait(); - private: DISABLE_COPY_AND_ASSIGN(ThreadPool); - // If the task queue is empty and avaialbe is equal to the number of - // threads, means that all tasks are completed. Note: this function - // is not thread-safe. Returns true if all tasks are completed. - // Note: don't delete the data member total_threads_ and use - // threads_.size() instead; because you'd need to lock the mutex - // before accessing threads_. - bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; } - // The constructor starts threads to run TaskLoop, which retrieves // and runs tasks from the queue. void TaskLoop(); diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 27a4ffd4fcbf293a3dea1744b29384d0bee0c137..1d55e011c77d07dc6c637bcf5d3650ef5e696bb2 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -52,6 +52,6 @@ TEST(ThreadPool, ConcurrentRun) { for (auto& t : threads) { t.join(); } - pool->Wait(); + framework::ThreadPool::Reset(); EXPECT_EQ(sum, ((n + 1) * n) / 2); }