From a7d0d888255bd572e87bdfabb5265e916d293927 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Sun, 20 Oct 2019 16:23:58 +0800 Subject: [PATCH] CHERRY_PICK 20720: fix ts_sort's bug, test=develop (#20726) test=release/1.6 --- paddle/fluid/framework/ir/graph_traits.cc | 31 ++++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index 929d9edc34..abcba32a64 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" #include +#include #include namespace paddle { @@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector &source) { PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); } - std::unordered_set visited; std::set to_visit{source.begin(), source.end()}; - - std::vector inlink_visited; + std::vector inlink_sorted; while (!to_visit.empty()) { std::vector queue(to_visit.begin(), to_visit.end()); for (auto *p : queue) { - inlink_visited.clear(); - - std::copy_if(p->inputs.begin(), p->inputs.end(), - std::back_inserter(inlink_visited), - [&](Node *x) -> bool { return visited.count(x) != 0; }); - - if (inlink_visited.size() == p->inputs.size()) { - sorted_.push_back(p); - for (auto *_ : p->outputs) { - if (!visited.count(_)) { - to_visit.insert(_); - } + to_visit.erase(p); + sorted_.push_back(p); + for (auto *out : p->outputs) { + inlink_sorted.clear(); + std::copy_if(out->inputs.begin(), out->inputs.end(), + std::back_inserter(inlink_sorted), [&](Node *x) -> bool { + return std::find(sorted_.begin(), sorted_.end(), x) != + sorted_.end(); + }); + if (inlink_sorted.size() == out->inputs.size()) { + to_visit.insert(out); } - - to_visit.erase(p); - visited.insert(p); } } } -- GitLab