未验证 提交 84994608 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8181 from reyoung/feature/add_exception_for_thread_pool

Add RunAndGetException in threadpool
......@@ -21,7 +21,8 @@ limitations under the License. */
#include <queue>
#include <thread>
#include <vector>
#include "glog/logging.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
......@@ -31,7 +32,7 @@ namespace framework {
// number of threads.
class ThreadPool {
public:
typedef std::packaged_task<void()> Task;
using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
// Returns the singleton of ThreadPool.
static ThreadPool* GetInstance();
......@@ -52,9 +53,28 @@ class ThreadPool {
// std::future::wait().
template <typename Callback>
std::future<void> Run(Callback fn) {
auto f = this->RunAndGetException(fn);
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
}
template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task(std::bind(fn));
std::future<void> f = task.get_future();
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try {
fn();
return nullptr;
} catch (platform::EnforceNotMet ex) {
return std::unique_ptr<platform::EnforceNotMet>(
new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL)
<< "Unexpected exception is catched in thread pool. All "
"throwable exception in Fluid should be an EnforceNotMet.";
}
});
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one();
......@@ -65,6 +85,22 @@ class ThreadPool {
void Wait();
private:
struct ExceptionHandler {
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
explicit ExceptionHandler(
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
: future_(std::move(f)) {}
void operator()() const {
auto ex = this->future_.get();
if (ex != nullptr) {
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
"should use RunAndGetException to handle the exception.\n"
"The default exception handler is LOG(FATAL)."
<< ex->what();
}
}
};
DISABLE_COPY_AND_ASSIGN(ThreadPool);
explicit ThreadPool(int num_threads);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册