diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 120bdd2bc9f5633250a755558c0e441e618dfe8b..a690b3026dbc2f298fbd3e84ce3a968e93f789a0 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -47,7 +47,16 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( << "Change thread number to 1 because the toposort order is unique"; strategy_.num_threads_ = 1; } - pool_.reset(new ::ThreadPool(strategy.num_threads_)); + if (strategy_.num_threads_ > 1) { + pool_.reset(new ::ThreadPool(strategy.num_threads_)); + } else { + auto nodes = ir::TopologySortOperations(*graph_); + traced_ops_.clear(); + traced_ops_.reserve(nodes.size()); + for (auto *node : nodes) { + traced_ops_.push_back(&node->Wrapper()); + } + } for (auto &op : ir::FilterByNodeWrapper(*graph_)) { int dep = static_cast(op->NotReadyInputSize()); op_deps_.emplace(op, dep); @@ -228,7 +237,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( OpHandleBase *op, const std::shared_ptr> &complete_q) { ++remaining_; - this->pool_->enqueue([=] { + auto func = [=] { std::deque op_queue; op_queue.push_front(op); @@ -287,7 +296,12 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( } --remaining_; complete_q->Push(complete); - }); + }; + if (pool_) { + pool_->enqueue(func); + } else { + func(); + } } void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {