From c822d0309bc410e382bc869e46027111d5d6c6f2 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Tue, 10 Jul 2018 16:02:24 +0800 Subject: [PATCH] Refine code --- .../framework/details/threaded_ssa_graph_executor.cc | 11 ++++++++--- .../framework/details/threaded_ssa_graph_executor.h | 3 +++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 99b10254a..8a8c3a593 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( set.clear(); }; + // Clean run context + run_op_futures_.clear(); + exception_.reset(); + // Step 3. Execution while (!pending_vars.empty()) { // 1. Run All Ready ops @@ -98,14 +102,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( if (timeout) { std::lock_guard l(exception_mu_); if (exception_) { + for (auto &run_op_future : run_op_futures_) { + run_op_future.wait(); + } std::exception *exp = exception_.get(); if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else { LOG(FATAL) << "Unknown exception."; @@ -222,7 +227,7 @@ void ThreadedSSAGraphExecutor::RunOp( } }; if (pool_) { - pool_->enqueue(op_run); + run_op_futures_.emplace_back(pool_->enqueue(op_run)); } else { op_run(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index c69e0487e..09973b7a7 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { private: ExecutionStrategy strategy_; + // use std::list because clear(), push_back, and for_each are O(1) + std::list> run_op_futures_; }; } // namespace details -- GitLab