提交 5ec2fb0c 编写于 作者: N nhzlx

add flexibledfs for find path between two nodes

上级 af15f6f0
......@@ -480,6 +480,8 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
for (auto *out : op_nodes[i]->outlinks) {
if (follow_up_input_names.count(out->name())) {
filtered_subgraph_outlinks.push_back(out);
} else {
out->SetDeleted();
}
}
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
......@@ -487,6 +489,41 @@ void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
}
}
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
const std::function<bool(const Node *)> &enter,
const std::function<bool(const Node *)> &leave) {
typedef struct {
const Node *node;
bool leave;
} FNode;
std::vector<FNode> stack;
for (auto &node : source) {
stack.push_back(FNode{node, false});
}
std::unordered_set<const Node *> visited;
while (!stack.empty()) {
auto fnode = stack.back();
stack.pop_back();
if (fnode.leave) {
if (leave && !leave(fnode.node)) return;
}
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (enter && !enter(fnode.node)) return;
if (leave) stack.push_back(FNode{fnode.node, true});
const std::vector<Node *> iter_nodes =
reverse == true ? fnode.node->inlinks : fnode.node->outlinks;
for (const Node *node : iter_nodes) {
if (!visited.count(node)) {
stack.push_back(FNode{node, false});
}
}
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -204,6 +204,9 @@ std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
void FlexibleDFS(const std::vector<Node *> &source, bool reverse,
const std::function<bool(const Node *)> &enter,
const std::function<bool(const Node *)> &leave);
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -160,6 +160,77 @@ TEST(DataFlowGraph, Build_IR_Graph) {
ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size());
}
// FlexibleDFS
/*
* Graph topology
* inputs: 0
* 0 -> 1
* 1 -> 2
* 1 -> 3
* 3 -> 4
* 4 -> 5
* 5 -> 2
*/
TEST(DataFlowGraph, flexibledfs) {
DataFlowGraph graph;
for (int i = 0; i < 6; i++) {
auto* node = graph.nodes.Create(Node::Type::kValue);
node->SetName("node-" + std::to_string(i));
}
auto add_link = [&](int i, int j) {
Node* source = graph.nodes.GetMutable(i);
Node* target = graph.nodes.GetMutable(j);
target->inlinks.push_back(source);
source->outlinks.push_back(target);
};
add_link(0, 1);
add_link(1, 2);
add_link(1, 3);
add_link(3, 4);
add_link(4, 5);
add_link(5, 2);
graph.Build();
std::vector<const Node*> order;
FlexibleDFS(graph.inputs(), false, nullptr, [&order](const Node* n) {
order.push_back(n);
return true;
});
ASSERT_EQ(order.size(), 6UL);
order.clear();
// reverse dfs
FlexibleDFS(graph.outputs(), true, nullptr, [&order](const Node* n) {
order.push_back(n);
return true;
});
ASSERT_EQ(order.size(), 6UL);
// If we delete
Node* last_node = graph.nodes.GetMutable(2);
Node* direct_node = graph.nodes.GetMutable(1);
std::vector<Node*> source_nodes;
for (Node* node : last_node->inlinks) {
if (node != direct_node) source_nodes.push_back(node);
}
bool has_cycle = false;
FlexibleDFS(source_nodes, true, nullptr,
[&has_cycle, direct_node](const Node* n) {
if (n == direct_node) {
has_cycle = true;
return false;
}
return true;
});
ASSERT_TRUE(has_cycle);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册