提交 82726402 编写于 作者: Y Yancey1989

exception safe

上级 79082c94
......@@ -34,32 +34,63 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i])));
}
VLOG(1) << "pool size: " << places_.size();
}
FeedFetchList ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::vector<std::future<void>> run_futures;
FeedFetchList fetch_data;
std::vector<std::future<FeedFetchList>> run_futures;
std::vector<FeedFetchList> fetch_datas;
FeedFetchList ret;
fetch_datas.reserve(places_.size());
ret.reserve(fetch_tensors.size());
exception_holder_.Clear();
for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i] {
// FIXME(Yancey1989): need to fix fetch data failed.
std::vector<std::string> empty;
executors_[i]->Run(empty);
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
return executors_[i]->Run(fetch_tensors);
};
if (pool_) {
run_futures.emplace_back(pool_->enqueue(std::move(call)));
} else {
call();
try {
fetch_datas.emplace_back(std::move(call()));
} catch (...) {
exception_holder_.Catch(std::current_exception());
break;
}
}
}
if (pool_) {
for (auto &f : run_futures) {
f.wait();
if (exception_holder_.IsCaught()) {
f.wait();
} else {
try {
fetch_datas.emplace_back(std::move(f.get()));
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
}
}
}
if (exception_holder_.IsCaught()) {
exception_holder_.ReThrow();
}
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size());
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
lodtensor_ptrs.push_back(&fetch_datas.at(scope_idx).at(fetch_idx));
}
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
return fetch_data;
return ret;
}
} // namespace details
......
......@@ -44,6 +44,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<ir::Graph>> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
};
} // namespace details
......
......@@ -202,21 +202,6 @@ ParallelExecutor::ParallelExecutor(
}
}
}
/**
std::vector<std::vector<details::VariableInfo>> var_infos_list;
for (size_t i = 0; i < graphs.size(); ++i) {
std::vector<details::VariableInfo> var_infos;
for (auto &node : graphs[i]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
var_infos_list.push_back(std::move(var_infos));
}
**/
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include <pthread.h>
#include <condition_variable> // NOLINT
#include <functional>
#include <future> // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册