diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index eb4e7ec52f907f9403e21ec2734d61824f51a58b..1d80bab90f513139f807b57258177c6b2ac53ac0 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" +#include #include #include #include "paddle/fluid/framework/executor.h" @@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } } } + std::vector fetch_data; + std::exception_ptr eptr; + try { + fetch_data = underlying_executor_->Run(fetch_tensors); + } catch (...) { + eptr = std::current_exception(); + } - auto fetch_data = underlying_executor_->Run(fetch_tensors); drop_scope_counter_ += 1; if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { @@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( scope->DeleteScope(local_scope); } } - return fetch_data; + if (eptr) { + std::rethrow_exception(eptr); + } else { + return fetch_data; + } } } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 99b10254a7961bf7b27b256acaece573a71c4115..07097c7e75c6ce638549716cd6523f387cdefd92 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( set.clear(); }; + // Clean run context + run_op_futures_.clear(); + exception_.reset(); + // Step 3. Execution while (!pending_vars.empty()) { // 1. Run All Ready ops @@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { - std::lock_guard l(exception_mu_); + std::unique_lock l(exception_mu_); if (exception_) { + l.unlock(); + for (auto &run_op_future : run_op_futures_) { + run_op_future.wait(); + } + l.lock(); std::exception *exp = exception_.get(); if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else if (dynamic_cast(exp)) { auto e = *static_cast(exp); - exception_.reset(); throw e; } else { LOG(FATAL) << "Unknown exception."; @@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp( } }; if (pool_) { - pool_->enqueue(op_run); + run_op_futures_.emplace_back(pool_->enqueue(op_run)); } else { op_run(); } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index c69e0487e2e503a0d445300aa2fd6bb9c30b06c9..09973b7a72881464ad9e7776d4aad3d2261a118d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { private: ExecutionStrategy strategy_; + // use std::list because clear(), push_back, and for_each are O(1) + std::list> run_op_futures_; }; } // namespace details diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 1dbafd23e92732bdaf0d263a01e267227786d839..e17c2ffd39eea31fe85933eda144ab97cf8c3dd8 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -23,7 +23,7 @@ class BatchReader : public framework::DecoratedReader { BatchReader(const std::shared_ptr& reader, int batch_size, bool discard_leftover) : DecoratedReader(reader), - batch_size_(batch_size), + batch_size_(static_cast(batch_size)), discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); } @@ -31,7 +31,7 @@ class BatchReader : public framework::DecoratedReader { void ReadNextImpl(std::vector* out) override; private: - int batch_size_; + size_t batch_size_; bool discard_leftover_; std::vector> buffer_; }; @@ -78,7 +78,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { void BatchReader::ReadNextImpl(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); - for (int i = 0; i < batch_size_; ++i) { + for (size_t i = 0; i < batch_size_; ++i) { buffer_.push_back(std::vector()); reader_->ReadNext(&buffer_.back()); if (buffer_.back().empty()) { @@ -95,9 +95,9 @@ void BatchReader::ReadNextImpl(std::vector* out) { // if buffer_ is empty, the 'out' will return as an empty vector. return; } - int out_num = buffer_[0].size(); + size_t out_num = buffer_[0].size(); out->reserve(out_num); - for (int j = 0; j < out_num; ++j) { + for (size_t j = 0; j < out_num; ++j) { // Merge shape and check date type std::type_index batch_type = buffer_[0][j].type(); framework::DDim batch_shape = buffer_[0][j].dims();