From 127bc2e09c3e014a116c2e86f2f8abae8add10e6 Mon Sep 17 00:00:00 2001 From: Yancey Date: Mon, 25 Dec 2017 11:15:33 +0800 Subject: [PATCH] Implement a simple threadpool (#6684) * implement a simple threadpool * unlock before cv.notify * add done function * add lock with GetAvailable function * delete done_ * using call_once in GetInstance * update by comment * update comment * enhance unit test for multi threads task --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/threadpool.h | 161 ++++++++++++++++++++++++++++ paddle/framework/threadpool_test.cc | 58 ++++++++++ 3 files changed, 220 insertions(+) create mode 100644 paddle/framework/threadpool.h create mode 100644 paddle/framework/threadpool_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 5f826aeb837..25a0db27688 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -59,6 +59,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) +cc_test(threadpool_test SRCS threadpool_test.cc) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h new file mode 100644 index 00000000000..9a1ece3ae84 --- /dev/null +++ b/paddle/framework/threadpool.h @@ -0,0 +1,161 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/platform/call_once.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace framework { + +typedef std::function Task; + +class ThreadPool { + public: + /** + * @brief Get a instance of threadpool, the thread number will + * be specified as the number of hardware thread contexts + */ + static ThreadPool* GetInstance() { + std::call_once(init_flag, &ThreadPool::Init); + return threadpool.get(); + } + + ~ThreadPool() { + { + // notify all threads to stop running + running_ = false; + scheduled_.notify_all(); + } + + for (auto& t : threads_) { + t->join(); + t.reset(nullptr); + } + } + + int GetNumThreads() const { return num_threads_; } + + int GetAvailable() { + std::unique_lock lock(mutex_); + return available_; + } + + /** + * @brief Push a function to the queue, and will be scheduled and + * executed if a thread is available. + * @param[in] Task will be pushed to the task queue. + */ + void Run(const Task& fn) { + std::unique_lock lock(mutex_); + tasks_.push(fn); + lock.unlock(); + scheduled_.notify_one(); + } + + /** + * @brief Wait until all the tasks are completed. + */ + void Wait() { + std::unique_lock lock(mutex_); + completed_.wait(lock, [=] { return Done() == true; }); + } + + private: + ThreadPool& operator=(const ThreadPool&) = delete; + ThreadPool(const ThreadPool&) = delete; + + ThreadPool(int num_threads) + : num_threads_(num_threads), available_(num_threads), running_(true) { + threads_.resize(num_threads); + for (auto& thread : threads_) { + // TODO(Yancey1989): binding the thread on the specify CPU number + thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this))); + } + } + + /** + * @brief If the task queue is empty and avaialbe + * is equal to the number of threads, means that + * all tasks are completed. + * + * Note: this function is not thread-safe. + * + * @return true if all tasks are completed. + */ + bool Done() { return tasks_.empty() && available_ == num_threads_; } + + void TaskLoop() { + while (running_) { + std::unique_lock lock(mutex_); + scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; }); + + if (!running_) { + break; + } + // pop a task from the task queue + auto task = tasks_.front(); + tasks_.pop(); + + --available_; + lock.unlock(); + + // run the task + task(); + + { + std::unique_lock lock(mutex_); + ++available_; + if (Done()) { + completed_.notify_all(); + } + } + } + } + + static void Init() { + if (threadpool.get() == nullptr) { + // TODO(Yancey1989): specify the max threads number + int num_threads = std::thread::hardware_concurrency(); + PADDLE_ENFORCE_GT(num_threads, 0); + threadpool.reset(new ThreadPool(num_threads)); + } + } + + private: + static std::unique_ptr threadpool; + static std::once_flag init_flag; + + int num_threads_; + int available_; + bool running_; + std::queue tasks_; + std::vector> threads_; + std::mutex mutex_; + std::condition_variable scheduled_; + std::condition_variable completed_; +}; + +std::unique_ptr ThreadPool::threadpool(nullptr); +std::once_flag ThreadPool::init_flag; +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc new file mode 100644 index 00000000000..78c762608ed --- /dev/null +++ b/paddle/framework/threadpool_test.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "threadpool.h" +#include +#include +#include +#include +#include + +namespace framework = paddle::framework; + +void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) { + for (int i = 0; i < cnt; ++i) { + pool->Run([&sum]() { sum.fetch_add(1); }); + } +} + +TEST(ThreadPool, ConcurrentInit) { + framework::ThreadPool* pool; + int concurrent_cnt = 50; + std::vector threads; + for (int i = 0; i < concurrent_cnt; ++i) { + std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); }); + threads.push_back(std::move(t)); + } + for (auto& t : threads) { + t.join(); + } +} + +TEST(ThreadPool, ConcurrentStart) { + framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); + std::atomic sum(0); + std::vector threads; + int concurrent_cnt = 50; + // sum = (n * (n + 1)) / 2 + for (int i = 1; i <= concurrent_cnt; ++i) { + std::thread t(do_sum, pool, std::ref(sum), i); + threads.push_back(std::move(t)); + } + for (auto& t : threads) { + t.join(); + } + pool->Wait(); + EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2); +} -- GitLab