diff --git a/paddle/fluid/framework/details/exception_holder.h b/paddle/fluid/framework/details/exception_holder.h index 6e302a29233b96451df14b4685911be1cd87c1ab..c97b364de1ecae21e97351196389615187932b5e 100644 --- a/paddle/fluid/framework/details/exception_holder.h +++ b/paddle/fluid/framework/details/exception_holder.h @@ -14,6 +14,7 @@ #pragma once +#include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -22,27 +23,24 @@ namespace details { class ExceptionHolder { public: - void Catch(const platform::EnforceNotMet& exp) { - std::lock_guard lock(mu_); - exception_.reset(new platform::EnforceNotMet(exp)); - type_ = kEnforceNotMet; - } - - void Catch(const platform::EOFException& exp) { - std::lock_guard lock(mu_); - // EOFException will not cover up existing EnforceNotMet. - if (exception_.get() == nullptr) { - exception_.reset(new platform::EOFException(exp)); - type_ = kEOF; + void Catch(std::exception_ptr eptr) { + try { + std::rethrow_exception(eptr); + } catch (platform::EOFException exp) { + Catch(exp); + } catch (platform::EnforceNotMet exp) { + Catch(exp); + } catch (...) { + LOG(FATAL) << "Unknown exception caught"; } } - bool ExceptionCatched() const { + bool IsCaught() const { std::lock_guard lock(mu_); return exception_.get() != nullptr; } - void Throw() { + void ReThrow() { std::lock_guard lock(mu_); switch (type_) { case kNone: @@ -50,27 +48,41 @@ class ExceptionHolder { case kEnforceNotMet: { auto e = *static_cast(exception_.get()); throw e; - break; } case kEOF: { auto e = *static_cast(exception_.get()); throw e; - break; } - default: - LOG(FATAL) << "Unknown exception."; } - exception_.reset(); - type_ = kNone; + ClearImpl(); } void Clear() { std::lock_guard lock(mu_); + ClearImpl(); + } + + private: + void ClearImpl() { exception_.reset(); type_ = kNone; } - private: + void Catch(const platform::EnforceNotMet& exp) { + std::lock_guard lock(mu_); + exception_.reset(new platform::EnforceNotMet(exp)); + type_ = kEnforceNotMet; + } + + void Catch(const platform::EOFException& exp) { + std::lock_guard lock(mu_); + // EOFException will not cover up existing EnforceNotMet. + if (exception_.get() == nullptr) { + exception_.reset(new platform::EOFException(exp)); + type_ = kEOF; + } + } + enum ExceptionType { kNone, kEnforceNotMet, kEOF }; ExceptionType type_{kNone}; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 994bb6492f685138d02971a6caf12572aecd6d6f..c9e331ef359f853263f8dad38dd0a2be4d9618ad 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -107,11 +107,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { - if (exception_holder_.ExceptionCatched()) { + if (exception_holder_.IsCaught()) { for (auto &run_op_future : run_op_futures_) { run_op_future.wait(); } - exception_holder_.Throw(); + exception_holder_.ReThrow(); } else { continue; } @@ -220,12 +220,8 @@ void ThreadedSSAGraphExecutor::RunOp( running_ops_--; ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; - } catch (platform::EOFException ex) { - exception_holder_.Catch(ex); - } catch (platform::EnforceNotMet ex) { - exception_holder_.Catch(ex); } catch (...) { - LOG(FATAL) << "Unknown exception catched"; + exception_holder_.Catch(std::current_exception()); } }; if (pool_) {