From 7e60cc63c33f0c17df36b0ee52ae50a3d04a6697 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:13:07 +0800 Subject: [PATCH] refine case when thread_num = 1 (#36201) --- .../fast_threaded_ssa_graph_executor.cc | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 120bdd2bc9f..a690b3026db 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -47,7 +47,16 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( << "Change thread number to 1 because the toposort order is unique"; strategy_.num_threads_ = 1; } - pool_.reset(new ::ThreadPool(strategy.num_threads_)); + if (strategy_.num_threads_ > 1) { + pool_.reset(new ::ThreadPool(strategy.num_threads_)); + } else { + auto nodes = ir::TopologySortOperations(*graph_); + traced_ops_.clear(); + traced_ops_.reserve(nodes.size()); + for (auto *node : nodes) { + traced_ops_.push_back(&node->Wrapper()); + } + } for (auto &op : ir::FilterByNodeWrapper(*graph_)) { int dep = static_cast(op->NotReadyInputSize()); op_deps_.emplace(op, dep); @@ -228,7 +237,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( OpHandleBase *op, const std::shared_ptr> &complete_q) { ++remaining_; - this->pool_->enqueue([=] { + auto func = [=] { std::deque op_queue; op_queue.push_front(op); @@ -287,7 +296,12 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( } --remaining_; complete_q->Push(complete); - }); + }; + if (pool_) { + pool_->enqueue(func); + } else { + func(); + } } void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { -- GitLab