diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 99b10254a7961bf7b27b256acaece573a71c4115..8a8c3a5938e85dab426edbe101cf019910eedb9e 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( set.clear(); }; + // Clean run context + run_op_futures_.clear(); + exception_.reset(); + // Step 3. Execution while (!pending_vars.empty()) { // 1. Run All Ready ops @@ -98,14 +102,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( if (timeout) { std::lock_guard l(exception_mu_); if (exception_) { + for (auto &run_op_future : run_op_futures_) { + run_op_future.wait(); + } std::exception *exp = exception_.get(); if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else { LOG(FATAL) << "Unknown exception."; @@ -222,7 +227,7 @@ void ThreadedSSAGraphExecutor::RunOp( } }; if (pool_) { - pool_->enqueue(op_run); + run_op_futures_.emplace_back(pool_->enqueue(op_run)); } else { op_run(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index c69e0487e2e503a0d445300aa2fd6bb9c30b06c9..09973b7a72881464ad9e7776d4aad3d2261a118d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { private: ExecutionStrategy strategy_; + // use std::list because clear(), push_back, and for_each are O(1) + std::list> run_op_futures_; }; } // namespace details