未验证 提交 84994608 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8181 from reyoung/feature/add_exception_for_thread_pool

Add RunAndGetException in threadpool
...@@ -21,7 +21,8 @@ limitations under the License. */ ...@@ -21,7 +21,8 @@ limitations under the License. */
#include <queue> #include <queue>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "glog/logging.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
...@@ -31,7 +32,7 @@ namespace framework { ...@@ -31,7 +32,7 @@ namespace framework {
// number of threads. // number of threads.
class ThreadPool { class ThreadPool {
public: public:
typedef std::packaged_task<void()> Task; using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
// Returns the singleton of ThreadPool. // Returns the singleton of ThreadPool.
static ThreadPool* GetInstance(); static ThreadPool* GetInstance();
...@@ -52,9 +53,28 @@ class ThreadPool { ...@@ -52,9 +53,28 @@ class ThreadPool {
// std::future::wait(). // std::future::wait().
template <typename Callback> template <typename Callback>
std::future<void> Run(Callback fn) { std::future<void> Run(Callback fn) {
auto f = this->RunAndGetException(fn);
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
}
template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
Task task(std::bind(fn)); Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
std::future<void> f = task.get_future(); try {
fn();
return nullptr;
} catch (platform::EnforceNotMet ex) {
return std::unique_ptr<platform::EnforceNotMet>(
new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL)
<< "Unexpected exception is catched in thread pool. All "
"throwable exception in Fluid should be an EnforceNotMet.";
}
});
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
lock.unlock(); lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
...@@ -65,6 +85,22 @@ class ThreadPool { ...@@ -65,6 +85,22 @@ class ThreadPool {
void Wait(); void Wait();
private: private:
struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
explicit ThreadPool(int num_threads); explicit ThreadPool(int num_threads);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册