From 5ec2fb0c93d0e799ea1fc215be0072488399c31e Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 31 Aug 2018 11:32:35 +0000 Subject: [PATCH] add flexibledfs for find path between two nodes --- .../inference/analysis/data_flow_graph.cc | 37 ++++++++++ .../inference/analysis/data_flow_graph.h | 3 + .../analysis/data_flow_graph_tester.cc | 71 +++++++++++++++++++ 3 files changed, 111 insertions(+) diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 100a7504b..e4f4bbf43 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -480,6 +480,8 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { for (auto *out : op_nodes[i]->outlinks) { if (follow_up_input_names.count(out->name())) { filtered_subgraph_outlinks.push_back(out); + } else { + out->SetDeleted(); } } PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL); @@ -487,6 +489,41 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { } } +void FlexibleDFS(const std::vector &source, bool reverse, + const std::function &enter, + const std::function &leave) { + typedef struct { + const Node *node; + bool leave; + } FNode; + std::vector stack; + for (auto &node : source) { + stack.push_back(FNode{node, false}); + } + std::unordered_set visited; + while (!stack.empty()) { + auto fnode = stack.back(); + stack.pop_back(); + + if (fnode.leave) { + if (leave && !leave(fnode.node)) return; + } + if (visited.count(fnode.node)) continue; + visited.insert(fnode.node); + + if (enter && !enter(fnode.node)) return; + + if (leave) stack.push_back(FNode{fnode.node, true}); + const std::vector iter_nodes = + reverse == true ? fnode.node->inlinks : fnode.node->outlinks; + for (const Node *node : iter_nodes) { + if (!visited.count(node)) { + stack.push_back(FNode{node, false}); + } + } + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h index 437e097ac..4fefc175f 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.h +++ b/paddle/fluid/inference/analysis/data_flow_graph.h @@ -204,6 +204,9 @@ std::pair, std::vector> ExtractInputAndOutputOfSubGraph(std::vector &graph); // NOLINT void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph); +void FlexibleDFS(const std::vector &source, bool reverse, + const std::function &enter, + const std::function &leave); } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc index 1682011c3..040ca1951 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc @@ -160,6 +160,77 @@ TEST(DataFlowGraph, Build_IR_Graph) { ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size()); } +// FlexibleDFS +/* + * Graph topology + * inputs: 0 + * 0 -> 1 + * 1 -> 2 + * 1 -> 3 + * 3 -> 4 + * 4 -> 5 + * 5 -> 2 + */ +TEST(DataFlowGraph, flexibledfs) { + DataFlowGraph graph; + + for (int i = 0; i < 6; i++) { + auto* node = graph.nodes.Create(Node::Type::kValue); + node->SetName("node-" + std::to_string(i)); + } + + auto add_link = [&](int i, int j) { + Node* source = graph.nodes.GetMutable(i); + Node* target = graph.nodes.GetMutable(j); + target->inlinks.push_back(source); + source->outlinks.push_back(target); + }; + + add_link(0, 1); + add_link(1, 2); + add_link(1, 3); + add_link(3, 4); + add_link(4, 5); + add_link(5, 2); + graph.Build(); + + std::vector order; + FlexibleDFS(graph.inputs(), false, nullptr, [&order](const Node* n) { + order.push_back(n); + return true; + }); + + ASSERT_EQ(order.size(), 6UL); + + order.clear(); + // reverse dfs + FlexibleDFS(graph.outputs(), true, nullptr, [&order](const Node* n) { + order.push_back(n); + return true; + }); + + ASSERT_EQ(order.size(), 6UL); + + // If we delete + Node* last_node = graph.nodes.GetMutable(2); + Node* direct_node = graph.nodes.GetMutable(1); + std::vector source_nodes; + for (Node* node : last_node->inlinks) { + if (node != direct_node) source_nodes.push_back(node); + } + + bool has_cycle = false; + FlexibleDFS(source_nodes, true, nullptr, + [&has_cycle, direct_node](const Node* n) { + if (n == direct_node) { + has_cycle = true; + return false; + } + return true; + }); + ASSERT_TRUE(has_cycle); +} + } // namespace analysis } // namespace inference } // namespace paddle -- GitLab