未验证 提交 b0c0ffb9 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine pe when exception raises, test=develop (#20894)

上级 20cdff0e
......@@ -77,9 +77,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_.IsCaught()) {
bool is_exception_free =
RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (!is_exception_free) {
ExecutionFinal(&fetch_ops);
}
} else {
......@@ -259,25 +259,25 @@ void FastThreadedSSAGraphExecutor::ExecutionFinal(
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
bool FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
if (!RunOpSync(op)) return false;
}
return true;
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
bool FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) {
exception_.Catch(std::current_exception());
return false;
}
}
......
......@@ -78,9 +78,9 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
......
......@@ -81,9 +81,9 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
bool is_exception_free =
RunTracedOps(traced_ops_) && RunTracedOps(fetch_ops);
if (!is_exception_free) {
ExecutionFinal(&fetch_ops);
}
} else {
......@@ -308,25 +308,25 @@ void ThreadedSSAGraphExecutor::RunOp(
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
bool ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
if (!RunOpSync(op)) return false;
}
return true;
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
return true;
} catch (...) {
exception_holder_.Catch(std::current_exception());
return false;
}
}
......
......@@ -109,9 +109,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
inline bool RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
bool RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
};
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册