From 0cefb9461f596cacb76c7659aef3a55f200a1f6d Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 11 Jul 2018 16:46:55 +0800 Subject: [PATCH] add topological sortting (#12059) --- .../inference/analysis/data_flow_graph.cc | 86 ++++++++++++++++++- .../inference/analysis/data_flow_graph.h | 36 ++++++++ .../analysis/data_flow_graph_tester.cc | 69 ++++++++++++++- 3 files changed, 188 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index d09bf3ed161..bd24e8a7d9c 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -90,6 +90,20 @@ std::string DataFlowGraph::DotString() const { return dot.Build(); } +std::string DataFlowGraph::HumanReadableInfo(bool show_values, + bool show_functions) const { + std::stringstream values, functions; + for (auto &n : nodes.nodes()) { + if (show_values && n->IsValue()) { + values << n->repr() << "\n"; + } + if (show_functions && n->IsFunction()) { + functions << n->repr() << "\n"; + } + } + return "Values:\n" + values.str() + "\n\n" + "Functions:\n" + functions.str(); +} + // // NodesBFSIterator // @@ -146,7 +160,7 @@ bool GraphTraits::NodesBFSIterator::operator==( if ((!queue_.empty()) && (!other.queue_.empty())) { return queue_.front() == other.queue_.front() && visited_.size() == other.visited_.size(); // here need to check the - // equality of queue and + // equality of queue and // visited. Just a light but week implementation. } return false; @@ -208,6 +222,76 @@ Node *GraphTraits::NodesDFSIterator::operator->() { return stack_.top(); } +GraphTraits::NodesTSIterator::NodesTSIterator( + const std::vector &source) { + PADDLE_ENFORCE(!source.empty(), + "Start points of topological sorting should not be empty!"); + std::unordered_set visited; + std::unordered_set to_visit{source.begin(), source.end()}; + + std::vector inlink_visited; + while (!to_visit.empty()) { + std::vector queue(to_visit.begin(), to_visit.end()); + for (auto *p : queue) { + inlink_visited.clear(); + + std::copy_if(p->inlinks.begin(), p->inlinks.end(), + std::back_inserter(inlink_visited), + [&](Node *x) { return visited.count(x); }); + + if (inlink_visited.size() == p->inlinks.size()) { + sorted_.push_back(p); + for (auto *_ : p->outlinks) { + if (!visited.count(_)) { + to_visit.insert(_); + } + } + + to_visit.erase(p); + visited.insert(p); + } + } + } +} + +GraphTraits::NodesTSIterator::NodesTSIterator( + const paddle::inference::analysis::GraphTraits< + DataFlowGraph>::NodesTSIterator &other) + : sorted_(other.sorted_), cursor_(other.cursor_) {} + +Node &GraphTraits::NodesTSIterator::operator*() { + PADDLE_ENFORCE_LT(cursor_, sorted_.size()); + return *sorted_[cursor_]; +} + +paddle::inference::analysis::GraphTraits::NodesTSIterator + &GraphTraits::NodesTSIterator::operator++() { + if (++cursor_ >= sorted_.size()) { + sorted_.clear(); + cursor_ = 0; + } + return *this; +} +paddle::inference::analysis::GraphTraits::NodesTSIterator & +GraphTraits::NodesTSIterator::operator=( + const paddle::inference::analysis::GraphTraits< + DataFlowGraph>::NodesTSIterator &other) { + cursor_ = other.cursor_; + sorted_ = other.sorted_; + return *this; +} + +bool GraphTraits::NodesTSIterator::operator==( + const paddle::inference::analysis::GraphTraits< + DataFlowGraph>::NodesTSIterator &other) { + return sorted_ == other.sorted_ && cursor_ == other.cursor_; +} + +Node *GraphTraits::NodesTSIterator::operator->() { + PADDLE_ENFORCE_LT(cursor_, sorted_.size()); + return sorted_[cursor_]; +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h index a4fefc83e0c..5dd914d1971 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.h +++ b/paddle/fluid/inference/analysis/data_flow_graph.h @@ -48,6 +48,9 @@ struct DataFlowGraph { // Output a DOT graph file for debug. std::string DotString() const; + std::string HumanReadableInfo(bool show_values = true, + bool show_functions = true) const; + private: // Remove duplicate edges and so on. void Clean(); @@ -107,6 +110,32 @@ struct GraphTraits { std::unordered_set visited_; }; + // Topological sorting iterator on nodes. + struct NodesTSIterator + : public std::iterator { + NodesTSIterator() = default; + explicit NodesTSIterator(const std::vector &source); + NodesTSIterator(NodesTSIterator &&other) + : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) { + other.cursor_ = 0; + } + NodesTSIterator(const NodesTSIterator &other); + + Node &operator*(); + NodesTSIterator &operator++(); + // TODO(Superjomn) current implementation just compare the first + // element, need to compare the graph and all the elements in the queue and + // set. + NodesTSIterator &operator=(const NodesTSIterator &other); + bool operator==(const NodesTSIterator &other); + bool operator!=(const NodesTSIterator &other) { return !(*this == other); } + Node *operator->(); + + private: + std::vector sorted_; + int cursor_{0}; + }; + explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {} // default use BFS to visit the nodes. @@ -119,17 +148,24 @@ struct GraphTraits { iterator_range nodes_in_DFS() { return iterator_range(nodes_dfs_begin(), nodes_dfs_end()); } + iterator_range nodes_in_TS() { + return iterator_range(nodes_ts_begin(), nodes_ts_end()); + } private: NodesBFSIterator nodes_bfs_begin() { return NodesBFSIterator(graph_->inputs); } NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); } + NodesDFSIterator nodes_dfs_begin() { return NodesDFSIterator(graph_->inputs); } NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); } + NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_->inputs); } + NodesTSIterator nodes_ts_end() { return NodesTSIterator(); } + private: DataFlowGraph *graph_; }; diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc index 9d7cceeb658..7912f8d7f17 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc @@ -24,11 +24,11 @@ TEST(DataFlowGraph, BFS) { auto dfg = ProgramDescToDFG(desc); dfg.Build(); - for (auto* in : dfg.inputs) { + for (auto *in : dfg.inputs) { LOG(INFO) << "inputs: " << in->name() << " " << static_cast(in->type()); } - for (auto* out : dfg.outputs) { + for (auto *out : dfg.outputs) { LOG(INFO) << "outputs: " << out->name() << " " << static_cast(out->type()); } @@ -57,6 +57,71 @@ TEST(DataFlowGraph, DFS) { ASSERT_EQ(count, dfg.nodes.size()); } +// Topological sorting. +/* + * Graph topology + * inputs: 0, 1, 2 + * 0 -> 4 + * 0 -> 5 + * 1 -> 6 + * 2 -> 7 + * 4 -> 5 + * 4 -> 7 + * 4 -> 3 + * 7 -> 3 + */ +TEST(DataFlowGraph, TS) { + DataFlowGraph graph; + + for (int i = 0; i < 8; i++) { + auto *node = graph.nodes.Create(Node::Type::kValue); + node->SetName("node-" + std::to_string(i)); + } + + auto add_link = [&](int i, int j) { + Node *source = graph.nodes.GetMutable(i); + Node *target = graph.nodes.GetMutable(j); + target->inlinks.push_back(source); + source->outlinks.push_back(target); + }; + + graph.inputs.push_back(graph.nodes.GetMutable(0)); + graph.inputs.push_back(graph.nodes.GetMutable(1)); + graph.inputs.push_back(graph.nodes.GetMutable(2)); + + add_link(0, 4); + add_link(0, 5); + add_link(1, 6); + add_link(2, 7); + add_link(4, 5); + add_link(4, 7); + add_link(4, 3); + add_link(7, 3); + + auto its = GraphTraits(&graph).nodes_in_TS(); + std::vector sorted_ids; + for (auto it = its.begin(); it != its.end(); ++it) { + LOG(INFO) << it->name(); + sorted_ids.push_back(it->id()); + } + + // Assert a occurs prior to b in the sorted_ids. + auto assert_positive_sequence_pair = [&](int a, int b) { + auto a_offset = std::find(sorted_ids.begin(), sorted_ids.end(), a); + auto b_offset = std::find(sorted_ids.begin(), sorted_ids.end(), b); + ASSERT_LT(a_offset, b_offset); + }; + + assert_positive_sequence_pair(2, 7); + assert_positive_sequence_pair(7, 3); + assert_positive_sequence_pair(4, 3); + assert_positive_sequence_pair(0, 4); + assert_positive_sequence_pair(0, 5); + assert_positive_sequence_pair(1, 6); + assert_positive_sequence_pair(4, 5); + assert_positive_sequence_pair(4, 7); +} + } // namespace analysis } // namespace inference } // namespace paddle -- GitLab