diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 79b390dde48adf67d8ea9d7b57fb0777ca0e5e68..5ce92ad826741c3cc7240256be7a13db89daada4 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -30,19 +30,6 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); - if (strategy_.num_iteration_per_run_ > 1) { - int read_op_num = 0; - for (auto *node : graphs_[0]->Nodes()) { - if (node->IsOp() && node->Name() == "read") { - read_op_num++; - } - } - if (read_op_num == 0) { - LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model " - "should use pyreader to feed data!"; - } - } - // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() ? 1UL @@ -69,9 +56,6 @@ FeedFetchList AsyncSSAGraphExecutor::Run( for (size_t i = 0; i < places_.size(); ++i) { auto call = [this, i, &fetch_tensors]() -> FeedFetchList { try { - for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) { - executors_[i]->Run(fetch_tensors); - } return executors_[i]->Run(fetch_tensors); } catch (...) { exception_holder_.Catch(std::current_exception()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 677a2937945b03fa577317cb4f26e09354d06957..16fa2a6db689b182aacb20839630996cf2e04f05 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -32,9 +32,22 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( places_(places), fetch_ctxs_(places), running_ops_(0), - strategy_(strategy) {} + strategy_(strategy) { + if (strategy_.num_iteration_per_run_ > 1) { + int read_op_num = 0; + for (auto *node : graph_->Nodes()) { + if (node->IsOp() && node->Name() == "read") { + read_op_num++; + } + } + if (read_op_num == 0) { + LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model " + "should use pyreader to feed data!"; + } + } +} -FeedFetchList ThreadedSSAGraphExecutor::Run( +inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( const std::vector &fetch_tensors) { std::unique_ptr event( new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr)); @@ -140,6 +153,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( return fetch_data; } +FeedFetchList ThreadedSSAGraphExecutor::Run( + const std::vector &fetch_tensors) { + for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) { + RunImpl({}); + } + return RunImpl(fetch_tensors); +} + void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector *fetch_ops, diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 24da56c09e3e0f3894d58e5af8838c98e3e1e67c..3809b6e9ae0c43581e8dd8d0fbe0f89225fa45c3 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ~ThreadedSSAGraphExecutor() final = default; private: + inline FeedFetchList RunImpl(const std::vector &fetch_tensors); void RunOp(const std::shared_ptr> &ready_var_q, details::OpHandleBase *op);