diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h index 4e9b58679d9e7c84adf76b6245b397c7a8872483..77d31a1176d5947655b57c70846294093a7bb5ef 100644 --- a/paddle/framework/threadpool.h +++ b/paddle/framework/threadpool.h @@ -21,7 +21,8 @@ limitations under the License. */ #include #include #include - +#include "glog/logging.h" +#include "paddle/platform/enforce.h" #include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { @@ -31,7 +32,7 @@ namespace framework { // number of threads. class ThreadPool { public: - typedef std::packaged_task Task; + using Task = std::packaged_task()>; // Returns the singleton of ThreadPool. static ThreadPool* GetInstance(); @@ -52,9 +53,28 @@ class ThreadPool { // std::future::wait(). template std::future Run(Callback fn) { + auto f = this->RunAndGetException(fn); + return std::async(std::launch::deferred, ExceptionHandler(std::move(f))); + } + + template + std::future> RunAndGetException( + Callback fn) { std::unique_lock lock(mutex_); - Task task(std::bind(fn)); - std::future f = task.get_future(); + Task task([fn]() -> std::unique_ptr { + try { + fn(); + return nullptr; + } catch (platform::EnforceNotMet ex) { + return std::unique_ptr( + new platform::EnforceNotMet(ex)); + } catch (...) { + LOG(FATAL) + << "Unexpected exception is catched in thread pool. All " + "throwable exception in Fluid should be an EnforceNotMet."; + } + }); + std::future> f = task.get_future(); tasks_.push(std::move(task)); lock.unlock(); scheduled_.notify_one(); @@ -65,6 +85,22 @@ class ThreadPool { void Wait(); private: + struct ExceptionHandler { + mutable std::future> future_; + explicit ExceptionHandler( + std::future>&& f) + : future_(std::move(f)) {} + void operator()() const { + auto ex = this->future_.get(); + if (ex != nullptr) { + LOG(FATAL) << "The exception is thrown inside the thread pool. You " + "should use RunAndGetException to handle the exception.\n" + "The default exception handler is LOG(FATAL)." + << ex->what(); + } + } + }; + DISABLE_COPY_AND_ASSIGN(ThreadPool); explicit ThreadPool(int num_threads);