未验证 提交 48a774c7 编写于 作者: 石晓伟 提交者: GitHub

fix ts_sort's bug, test=develop (#20720)

上级 dc229b41
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include <set> #include <set>
#include <utility>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) { ...@@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0)); PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
} }
std::unordered_set<Node *> visited;
std::set<Node *> to_visit{source.begin(), source.end()}; std::set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_sorted;
std::vector<Node *> inlink_visited;
while (!to_visit.empty()) { while (!to_visit.empty()) {
std::vector<Node *> queue(to_visit.begin(), to_visit.end()); std::vector<Node *> queue(to_visit.begin(), to_visit.end());
for (auto *p : queue) { for (auto *p : queue) {
inlink_visited.clear(); to_visit.erase(p);
sorted_.push_back(p);
std::copy_if(p->inputs.begin(), p->inputs.end(), for (auto *out : p->outputs) {
std::back_inserter(inlink_visited), inlink_sorted.clear();
[&](Node *x) -> bool { return visited.count(x) != 0; }); std::copy_if(out->inputs.begin(), out->inputs.end(),
std::back_inserter(inlink_sorted), [&](Node *x) -> bool {
if (inlink_visited.size() == p->inputs.size()) { return std::find(sorted_.begin(), sorted_.end(), x) !=
sorted_.push_back(p); sorted_.end();
for (auto *_ : p->outputs) { });
if (!visited.count(_)) { if (inlink_sorted.size() == out->inputs.size()) {
to_visit.insert(_); to_visit.insert(out);
}
} }
to_visit.erase(p);
visited.insert(p);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册