提交 784a19ec 编写于 作者: X Xin Pan

fix some thread-safty issue and simplify threadpool

test=develop
上级 2256fae4
...@@ -34,6 +34,11 @@ ThreadPool* ThreadPool::GetInstance() { ...@@ -34,6 +34,11 @@ ThreadPool* ThreadPool::GetInstance() {
return threadpool_.get(); return threadpool_.get();
} }
void ThreadPool::Reset() {
threadpool_.reset(nullptr);
ThreadPool::Init();
}
void ThreadPool::Init() { void ThreadPool::Init() {
if (threadpool_.get() == nullptr) { if (threadpool_.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number // TODO(Yancey1989): specify the max threads number
...@@ -59,6 +64,7 @@ ThreadPool::ThreadPool(int num_threads) ...@@ -59,6 +64,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all(); scheduled_.notify_all();
} }
...@@ -69,19 +75,18 @@ ThreadPool::~ThreadPool() { ...@@ -69,19 +75,18 @@ ThreadPool::~ThreadPool() {
} }
} }
void ThreadPool::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (running_) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) { scheduled_.wait(
break; lock, [this] { return !this->tasks_.empty() || !this->running_; });
std::lock_guard<std::mutex> l(mutex_);
if (!running_ || tasks_.empty()) {
return;
} }
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
...@@ -91,14 +96,6 @@ void ThreadPool::TaskLoop() { ...@@ -91,14 +96,6 @@ void ThreadPool::TaskLoop() {
// run the task // run the task
task(); task();
{
std::unique_lock<std::mutex> lock(mutex_);
++idle_threads_;
if (Done()) {
completed_.notify_all();
}
}
} }
} }
......
...@@ -55,16 +55,10 @@ class ThreadPool { ...@@ -55,16 +55,10 @@ class ThreadPool {
// Returns the singleton of ThreadPool. // Returns the singleton of ThreadPool.
static ThreadPool* GetInstance(); static ThreadPool* GetInstance();
~ThreadPool(); // delete current thread pool and create a new one.
static void Reset();
// Returns the number of threads created by the constructor.
size_t Threads() const { return total_threads_; }
// Returns the number of currently idle threads. ~ThreadPool();
size_t IdleThreads() {
std::unique_lock<std::mutex> lock(mutex_);
return idle_threads_;
}
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
...@@ -94,25 +88,13 @@ class ThreadPool { ...@@ -94,25 +88,13 @@ class ThreadPool {
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
// Wait until all the tasks are completed.
void Wait();
private: private:
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
// The constructor starts threads to run TaskLoop, which retrieves // The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue. // and runs tasks from the queue.
void TaskLoop(); void TaskLoop();
......
...@@ -52,6 +52,6 @@ TEST(ThreadPool, ConcurrentRun) { ...@@ -52,6 +52,6 @@ TEST(ThreadPool, ConcurrentRun) {
for (auto& t : threads) { for (auto& t : threads) {
t.join(); t.join();
} }
pool->Wait(); framework::ThreadPool::Reset();
EXPECT_EQ(sum, ((n + 1) * n) / 2); EXPECT_EQ(sum, ((n + 1) * n) / 2);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册