未验证 提交 b0c0ffb9 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine pe when exception raises, test=develop (#20894)

上级 20cdff0e
...@@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
// run the recorded operators directly. This strategy could make the // run the recorded operators directly. This strategy could make the
// execution faster. // execution faster.
VLOG(3) << "Run the traced ops."; VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_); bool is_exception_free =
RunTracedOps(fetch_ops); RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (exception_.IsCaught()) { if (!is_exception_free) {
ExecutionFinal(&fetch_ops); ExecutionFinal(&fetch_ops);
} }
} else { } else {
...@@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal( ...@@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal(
exception_.ReThrow(); exception_.ReThrow();
} }
void FastThreadedSSAGraphExecutor::RunTracedOps( bool FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) { const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) { for (auto &op : traced_ops) {
if (exception_.IsCaught()) { if (!RunOpSync(op)) return false;
return;
}
RunOpSync(op);
} }
return true;
} }
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try { try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) { } catch (...) {
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
return false;
} }
} }
......
...@@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops); inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op); inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops); bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps( void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches, const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
......
...@@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
// run the recorded operators directly. This strategy could make the // run the recorded operators directly. This strategy could make the
// execution faster. // execution faster.
VLOG(3) << "Run the traced ops."; VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_); bool is_exception_free =
RunTracedOps(fetch_ops); RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) { if (!is_exception_free) {
ExecutionFinal(&fetch_ops); ExecutionFinal(&fetch_ops);
} }
} else { } else {
...@@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp(
RecordOps(op); RecordOps(op);
} }
void ThreadedSSAGraphExecutor::RunTracedOps( bool ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) { const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) { for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) { if (!RunOpSync(op)) return false;
return;
}
RunOpSync(op);
} }
return true;
} }
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try { try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
return false;
} }
} }
......
...@@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops); inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op); inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops); bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册