From b0c0ffb9aedcaea3b039add0ea999148c96c2eca Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 1 Nov 2019 09:58:12 +0800 Subject: [PATCH] refine pe when exception raises, test=develop (#20894) --- .../fast_threaded_ssa_graph_executor.cc | 18 +++++++++--------- .../details/fast_threaded_ssa_graph_executor.h | 4 ++-- .../details/threaded_ssa_graph_executor.cc | 18 +++++++++--------- .../details/threaded_ssa_graph_executor.h | 4 ++-- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 28046b4a035..e9635d8003a 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( // run the recorded operators directly. This strategy could make the // execution faster. VLOG(3) << "Run the traced ops."; - RunTracedOps(traced_ops_); - RunTracedOps(fetch_ops); - if (exception_.IsCaught()) { + bool is_exception_free = + RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops); + if (!is_exception_free) { ExecutionFinal(&fetch_ops); } } else { @@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal( exception_.ReThrow(); } -void FastThreadedSSAGraphExecutor::RunTracedOps( +bool FastThreadedSSAGraphExecutor::RunTracedOps( const std::vector &traced_ops) { for (auto &op : traced_ops) { - if (exception_.IsCaught()) { - return; - } - RunOpSync(op); + if (!RunOpSync(op)) return false; } + return true; } -void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { +bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (LIKELY(!strategy_.dry_run_)) { op->Run(strategy_.use_cuda_); } VLOG(10) << op << " " << op->Name() << " Done "; + return true; } catch (...) { exception_.Catch(std::current_exception()); + return false; } } diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h index 5d11c2cfd9e..0e904554d83 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { inline void ExecutionFinal(std::vector *fetch_ops); - inline void RunOpSync(OpHandleBase *op); + inline bool RunOpSync(OpHandleBase *op); - void RunTracedOps(const std::vector &traced_ops); + bool RunTracedOps(const std::vector &traced_ops); void InsertFetchOps( const std::vector &fetch_tensors, FeedFetchList *fetches, diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 86d33c242d5..5ee47a3933b 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( // run the recorded operators directly. This strategy could make the // execution faster. VLOG(3) << "Run the traced ops."; - RunTracedOps(traced_ops_); - RunTracedOps(fetch_ops); - if (exception_holder_.IsCaught()) { + bool is_exception_free = + RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops); + if (!is_exception_free) { ExecutionFinal(&fetch_ops); } } else { @@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp( RecordOps(op); } -void ThreadedSSAGraphExecutor::RunTracedOps( +bool ThreadedSSAGraphExecutor::RunTracedOps( const std::vector &traced_ops) { for (auto &op : traced_ops) { - if (exception_holder_.IsCaught()) { - return; - } - RunOpSync(op); + if (!RunOpSync(op)) return false; } + return true; } -void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { +bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { try { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (LIKELY(!strategy_.dry_run_)) { op->Run(strategy_.use_cuda_); } VLOG(10) << op << " " << op->Name() << " Done "; + return true; } catch (...) { exception_holder_.Catch(std::current_exception()); + return false; } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index fe6ef95a135..8576e2e65a9 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { inline void ExecutionFinal(std::vector *fetch_ops); - inline void RunOpSync(OpHandleBase *op); + inline bool RunOpSync(OpHandleBase *op); - void RunTracedOps(const std::vector &traced_ops); + bool RunTracedOps(const std::vector &traced_ops); }; } // namespace details -- GitLab