threadpool.h 3.2 KB
Newer Older
Y
Yancey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
D
dzhwinter 已提交
16

Y
Yancey 已提交
17 18
#include <condition_variable>
#include <functional>
19
#include <future>
Y
Yancey 已提交
20 21 22
#include <mutex>
#include <queue>
#include <thread>
Y
Yi Wang 已提交
23
#include <vector>
Y
Yancey 已提交
24

Y
Yi Wang 已提交
25
#include "paddle/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
Y
Yancey 已提交
26 27 28 29

namespace paddle {
namespace framework {

Y
Yi Wang 已提交
30 31
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
Y
Yancey 已提交
32 33
class ThreadPool {
 public:
34 35
  typedef std::packaged_task<void()> Task;

Y
Yi Wang 已提交
36 37
  // Returns the singleton of ThreadPool.
  static ThreadPool* GetInstance();
Y
Yancey 已提交
38

Y
Yi Wang 已提交
39
  ~ThreadPool();
Y
Yancey 已提交
40

Y
Yi Wang 已提交
41 42
  // Returns the number of threads created by the constructor.
  size_t Threads() const { return total_threads_; }
Y
Yancey 已提交
43

Y
Yi Wang 已提交
44 45
  // Returns the number of currently idle threads.
  size_t IdleThreads() {
Y
Yancey 已提交
46
    std::unique_lock<std::mutex> lock(mutex_);
Y
Yi Wang 已提交
47
    return idle_threads_;
Y
Yancey 已提交
48 49
  }

Y
Yi Wang 已提交
50 51 52
  // Run pushes a function to the task queue and returns a std::future
  // object.  To wait for the completion of the task, call
  // std::future::wait().
53 54
  template <typename Callback>
  std::future<void> Run(Callback fn) {
Y
Yancey 已提交
55
    std::unique_lock<std::mutex> lock(mutex_);
56 57 58
    Task task(std::bind(fn));
    std::future<void> f = task.get_future();
    tasks_.push(std::move(task));
Y
Yancey 已提交
59 60
    lock.unlock();
    scheduled_.notify_one();
61
    return f;
Y
Yancey 已提交
62 63
  }

Y
Yi Wang 已提交
64 65
  // Wait until all the tasks are completed.
  void Wait();
Y
Yancey 已提交
66 67

 private:
D
dzhwinter 已提交
68
  DISABLE_COPY_AND_ASSIGN(ThreadPool);
Y
Yancey 已提交
69

Y
Yi Wang 已提交
70
  explicit ThreadPool(int num_threads);
Y
Yancey 已提交
71

Y
Yi Wang 已提交
72 73 74 75 76 77 78
  // If the task queue is empty and avaialbe is equal to the number of
  // threads, means that all tasks are completed.  Note: this function
  // is not thread-safe.  Returns true if all tasks are completed.
  // Note: don't delete the data member total_threads_ and use
  // threads_.size() instead; because you'd need to lock the mutex
  // before accessing threads_.
  bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
Y
Yancey 已提交
79

Y
Yi Wang 已提交
80 81 82 83 84 85
  // The constructor starts threads to run TaskLoop, which retrieves
  // and runs tasks from the queue.
  void TaskLoop();

  // Init is called by GetInstance.
  static void Init();
Y
Yancey 已提交
86 87

 private:
Y
Yi Wang 已提交
88 89
  static std::unique_ptr<ThreadPool> threadpool_;
  static std::once_flag init_flag_;
Y
Yancey 已提交
90 91

  std::vector<std::unique_ptr<std::thread>> threads_;
Y
Yi Wang 已提交
92 93 94 95
  const size_t total_threads_;
  size_t idle_threads_;

  std::queue<Task> tasks_;
Y
Yancey 已提交
96
  std::mutex mutex_;
Y
Yi Wang 已提交
97
  bool running_;
Y
Yancey 已提交
98 99 100 101
  std::condition_variable scheduled_;
  std::condition_variable completed_;
};

Y
Yang Yu 已提交
102 103 104
// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
105 106 107
template <typename Callback>
std::future<void> Async(Callback callback) {
  return ThreadPool::GetInstance()->Run(callback);
Y
Yang Yu 已提交
108
}
Y
Yang Yu 已提交
109

Y
Yancey 已提交
110 111
}  // namespace framework
}  // namespace paddle