未验证 提交 bcc9126e 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14056 from panyx0718/fix

Fix threadpool
...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, ...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr); std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
std::once_flag ThreadPool::init_flag_; std::once_flag ThreadPool::init_flag_;
...@@ -47,8 +46,7 @@ void ThreadPool::Init() { ...@@ -47,8 +46,7 @@ void ThreadPool::Init() {
} }
} }
ThreadPool::ThreadPool(int num_threads) ThreadPool::ThreadPool(int num_threads) : running_(true) {
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
threads_.resize(num_threads); threads_.resize(num_threads);
for (auto& thread : threads_) { for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number // TODO(Yancey1989): binding the thread on the specify CPU number
...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all(); scheduled_.notify_all();
} }
...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { ...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
} }
} }
void ThreadPool::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (running_) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) { scheduled_.wait(
break; lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) {
return;
} }
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
--idle_threads_;
lock.unlock(); lock.unlock();
// run the task // run the task
task(); task();
{
std::unique_lock<std::mutex> lock(mutex_);
++idle_threads_;
if (Done()) {
completed_.notify_all();
}
}
} }
} }
......
...@@ -57,15 +57,6 @@ class ThreadPool { ...@@ -57,15 +57,6 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Returns the number of threads created by the constructor.
size_t Threads() const { return total_threads_; }
// Returns the number of currently idle threads.
size_t IdleThreads() {
std::unique_lock<std::mutex> lock(mutex_);
return idle_threads_;
}
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
...@@ -94,25 +85,13 @@ class ThreadPool { ...@@ -94,25 +85,13 @@ class ThreadPool {
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
// Wait until all the tasks are completed.
void Wait();
private: private:
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
// The constructor starts threads to run TaskLoop, which retrieves // The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue. // and runs tasks from the queue.
void TaskLoop(); void TaskLoop();
...@@ -125,14 +104,11 @@ class ThreadPool { ...@@ -125,14 +104,11 @@ class ThreadPool {
static std::once_flag init_flag_; static std::once_flag init_flag_;
std::vector<std::unique_ptr<std::thread>> threads_; std::vector<std::unique_ptr<std::thread>> threads_;
const size_t total_threads_;
size_t idle_threads_;
std::queue<Task> tasks_; std::queue<Task> tasks_;
std::mutex mutex_; std::mutex mutex_;
bool running_; bool running_;
std::condition_variable scheduled_; std::condition_variable scheduled_;
std::condition_variable completed_;
}; };
class ThreadPoolIO : ThreadPool { class ThreadPoolIO : ThreadPool {
......
...@@ -19,10 +19,11 @@ limitations under the License. */ ...@@ -19,10 +19,11 @@ limitations under the License. */
namespace framework = paddle::framework; namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>* sum, int cnt) { void do_sum(std::vector<std::future<void>>* fs, std::mutex* mu,
std::vector<std::future<void>> fs; std::atomic<int>* sum, int cnt) {
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); std::lock_guard<std::mutex> l(*mu);
fs->push_back(framework::Async([sum]() { sum->fetch_add(1); }));
} }
} }
...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { ...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
} }
TEST(ThreadPool, ConcurrentRun) { TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0); std::atomic<int> sum(0);
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::future<void>> fs;
std::mutex fs_mu;
int n = 50; int n = 50;
// sum = (n * (n + 1)) / 2 // sum = (n * (n + 1)) / 2
for (int i = 1; i <= n; ++i) { for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, &sum, i); std::thread t(do_sum, &fs, &fs_mu, &sum, i);
threads.push_back(std::move(t)); threads.push_back(std::move(t));
} }
for (auto& t : threads) { for (auto& t : threads) {
t.join(); t.join();
} }
pool->Wait(); for (auto& t : fs) {
t.wait();
}
EXPECT_EQ(sum, ((n + 1) * n) / 2); EXPECT_EQ(sum, ((n + 1) * n) / 2);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册