From 48a774c713b2d5bd6dc4cb71dd79a4006538367f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Sat, 19 Oct 2019 00:54:24 +0800 Subject: [PATCH] fix ts_sort's bug, test=develop (#20720) --- paddle/fluid/framework/ir/graph_traits.cc | 31 ++++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index 929d9edc34f..abcba32a649 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" #include +#include #include namespace paddle { @@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector &source) { PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); } - std::unordered_set visited; std::set to_visit{source.begin(), source.end()}; - - std::vector inlink_visited; + std::vector inlink_sorted; while (!to_visit.empty()) { std::vector queue(to_visit.begin(), to_visit.end()); for (auto *p : queue) { - inlink_visited.clear(); - - std::copy_if(p->inputs.begin(), p->inputs.end(), - std::back_inserter(inlink_visited), - [&](Node *x) -> bool { return visited.count(x) != 0; }); - - if (inlink_visited.size() == p->inputs.size()) { - sorted_.push_back(p); - for (auto *_ : p->outputs) { - if (!visited.count(_)) { - to_visit.insert(_); - } + to_visit.erase(p); + sorted_.push_back(p); + for (auto *out : p->outputs) { + inlink_sorted.clear(); + std::copy_if(out->inputs.begin(), out->inputs.end(), + std::back_inserter(inlink_sorted), [&](Node *x) -> bool { + return std::find(sorted_.begin(), sorted_.end(), x) != + sorted_.end(); + }); + if (inlink_sorted.size() == out->inputs.size()) { + to_visit.insert(out); } - - to_visit.erase(p); - visited.insert(p); } } } -- GitLab