threadpool.h 4.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yancey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
D
dzhwinter 已提交
16

Y
Yancey 已提交
17 18
#include <condition_variable>
#include <functional>
19
#include <future>
Y
Yancey 已提交
20 21 22
#include <mutex>
#include <queue>
#include <thread>
Y
Yi Wang 已提交
23
#include <vector>
Y
Yang Yu 已提交
24
#include "glog/logging.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
Y
Yancey 已提交
27 28 29 30

namespace paddle {
namespace framework {

Y
Yi Wang 已提交
31 32
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
Y
Yancey 已提交
33 34
class ThreadPool {
 public:
Y
Yu Yang 已提交
35 36
  explicit ThreadPool(int num_threads);

Y
Yang Yu 已提交
37
  using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
38

Y
Yi Wang 已提交
39 40
  // Returns the singleton of ThreadPool.
  static ThreadPool* GetInstance();
Y
Yancey 已提交
41

Y
Yi Wang 已提交
42
  ~ThreadPool();
Y
Yancey 已提交
43

Y
Yi Wang 已提交
44 45
  // Returns the number of threads created by the constructor.
  size_t Threads() const { return total_threads_; }
Y
Yancey 已提交
46

Y
Yi Wang 已提交
47 48
  // Returns the number of currently idle threads.
  size_t IdleThreads() {
Y
Yancey 已提交
49
    std::unique_lock<std::mutex> lock(mutex_);
Y
Yi Wang 已提交
50
    return idle_threads_;
Y
Yancey 已提交
51 52
  }

Y
Yi Wang 已提交
53 54 55
  // Run pushes a function to the task queue and returns a std::future
  // object.  To wait for the completion of the task, call
  // std::future::wait().
56 57
  template <typename Callback>
  std::future<void> Run(Callback fn) {
Y
Yang Yu 已提交
58 59 60 61 62 63 64
    auto f = this->RunAndGetException(fn);
    return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
  }

  template <typename Callback>
  std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
      Callback fn) {
Y
Yancey 已提交
65
    std::unique_lock<std::mutex> lock(mutex_);
Y
Yang Yu 已提交
66 67 68 69 70 71
    Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
      try {
        fn();
      } catch (platform::EnforceNotMet ex) {
        return std::unique_ptr<platform::EnforceNotMet>(
            new platform::EnforceNotMet(ex));
X
Xin Pan 已提交
72 73 74 75
      } catch (const std::exception& e) {
        LOG(FATAL) << "Unexpected exception is catched in thread pool. All "
                      "throwable exception in Fluid should be an EnforceNotMet."
                   << e.what();
Y
Yang Yu 已提交
76
      }
Y
Yi Wang 已提交
77
      return nullptr;
Y
Yang Yu 已提交
78 79
    });
    std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
80
    tasks_.push(std::move(task));
Y
Yancey 已提交
81 82
    lock.unlock();
    scheduled_.notify_one();
83
    return f;
Y
Yancey 已提交
84 85
  }

Y
Yi Wang 已提交
86 87
  // Wait until all the tasks are completed.
  void Wait();
Y
Yancey 已提交
88 89

 private:
Y
Yang Yu 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  struct ExceptionHandler {
    mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
    explicit ExceptionHandler(
        std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
        : future_(std::move(f)) {}
    void operator()() const {
      auto ex = this->future_.get();
      if (ex != nullptr) {
        LOG(FATAL) << "The exception is thrown inside the thread pool. You "
                      "should use RunAndGetException to handle the exception.\n"
                      "The default exception handler is LOG(FATAL)."
                   << ex->what();
      }
    }
  };

D
dzhwinter 已提交
106
  DISABLE_COPY_AND_ASSIGN(ThreadPool);
Y
Yancey 已提交
107

Y
Yi Wang 已提交
108 109 110 111 112 113 114
  // If the task queue is empty and avaialbe is equal to the number of
  // threads, means that all tasks are completed.  Note: this function
  // is not thread-safe.  Returns true if all tasks are completed.
  // Note: don't delete the data member total_threads_ and use
  // threads_.size() instead; because you'd need to lock the mutex
  // before accessing threads_.
  bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
Y
Yancey 已提交
115

Y
Yi Wang 已提交
116 117 118 119 120 121
  // The constructor starts threads to run TaskLoop, which retrieves
  // and runs tasks from the queue.
  void TaskLoop();

  // Init is called by GetInstance.
  static void Init();
Y
Yancey 已提交
122 123

 private:
Y
Yi Wang 已提交
124 125
  static std::unique_ptr<ThreadPool> threadpool_;
  static std::once_flag init_flag_;
Y
Yancey 已提交
126 127

  std::vector<std::unique_ptr<std::thread>> threads_;
Y
Yi Wang 已提交
128 129 130 131
  const size_t total_threads_;
  size_t idle_threads_;

  std::queue<Task> tasks_;
Y
Yancey 已提交
132
  std::mutex mutex_;
Y
Yi Wang 已提交
133
  bool running_;
Y
Yancey 已提交
134 135 136 137
  std::condition_variable scheduled_;
  std::condition_variable completed_;
};

Y
Yang Yu 已提交
138 139 140
// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
141 142 143
template <typename Callback>
std::future<void> Async(Callback callback) {
  return ThreadPool::GetInstance()->Run(callback);
Y
Yang Yu 已提交
144
}
Y
Yang Yu 已提交
145

Y
Yancey 已提交
146 147
}  // namespace framework
}  // namespace paddle