未验证 提交 a1aae040 编写于 作者: W Wilber 提交者: GitHub

[Inference] Replace unordered_map with map to support subgraph stability (#35147)

上级 e4a8815d
...@@ -117,7 +117,7 @@ void SubgraphDetector::MarkNodesInsideSubGraph() { ...@@ -117,7 +117,7 @@ void SubgraphDetector::MarkNodesInsideSubGraph() {
// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node // Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
// a's output is node b, that is a and b is in the same sub-graph. The UF // a's output is node b, that is a and b is in the same sub-graph. The UF
// algorithm will group them to the same cluster. // algorithm will group them to the same cluster.
using node_map_t = std::unordered_map<int, Node *>; using node_map_t = std::map<int, Node *>;
// Find the ancestor id of a node. // Find the ancestor id of a node.
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
int tmp = id; int tmp = id;
...@@ -155,7 +155,7 @@ struct BriefNode { ...@@ -155,7 +155,7 @@ struct BriefNode {
// 3. change all the dst's inputs and outputs // 3. change all the dst's inputs and outputs
// corresponding inlinks and outlinks to src node. // corresponding inlinks and outlinks to src node.
// 4. delete all dst's inlinks and outlinks. // 4. delete all dst's inlinks and outlinks.
void UnionContractedNodes(const std::unordered_map<int, BriefNode *> &node_map, void UnionContractedNodes(const std::map<int, BriefNode *> &node_map,
int src_id, int dst_id) { int src_id, int dst_id) {
// merge the two adjacent nodes into one node. // merge the two adjacent nodes into one node.
BriefNode *src_node = node_map.at(src_id); BriefNode *src_node = node_map.at(src_id);
...@@ -262,7 +262,7 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() { ...@@ -262,7 +262,7 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
// We use brief_node_map to represent the original graph in order to avoid // We use brief_node_map to represent the original graph in order to avoid
// changing the original graph. // changing the original graph.
std::unordered_map<int, BriefNode *> brief_node_map; std::map<int, BriefNode *> brief_node_map;
std::unordered_set<int32_t> valid_node_ids; std::unordered_set<int32_t> valid_node_ids;
for (auto *node : graph_->Nodes()) { for (auto *node : graph_->Nodes()) {
......
...@@ -167,7 +167,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest): ...@@ -167,7 +167,7 @@ class FusionGroupPassSumTest(FusionGroupPassTest):
self.append_gradients(tmp_3) self.append_gradients(tmp_3)
self.num_fused_ops = 4 self.num_fused_ops = 3
self.fetch_list = [tmp_3, self.grad(tmp_0)] self.fetch_list = [tmp_3, self.grad(tmp_0)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册