threadpool.h 4.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yancey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
D
dzhwinter 已提交
16

T
typhoonzero 已提交
17
#include <condition_variable>  // NOLINT
Y
Yancey 已提交
18
#include <functional>
T
typhoonzero 已提交
19 20
#include <future>  // NOLINT
#include <mutex>   // NOLINT
Y
Yancey 已提交
21
#include <queue>
T
typhoonzero 已提交
22
#include <thread>  // NOLINT
Y
Yi Wang 已提交
23
#include <vector>
Y
Yang Yu 已提交
24
#include "glog/logging.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
Y
Yancey 已提交
27 28 29

namespace paddle {
namespace framework {
T
typhoonzero 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
struct ExceptionHandler {
  mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
  explicit ExceptionHandler(
      std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
      : future_(std::move(f)) {}
  void operator()() const {
    auto ex = this->future_.get();
    if (ex != nullptr) {
      LOG(FATAL) << "The exception is thrown inside the thread pool. You "
                    "should use RunAndGetException to handle the exception.\n"
                    "The default exception handler is LOG(FATAL)."
                 << ex->what();
    }
  }
};

Y
Yi Wang 已提交
46 47
// ThreadPool maintains a queue of tasks, and runs them using a fixed
// number of threads.
Y
Yancey 已提交
48 49
class ThreadPool {
 public:
Y
Yu Yang 已提交
50 51
  explicit ThreadPool(int num_threads);

Y
Yang Yu 已提交
52
  using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
53

Y
Yi Wang 已提交
54 55
  // Returns the singleton of ThreadPool.
  static ThreadPool* GetInstance();
Y
Yancey 已提交
56

57
  ~ThreadPool();
Y
Yancey 已提交
58

Y
Yi Wang 已提交
59
  // Run pushes a function to the task queue and returns a std::future
Q
Qiao Longfei 已提交
60
  // object. To wait for the completion of the task, call
Y
Yi Wang 已提交
61
  // std::future::wait().
62 63
  template <typename Callback>
  std::future<void> Run(Callback fn) {
Y
Yang Yu 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
    auto f = this->RunAndGetException(fn);
    return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
  }

  template <typename Callback>
  std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
      Callback fn) {
    Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
      try {
        fn();
      } catch (platform::EnforceNotMet ex) {
        return std::unique_ptr<platform::EnforceNotMet>(
            new platform::EnforceNotMet(ex));
X
Xin Pan 已提交
77 78 79 80
      } catch (const std::exception& e) {
        LOG(FATAL) << "Unexpected exception is catched in thread pool. All "
                      "throwable exception in Fluid should be an EnforceNotMet."
                   << e.what();
Y
Yang Yu 已提交
81
      }
Y
Yi Wang 已提交
82
      return nullptr;
Y
Yang Yu 已提交
83 84
    });
    std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
Q
Qiao Longfei 已提交
85 86 87 88 89 90 91
    {
      std::unique_lock<std::mutex> lock(mutex_);
      if (!running_) {
        PADDLE_THROW("enqueue on stopped ThreadPool");
      }
      tasks_.push(std::move(task));
    }
Y
Yancey 已提交
92
    scheduled_.notify_one();
93
    return f;
Y
Yancey 已提交
94 95 96
  }

 private:
D
dzhwinter 已提交
97
  DISABLE_COPY_AND_ASSIGN(ThreadPool);
Y
Yancey 已提交
98

Y
Yi Wang 已提交
99 100
  // The constructor starts threads to run TaskLoop, which retrieves
  // and runs tasks from the queue.
Y
Yancey1989 已提交
101
  void TaskLoop();
Y
Yi Wang 已提交
102 103 104

  // Init is called by GetInstance.
  static void Init();
Y
Yancey 已提交
105 106

 private:
Y
Yi Wang 已提交
107 108
  static std::unique_ptr<ThreadPool> threadpool_;
  static std::once_flag init_flag_;
Y
Yancey 已提交
109 110

  std::vector<std::unique_ptr<std::thread>> threads_;
Y
Yi Wang 已提交
111 112

  std::queue<Task> tasks_;
Y
Yancey 已提交
113
  std::mutex mutex_;
Y
Yi Wang 已提交
114
  bool running_;
Y
Yancey 已提交
115 116 117
  std::condition_variable scheduled_;
};

T
typhoonzero 已提交
118
class ThreadPoolIO : ThreadPool {
T
typhoonzero 已提交
119
 public:
T
typhoonzero 已提交
120
  static ThreadPool* GetInstanceIO();
T
typhoonzero 已提交
121 122 123 124 125 126 127 128
  static void InitIO();

 private:
  // NOTE: threadpool in base will be inhereted here.
  static std::unique_ptr<ThreadPool> io_threadpool_;
  static std::once_flag io_init_flag_;
};

Y
Yang Yu 已提交
129 130 131
// Run a function asynchronously.
// NOTE: The function must return void. If the function need to return a value,
// you can use lambda to capture a value pointer.
132 133 134
template <typename Callback>
std::future<void> Async(Callback callback) {
  return ThreadPool::GetInstance()->Run(callback);
Y
Yang Yu 已提交
135
}
Y
Yang Yu 已提交
136

T
typhoonzero 已提交
137 138
template <typename Callback>
std::future<void> AsyncIO(Callback callback) {
T
typhoonzero 已提交
139
  return ThreadPoolIO::GetInstanceIO()->Run(callback);
T
typhoonzero 已提交
140 141
}

Y
Yancey 已提交
142 143
}  // namespace framework
}  // namespace paddle