diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 18cdca3a658a6a89e6ab637a7f38825756acfea8..a588cb417aebe94bd4aeda02b1bc8ba07a04b960 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, namespace paddle { namespace framework { - std::unique_ptr ThreadPool::threadpool_(nullptr); std::once_flag ThreadPool::init_flag_; @@ -47,8 +46,7 @@ void ThreadPool::Init() { } } -ThreadPool::ThreadPool(int num_threads) - : total_threads_(num_threads), idle_threads_(num_threads), running_(true) { +ThreadPool::ThreadPool(int num_threads) : running_(true) { threads_.resize(num_threads); for (auto& thread : threads_) { // TODO(Yancey1989): binding the thread on the specify CPU number @@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ThreadPool::~ThreadPool() { { // notify all threads to stop running + std::lock_guard l(mutex_); running_ = false; scheduled_.notify_all(); } @@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { } } -void ThreadPool::Wait() { - std::unique_lock lock(mutex_); - completed_.wait(lock, [=] { return Done() == true; }); -} - void ThreadPool::TaskLoop() { - while (running_) { + while (true) { std::unique_lock lock(mutex_); - scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; }); - if (!running_) { - break; + scheduled_.wait( + lock, [this] { return !this->tasks_.empty() || !this->running_; }); + + if (!running_ || tasks_.empty()) { + return; } + // pop a task from the task queue auto task = std::move(tasks_.front()); tasks_.pop(); - - --idle_threads_; lock.unlock(); // run the task task(); - - { - std::unique_lock lock(mutex_); - ++idle_threads_; - if (Done()) { - completed_.notify_all(); - } - } } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 94111ee335b1a5df327b3e46d62069b4735c54f6..0687e628aaa4fb7b2e67938fa09a319c8bb35aff 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -57,15 +57,6 @@ class ThreadPool { ~ThreadPool(); - // Returns the number of threads created by the constructor. - size_t Threads() const { return total_threads_; } - - // Returns the number of currently idle threads. - size_t IdleThreads() { - std::unique_lock lock(mutex_); - return idle_threads_; - } - // Run pushes a function to the task queue and returns a std::future // object. To wait for the completion of the task, call // std::future::wait(). @@ -94,25 +85,13 @@ class ThreadPool { }); std::future> f = task.get_future(); tasks_.push(std::move(task)); - lock.unlock(); scheduled_.notify_one(); return f; } - // Wait until all the tasks are completed. - void Wait(); - private: DISABLE_COPY_AND_ASSIGN(ThreadPool); - // If the task queue is empty and avaialbe is equal to the number of - // threads, means that all tasks are completed. Note: this function - // is not thread-safe. Returns true if all tasks are completed. - // Note: don't delete the data member total_threads_ and use - // threads_.size() instead; because you'd need to lock the mutex - // before accessing threads_. - bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; } - // The constructor starts threads to run TaskLoop, which retrieves // and runs tasks from the queue. void TaskLoop(); @@ -125,14 +104,11 @@ class ThreadPool { static std::once_flag init_flag_; std::vector> threads_; - const size_t total_threads_; - size_t idle_threads_; std::queue tasks_; std::mutex mutex_; bool running_; std::condition_variable scheduled_; - std::condition_variable completed_; }; class ThreadPoolIO : ThreadPool { diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 27a4ffd4fcbf293a3dea1744b29384d0bee0c137..884d61e23428a0ad758946295ca9c470767e93ef 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -19,10 +19,11 @@ limitations under the License. */ namespace framework = paddle::framework; -void do_sum(framework::ThreadPool* pool, std::atomic* sum, int cnt) { - std::vector> fs; +void do_sum(std::vector>* fs, std::mutex* mu, + std::atomic* sum, int cnt) { for (int i = 0; i < cnt; ++i) { - fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); + std::lock_guard l(*mu); + fs->push_back(framework::Async([sum]() { sum->fetch_add(1); })); } } @@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { } TEST(ThreadPool, ConcurrentRun) { - framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); std::atomic sum(0); std::vector threads; + std::vector> fs; + std::mutex fs_mu; int n = 50; // sum = (n * (n + 1)) / 2 for (int i = 1; i <= n; ++i) { - std::thread t(do_sum, pool, &sum, i); + std::thread t(do_sum, &fs, &fs_mu, &sum, i); threads.push_back(std::move(t)); } for (auto& t : threads) { t.join(); } - pool->Wait(); + for (auto& t : fs) { + t.wait(); + } EXPECT_EQ(sum, ((n + 1) * n) / 2); }