未验证 提交 5022ee63 编写于 作者: Y Yancey 提交者: GitHub

ThreadPool::Run interface return std::future (#7099)

* Run interface return future

* delete unused comments
上级 18311767
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <condition_variable>
#include <functional>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
......@@ -25,10 +26,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::function<void()> Task;
class ThreadPool {
public:
typedef std::packaged_task<void()> Task;
typedef std::function<void()> Fun;
/**
* @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts
......@@ -61,13 +63,18 @@ class ThreadPool {
/**
* @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available.
* @param[in] Task will be pushed to the task queue.
* @param[in] Task, will be pushed to the task queue.
* @return std::future<void>, we could wait for the task finished by
* f.wait().
*/
void Run(const Task& fn) {
std::future<void> Run(const Fun& fn) {
std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(fn);
Task task(std::bind(fn));
std::future<void> f = task.get_future();
tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one();
return f;
}
/**
......@@ -110,7 +117,7 @@ class ThreadPool {
break;
}
// pop a task from the task queue
auto task = tasks_.front();
auto task = std::move(tasks_.front());
tasks_.pop();
--available_;
......
......@@ -20,16 +20,21 @@ limitations under the License. */
namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
std::vector<std::future<void>> fs;
for (int i = 0; i < cnt; ++i) {
pool->Run([&sum]() { sum.fetch_add(1); });
auto f = pool->Run([&sum]() { sum.fetch_add(1); });
fs.push_back(std::move(f));
}
for (auto& f : fs) {
f.wait();
}
}
TEST(ThreadPool, ConcurrentInit) {
framework::ThreadPool* pool;
int concurrent_cnt = 50;
int n = 50;
std::vector<std::thread> threads;
for (int i = 0; i < concurrent_cnt; ++i) {
for (int i = 0; i < n; ++i) {
std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
threads.push_back(std::move(t));
}
......@@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) {
}
}
TEST(ThreadPool, ConcurrentStart) {
TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0);
std::vector<std::thread> threads;
int concurrent_cnt = 50;
int n = 50;
// sum = (n * (n + 1)) / 2
for (int i = 1; i <= concurrent_cnt; ++i) {
for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, std::ref(sum), i);
threads.push_back(std::move(t));
}
......@@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) {
t.join();
}
pool->Wait();
EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2);
EXPECT_EQ(sum, ((n + 1) * n) / 2);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册