未验证 提交 c822d030 编写于 作者: Y yuyang18

Refine code

上级 3aaf7981
...@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
set.clear(); set.clear();
}; };
// Clean run context
run_op_futures_.clear();
exception_.reset();
// Step 3. Execution // Step 3. Execution
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
...@@ -98,14 +102,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -98,14 +102,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
if (timeout) { if (timeout) {
std::lock_guard<std::mutex> l(exception_mu_); std::lock_guard<std::mutex> l(exception_mu_);
if (exception_) { if (exception_) {
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
std::exception *exp = exception_.get(); std::exception *exp = exception_.get();
if (dynamic_cast<platform::EOFException *>(exp)) { if (dynamic_cast<platform::EOFException *>(exp)) {
auto e = *static_cast<platform::EOFException *>(exp); auto e = *static_cast<platform::EOFException *>(exp);
exception_.reset();
throw e; throw e;
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) { } else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
auto e = *static_cast<platform::EnforceNotMet *>(exp); auto e = *static_cast<platform::EnforceNotMet *>(exp);
exception_.reset();
throw e; throw e;
} else { } else {
LOG(FATAL) << "Unknown exception."; LOG(FATAL) << "Unknown exception.";
...@@ -222,7 +227,7 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -222,7 +227,7 @@ void ThreadedSSAGraphExecutor::RunOp(
} }
}; };
if (pool_) { if (pool_) {
pool_->enqueue(op_run); run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else { } else {
op_run(); op_run();
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <deque> #include <deque>
#include <list>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
...@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册