diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 28046b4a0351cc702970bd4266b3a6ef80f59252..e9635d8003aab7701d232066fd98f2fc0bda5360 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( // run the recorded operators directly. This strategy could make the // execution faster. VLOG(3) << "Run the traced ops."; - RunTracedOps(traced_ops_); - RunTracedOps(fetch_ops); - if (exception_.IsCaught()) { + bool is_exception_free = + RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops); + if (!is_exception_free) { ExecutionFinal(&fetch_ops); } } else { @@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal( exception_.ReThrow(); } -void FastThreadedSSAGraphExecutor::RunTracedOps( +bool FastThreadedSSAGraphExecutor::RunTracedOps( const std::vector &traced_ops) { for (auto &op : traced_ops) { - if (exception_.IsCaught()) { - return; - } - RunOpSync(op); + if (!RunOpSync(op)) return false; } + return true; } -void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { +bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (LIKELY(!strategy_.dry_run_)) { op->Run(strategy_.use_cuda_); } VLOG(10) << op << " " << op->Name() << " Done "; + return true; } catch (...) { exception_.Catch(std::current_exception()); + return false; } } diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h index 5d11c2cfd9ed6a8b49aa6ee01c89969dc75c21a6..0e904554d83c17cbd0b8f436fadaa85b8b8b68e9 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { inline void ExecutionFinal(std::vector *fetch_ops); - inline void RunOpSync(OpHandleBase *op); + inline bool RunOpSync(OpHandleBase *op); - void RunTracedOps(const std::vector &traced_ops); + bool RunTracedOps(const std::vector &traced_ops); void InsertFetchOps( const std::vector &fetch_tensors, FeedFetchList *fetches, diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 86d33c242d5f00c88cc60443b1072d1352161145..5ee47a3933bd2cd3c5f75b7627733f5659d71475 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( // run the recorded operators directly. This strategy could make the // execution faster. VLOG(3) << "Run the traced ops."; - RunTracedOps(traced_ops_); - RunTracedOps(fetch_ops); - if (exception_holder_.IsCaught()) { + bool is_exception_free = + RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops); + if (!is_exception_free) { ExecutionFinal(&fetch_ops); } } else { @@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp( RecordOps(op); } -void ThreadedSSAGraphExecutor::RunTracedOps( +bool ThreadedSSAGraphExecutor::RunTracedOps( const std::vector &traced_ops) { for (auto &op : traced_ops) { - if (exception_holder_.IsCaught()) { - return; - } - RunOpSync(op); + if (!RunOpSync(op)) return false; } + return true; } -void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { +bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (LIKELY(!strategy_.dry_run_)) { op->Run(strategy_.use_cuda_); } VLOG(10) << op << " " << op->Name() << " Done "; + return true; } catch (...) { exception_holder_.Catch(std::current_exception()); + return false; } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index fe6ef95a135417c0c73cfb3c9a20af66dc5047e6..8576e2e65a9256bfba1f45da2cc608301b8f79ad 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { inline void ExecutionFinal(std::vector *fetch_ops); - inline void RunOpSync(OpHandleBase *op); + inline bool RunOpSync(OpHandleBase *op); - void RunTracedOps(const std::vector &traced_ops); + bool RunTracedOps(const std::vector &traced_ops); }; } // namespace details