threadpool.h 4.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yancey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
D
dzhwinter 已提交
16

T
typhoonzero 已提交
17
#include <condition_variable>  // NOLINT
Y
Yancey 已提交
18
#include <functional>
T
typhoonzero 已提交
19
#include <future>  // NOLINT
20 21
#include <memory>
#include <mutex>  // NOLINT
Y
Yancey 已提交
22
#include <queue>
T
typhoonzero 已提交
23
#include <thread>  // NOLINT
24
#include <utility>
Y
Yi Wang 已提交
25
#include <vector>
W
wanghuancoder 已提交
26

Y
Yang Yu 已提交
27
#include "glog/logging.h"
Y
Yi Wang 已提交
28 29
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
Y
Yancey 已提交
30 31 32 33

namespace paddle {
namespace framework {

T
typhoonzero 已提交
34 35 36 37 38 39 40 41
struct ExceptionHandler {
  mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
  explicit ExceptionHandler(
      std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
      : future_(std::move(f)) {}
  void operator()() const {
    auto ex = this->future_.get();
    if (ex != nullptr) {
42 43 44 45 46
      PADDLE_THROW(platform::errors::Fatal(
          "The exception is thrown inside the thread pool. You "
          "should use RunAndGetException to handle the exception."
          "The exception is:\n %s.",
          ex->what()));
T
typhoonzero 已提交
47 48 49 50
    }
  }
};

Y
Yi Wang 已提交
51 52
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
Y
Yancey 已提交
53 54
class ThreadPool {
 public:
Y
Yu Yang 已提交
55 56
  explicit ThreadPool(int num_threads);

Y
Yang Yu 已提交
57
  using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
58

Y
Yi Wang 已提交
59 60
  // Returns the singleton of ThreadPool.
  static ThreadPool* GetInstance();
Y
Yancey 已提交
61

62
  ~ThreadPool();
Y
Yancey 已提交
63

Y
Yi Wang 已提交
64
  // Run pushes a function to the task queue and returns a std::future
Q
Qiao Longfei 已提交
65
  // object. To wait for the completion of the task, call
Y
Yi Wang 已提交
66
  // std::future::wait().
67 68
  template <typename Callback>
  std::future<void> Run(Callback fn) {
Y
Yang Yu 已提交
69 70 71 72 73 74 75 76 77 78
    auto f = this->RunAndGetException(fn);
    return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
  }

  template <typename Callback>
  std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
      Callback fn) {
    Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
      try {
        fn();
79
      } catch (platform::EnforceNotMet& ex) {
Y
Yang Yu 已提交
80 81
        return std::unique_ptr<platform::EnforceNotMet>(
            new platform::EnforceNotMet(ex));
X
Xin Pan 已提交
82
      } catch (const std::exception& e) {
83 84 85 86 87
        PADDLE_THROW(platform::errors::Fatal(
            "Unexpected exception is catched in thread pool. All "
            "throwable exception in Paddle should be an EnforceNotMet."
            "The exception is:\n %s.",
            e.what()));
Y
Yang Yu 已提交
88
      }
Y
Yi Wang 已提交
89
      return nullptr;
Y
Yang Yu 已提交
90 91
    });
    std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
Q
Qiao Longfei 已提交
92 93 94
    {
      std::unique_lock<std::mutex> lock(mutex_);
      if (!running_) {
95 96
        PADDLE_THROW(platform::errors::Unavailable(
            "Task is enqueued into stopped ThreadPool."));
Q
Qiao Longfei 已提交
97 98 99
      }
      tasks_.push(std::move(task));
    }
Y
Yancey 已提交
100
    scheduled_.notify_one();
101
    return f;
Y
Yancey 已提交
102 103 104
  }

 private:
D
dzhwinter 已提交
105
  DISABLE_COPY_AND_ASSIGN(ThreadPool);
Y
Yancey 已提交
106

Y
Yi Wang 已提交
107 108 109 110 111 112
  // The constructor starts threads to run TaskLoop, which retrieves
  // and runs tasks from the queue.
  void TaskLoop();

  // Init is called by GetInstance.
  static void Init();
Y
Yancey 已提交
113 114

 private:
Y
Yi Wang 已提交
115 116
  static std::unique_ptr<ThreadPool> threadpool_;
  static std::once_flag init_flag_;
Y
Yancey 已提交
117 118

  std::vector<std::unique_ptr<std::thread>> threads_;
Y
Yi Wang 已提交
119 120

  std::queue<Task> tasks_;
Y
Yancey 已提交
121
  std::mutex mutex_;
Y
Yi Wang 已提交
122
  bool running_;
Y
Yancey 已提交
123 124 125
  std::condition_variable scheduled_;
};

T
typhoonzero 已提交
126
class ThreadPoolIO : ThreadPool {
T
typhoonzero 已提交
127
 public:
T
typhoonzero 已提交
128
  static ThreadPool* GetInstanceIO();
T
typhoonzero 已提交
129 130 131 132 133 134 135 136
  static void InitIO();

 private:
  // NOTE: threadpool in base will be inhereted here.
  static std::unique_ptr<ThreadPool> io_threadpool_;
  static std::once_flag io_init_flag_;
};

Y
Yang Yu 已提交
137 138 139
// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
140 141 142
template <typename Callback>
std::future<void> Async(Callback callback) {
  return ThreadPool::GetInstance()->Run(callback);
Y
Yang Yu 已提交
143
}
Y
Yang Yu 已提交
144

T
typhoonzero 已提交
145 146
template <typename Callback>
std::future<void> AsyncIO(Callback callback) {
T
typhoonzero 已提交
147
  return ThreadPoolIO::GetInstanceIO()->Run(callback);
T
typhoonzero 已提交
148 149
}

Y
Yancey 已提交
150 151
}  // namespace framework
}  // namespace paddle