diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index dfb40721d88717ef49a2a66297edf95c4c2c7422..f1a07edf08843ccee42b759a2dc39517289ca853 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -34,32 +34,63 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor( strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); } - VLOG(1) << "pool size: " << places_.size(); } FeedFetchList ParallelSSAGraphExecutor::Run( const std::vector &fetch_tensors) { - std::vector> run_futures; - FeedFetchList fetch_data; + std::vector> run_futures; + + std::vector fetch_datas; + FeedFetchList ret; + + fetch_datas.reserve(places_.size()); + ret.reserve(fetch_tensors.size()); + exception_holder_.Clear(); for (size_t i = 0; i < places_.size(); ++i) { - auto call = [this, i] { - // FIXME(Yancey1989): need to fix fetch data failed. - std::vector empty; - executors_[i]->Run(empty); + auto call = [this, i, &fetch_tensors]() -> FeedFetchList { + return executors_[i]->Run(fetch_tensors); }; + if (pool_) { run_futures.emplace_back(pool_->enqueue(std::move(call))); } else { - call(); + try { + fetch_datas.emplace_back(std::move(call())); + } catch (...) { + exception_holder_.Catch(std::current_exception()); + break; + } } } + if (pool_) { for (auto &f : run_futures) { - f.wait(); + if (exception_holder_.IsCaught()) { + f.wait(); + } else { + try { + fetch_datas.emplace_back(std::move(f.get())); + } catch (...) { + exception_holder_.Catch(std::current_exception()); + } + } + } + } + if (exception_holder_.IsCaught()) { + exception_holder_.ReThrow(); + } + + for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { + std::vector lodtensor_ptrs; + lodtensor_ptrs.reserve(local_scopes_.size()); + for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) { + lodtensor_ptrs.push_back(&fetch_datas.at(scope_idx).at(fetch_idx)); } + ret.emplace_back(); + ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace()); } - return fetch_data; + return ret; } } // namespace details diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index 37784775f039e96a52c6382e2826aaf982aef3c4..bd777e41f8588ecce83f84838a8a95431eb81240 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -44,6 +44,7 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { std::vector> graphs_; std::vector> executors_; + ExceptionHolder exception_holder_; }; } // namespace details diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 2a9ca3e815bd65332453655366bfe94479007097..82a7bd218590d6009b2f6e5f6bee3085d2aefce5 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -202,21 +202,6 @@ ParallelExecutor::ParallelExecutor( } } } - /** - std::vector> var_infos_list; - for (size_t i = 0; i < graphs.size(); ++i) { - std::vector var_infos; - for (auto &node : graphs[i]->Nodes()) { - if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { - var_infos.emplace_back(); - var_infos.back().name_ = node->Var()->Name(); - var_infos.back().type_ = node->Var()->GetType(); - var_infos.back().persistable_ = node->Var()->Persistable(); - } - } - var_infos_list.push_back(std::move(var_infos)); - } - **/ // If the loss_var_name is given, the number of graph should be only one. if (loss_var_name.size()) { diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 5177b7ee029d5e01956de2ff2a8d725392e63e12..8fd834be9acc80a5ff3989c158d1ae3d81873196 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include // NOLINT #include #include // NOLINT