未验证 提交 2acc2b14 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] fix CINN graph symbolization topo sort fixed (#52556)

上级 84bb7a96
......@@ -162,14 +162,50 @@ CinnGraphSymbolization::CreateCinnScope(const FeedInfoMap& feed_map) {
}
std::vector<Node*> CinnGraphSymbolization::TopologicalSort() const {
std::unordered_set<Node*> op_nodes;
std::unordered_set<Node*> node_set;
std::for_each(
graph_.Nodes().begin(), graph_.Nodes().end(), [&op_nodes](Node* n) {
if (n->IsOp()) {
op_nodes.emplace(n);
graph_.Nodes().begin(), graph_.Nodes().end(), [&node_set](Node* n) {
if (n && n->IsOp()) {
node_set.emplace(n);
}
});
std::vector<Node*> op_nodes(node_set.begin(), node_set.end());
std::stable_sort(op_nodes.begin(), op_nodes.end(), [](Node* op1, Node* op2) {
auto out1 = op1->outputs;
auto out2 = op2->outputs;
if (out1.size() != out2.size()) {
return out1.size() < out2.size();
}
auto var_compare = [](Node* var1, Node* var2) {
if (!var1) {
// the null node one the front
return true;
} else if (!var2) {
return false;
}
// sorted by name
return var1->Name() < var2->Name();
};
std::stable_sort(out1.begin(), out1.end(), var_compare);
std::stable_sort(out2.begin(), out2.end(), var_compare);
for (int i = 0; i < out1.size(); ++i) {
if (!out1[i] && !out2[i]) {
continue;
} else if (!out1[i]) {
return true;
} else if (!out2[i]) {
return false;
} else if (out1[i]->Name() != out2[i]->Name()) {
return out1[i]->Name() < out2[i]->Name();
}
}
return true;
});
std::unordered_map<Node*, std::unordered_map<Node*, size_t>> adj_list;
std::unordered_map<Node*, size_t> in_degrees;
for (auto* n : op_nodes) {
......@@ -177,7 +213,7 @@ std::vector<Node*> CinnGraphSymbolization::TopologicalSort() const {
for (auto* in_var : n->inputs) {
// the var's input is op
for (auto* in_op : in_var->inputs) {
if (op_nodes.count(in_op)) {
if (node_set.count(in_op)) {
++adj_list[in_op][n];
++in_degrees[n];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册