未验证 提交 48a774c7 编写于 作者: 石晓伟 提交者: GitHub

fix ts_sort's bug, test=develop (#20720)

上级 dc229b41
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include <set>
#include <utility>
#include <vector>
namespace paddle {
......@@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
}
std::unordered_set<Node *> visited;
std::set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited;
std::vector<Node *> inlink_sorted;
while (!to_visit.empty()) {
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
for (auto *p : queue) {
inlink_visited.clear();
std::copy_if(p->inputs.begin(), p->inputs.end(),
std::back_inserter(inlink_visited),
[&](Node *x) -> bool { return visited.count(x) != 0; });
if (inlink_visited.size() == p->inputs.size()) {
sorted_.push_back(p);
for (auto *_ : p->outputs) {
if (!visited.count(_)) {
to_visit.insert(_);
}
to_visit.erase(p);
sorted_.push_back(p);
for (auto *out : p->outputs) {
inlink_sorted.clear();
std::copy_if(out->inputs.begin(), out->inputs.end(),
std::back_inserter(inlink_sorted), [&](Node *x) -> bool {
return std::find(sorted_.begin(), sorted_.end(), x) !=
sorted_.end();
});
if (inlink_sorted.size() == out->inputs.size()) {
to_visit.insert(out);
}
to_visit.erase(p);
visited.insert(p);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册