未验证 提交 4fb3c676 编写于 作者: Y Yi Wang 提交者: GitHub

Polish threadpool (#7918)

* Polish threadpool

* Add #include <vector>

* Rename variables

* Rename variables

* clang-format
上级 3646be7b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/threadpool.h" #include "paddle/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr); std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
std::once_flag ThreadPool::init_flag; std::once_flag ThreadPool::init_flag_;
ThreadPool* ThreadPool::GetInstance() {
std::call_once(init_flag_, &ThreadPool::Init);
return threadpool_.get();
}
void ThreadPool::Init() {
if (threadpool_.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number
int num_threads = std::thread::hardware_concurrency();
PADDLE_ENFORCE_GT(num_threads, 0);
threadpool_.reset(new ThreadPool(num_threads));
}
}
ThreadPool::ThreadPool(int num_threads)
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
threads_.resize(num_threads);
for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
}
}
ThreadPool::~ThreadPool() {
{
// notify all threads to stop running
running_ = false;
scheduled_.notify_all();
}
for (auto& t : threads_) {
t->join();
t.reset(nullptr);
}
}
void ThreadPool::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
void ThreadPool::TaskLoop() {
while (running_) {
std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) {
break;
}
// pop a task from the task queue
auto task = std::move(tasks_.front());
tasks_.pop();
--idle_threads_;
lock.unlock();
// run the task
task();
{
std::unique_lock<std::mutex> lock(mutex_);
++idle_threads_;
if (Done()) {
completed_.notify_all();
}
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,52 +20,36 @@ limitations under the License. */ ...@@ -20,52 +20,36 @@ limitations under the License. */
#include <mutex> #include <mutex>
#include <queue> #include <queue>
#include <thread> #include <thread>
#include <vector>
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
class ThreadPool { class ThreadPool {
public: public:
typedef std::packaged_task<void()> Task; typedef std::packaged_task<void()> Task;
/** // Returns the singleton of ThreadPool.
* @brief Get a instance of threadpool, the thread number will static ThreadPool* GetInstance();
* be specified as the number of hardware thread contexts
*/
static ThreadPool* GetInstance() {
std::call_once(init_flag, &ThreadPool::Init);
return threadpool.get();
}
~ThreadPool() { ~ThreadPool();
{
// notify all threads to stop running
running_ = false;
scheduled_.notify_all();
}
for (auto& t : threads_) {
t->join();
t.reset(nullptr);
}
}
int GetNumThreads() const { return num_threads_; } // Returns the number of threads created by the constructor.
size_t Threads() const { return total_threads_; }
int GetAvailable() { // Returns the number of currently idle threads.
size_t IdleThreads() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
return available_; return idle_threads_;
} }
/** // Run pushes a function to the task queue and returns a std::future
* @brief Push a function to the queue, and will be scheduled and // object. To wait for the completion of the task, call
* executed if a thread is available. // std::future::wait().
* @param[in] Task, will be pushed to the task queue.
* @return std::future<void>, we could wait for the task finished by
* f.wait().
*/
template <typename Callback> template <typename Callback>
std::future<void> Run(Callback fn) { std::future<void> Run(Callback fn) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -77,84 +61,40 @@ class ThreadPool { ...@@ -77,84 +61,40 @@ class ThreadPool {
return f; return f;
} }
/** // Wait until all the tasks are completed.
* @brief Wait until all the tasks are completed. void Wait();
*/
void Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
private: private:
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
explicit ThreadPool(int num_threads) explicit ThreadPool(int num_threads);
: num_threads_(num_threads), available_(num_threads), running_(true) {
threads_.resize(num_threads);
for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
}
}
/** // If the task queue is empty and avaialbe is equal to the number of
* @brief If the task queue is empty and avaialbe // threads, means that all tasks are completed. Note: this function
* is equal to the number of threads, means that // is not thread-safe. Returns true if all tasks are completed.
* all tasks are completed. // Note: don't delete the data member total_threads_ and use
* // threads_.size() instead; because you'd need to lock the mutex
* Note: this function is not thread-safe. // before accessing threads_.
* bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
* @return true if all tasks are completed.
*/
bool Done() { return tasks_.empty() && available_ == num_threads_; }
void TaskLoop() {
while (running_) {
std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) {
break;
}
// pop a task from the task queue
auto task = std::move(tasks_.front());
tasks_.pop();
--available_;
lock.unlock();
// run the task
task();
{
std::unique_lock<std::mutex> lock(mutex_);
++available_;
if (Done()) {
completed_.notify_all();
}
}
}
}
static void Init() { // The constructor starts threads to run TaskLoop, which retrieves
if (threadpool.get() == nullptr) { // and runs tasks from the queue.
// TODO(Yancey1989): specify the max threads number void TaskLoop();
int num_threads = std::thread::hardware_concurrency();
PADDLE_ENFORCE_GT(num_threads, 0); // Init is called by GetInstance.
threadpool.reset(new ThreadPool(num_threads)); static void Init();
}
}
private: private:
static std::unique_ptr<ThreadPool> threadpool; static std::unique_ptr<ThreadPool> threadpool_;
static std::once_flag init_flag; static std::once_flag init_flag_;
int num_threads_;
int available_;
bool running_;
std::queue<Task> tasks_;
std::vector<std::unique_ptr<std::thread>> threads_; std::vector<std::unique_ptr<std::thread>> threads_;
const size_t total_threads_;
size_t idle_threads_;
std::queue<Task> tasks_;
std::mutex mutex_; std::mutex mutex_;
bool running_;
std::condition_variable scheduled_; std::condition_variable scheduled_;
std::condition_variable completed_; std::condition_variable completed_;
}; };
......
...@@ -22,11 +22,7 @@ namespace framework = paddle::framework; ...@@ -22,11 +22,7 @@ namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) { void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
auto f = pool->Run([&sum]() { sum.fetch_add(1); }); fs.push_back(framework::Async([&sum]() { sum.fetch_add(1); }));
fs.push_back(std::move(f));
}
for (auto& f : fs) {
f.wait();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册