diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index 929d9edc34ffb92f468d5b7af54a0b8da4121543..abcba32a6492b114193cfab6756ff87247956f6c 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" #include +#include #include namespace paddle { @@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector &source) { PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); } - std::unordered_set visited; std::set to_visit{source.begin(), source.end()}; - - std::vector inlink_visited; + std::vector inlink_sorted; while (!to_visit.empty()) { std::vector queue(to_visit.begin(), to_visit.end()); for (auto *p : queue) { - inlink_visited.clear(); - - std::copy_if(p->inputs.begin(), p->inputs.end(), - std::back_inserter(inlink_visited), - [&](Node *x) -> bool { return visited.count(x) != 0; }); - - if (inlink_visited.size() == p->inputs.size()) { - sorted_.push_back(p); - for (auto *_ : p->outputs) { - if (!visited.count(_)) { - to_visit.insert(_); - } + to_visit.erase(p); + sorted_.push_back(p); + for (auto *out : p->outputs) { + inlink_sorted.clear(); + std::copy_if(out->inputs.begin(), out->inputs.end(), + std::back_inserter(inlink_sorted), [&](Node *x) -> bool { + return std::find(sorted_.begin(), sorted_.end(), x) != + sorted_.end(); + }); + if (inlink_sorted.size() == out->inputs.size()) { + to_visit.insert(out); } - - to_visit.erase(p); - visited.insert(p); } } }