diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index bf7d76a8a6e173e648cea5aaba9b7202d787173b..923a7083d4f30b646bbab03d79992b275aa2b403 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -3,7 +3,10 @@ cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) +cc_library(graph_traits SRCS graph_traits.cc DEPS graph) +cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) +cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.cc b/paddle/fluid/framework/ir/graph_pattern_detecter.cc new file mode 100644 index 0000000000000000000000000000000000000000..f27d9b0509aa4561cfd1e5da3b46a3a085cc888c --- /dev/null +++ b/paddle/fluid/framework/ir/graph_pattern_detecter.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { + nodes_.emplace_back(new PDNode(std::move(teller), name)); + auto* cur = nodes_.back().get(); + return cur; +} + +void PDPattern::AddEdge(PDNode* a, PDNode* b) { + PADDLE_ENFORCE(a); + PADDLE_ENFORCE(b); + PADDLE_ENFORCE(a != b, "can't connect to the same nodes."); + edges_.emplace_back(a, b); +} + +void GraphPatternDetecter::operator()(Graph* graph, + GraphPatternDetecter::handle_t handler) { + if (!MarkPDNodesInGraph(*graph)) return; + auto subgraphs = DetectPatterns(); + UniquePatterns(&subgraphs); + RemoveOverlappedMatch(&subgraphs); + + for (auto& g : subgraphs) { + handler(g, graph); + } +} + +bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { + if (graph.Nodes().empty()) return false; + + for (auto& node : GraphTraits::DFS(graph)) { + for (const auto& pdnode : pattern_.nodes()) { + if (pdnode->Tell(&node)) { + pdnodes2nodes_[pdnode.get()].insert(&node); + } + } + } + return !pdnodes2nodes_.empty(); +} + +struct HitGroup { + std::unordered_map roles; + + bool Match(Node* node, PDNode* pat) { + return !roles.count(pat) || roles.at(pat) == node; + } + + void Register(Node* node, PDNode* pat) { roles[pat] = node; } +}; + +// Tell whether Node a links to b. +bool IsNodesLink(Node* a, Node* b) { + for (auto* node : a->outputs) { + if (b == node) { + return true; + } + } + return false; +} + +std::vector +GraphPatternDetecter::DetectPatterns() { + // Init empty subgraphs. + std::vector result; + std::vector init_groups; + PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); + auto* first_pnode = pattern_.edges().front().first; + if (!pdnodes2nodes_.count(first_pnode)) return result; + for (auto* node : pdnodes2nodes_[first_pnode]) { + HitGroup group; + group.roles[first_pnode] = node; + init_groups.emplace_back(group); + } + + int step = 0; + std::array, 2> bi_records; + bi_records[0] = std::move(init_groups); + + // Extend a PDNode to subgraphs by deducing the connection relations defined + // in edges of PDNodes. + for (const auto& edge : pattern_.edges()) { + // Each role has two PDNodes, which indicates two roles. + // Detect two Nodes that can match these two roles and they are connected. + auto& pre_groups = bi_records[step % 2]; + auto& cur_groups = bi_records[1 - (step++ % 2)]; + cur_groups.clear(); + // source -> target + for (Node* source : pdnodes2nodes_[edge.first]) { + for (Node* target : pdnodes2nodes_[edge.second]) { + // TODO(Superjomn) add some prune strategies. + for (const auto& group : pre_groups) { + HitGroup new_group = group; + if (IsNodesLink(source, target) && + new_group.Match(source, edge.first)) { + new_group.Register(source, edge.first); + if (new_group.Match(target, edge.second)) { + new_group.Register(target, edge.second); + cur_groups.push_back(new_group); + // TODO(Superjomn) need to unique + } + } + } + } + } + } + + for (auto& group : bi_records[step % 2]) { + GraphPatternDetecter::subgraph_t subgraph; + for (auto& role : group.roles) { + subgraph.emplace(role.first, role.second); + } + result.emplace_back(subgraph); + } + return result; +} + +void GraphPatternDetecter::UniquePatterns( + std::vector* subgraphs) { + if (subgraphs->empty()) return; + std::vector result; + + std::unordered_set set; + for (auto& g : *subgraphs) { + size_t key = 0; + for (auto& item : g) { + key ^= std::hash{}(item.first); + key ^= std::hash{}(item.second); + } + if (!set.count(key)) { + result.emplace_back(g); + set.insert(key); + } + } + *subgraphs = result; +} + +void GraphPatternDetecter::RemoveOverlappedMatch( + std::vector* subgraphs) { + std::vector result; + std::unordered_set node_set; + + for (const auto& subgraph : *subgraphs) { + bool valid = true; + for (auto& item : subgraph) { + if (node_set.count(item.second)) { + valid = false; + break; + } + } + if (valid) { + for (auto& item : subgraph) { + node_set.insert(item.second); + } + result.push_back(subgraph); + } + } + *subgraphs = result; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.h b/paddle/fluid/framework/ir/graph_pattern_detecter.h new file mode 100644 index 0000000000000000000000000000000000000000..1778bf00000f60e5cf8b2a585bf7e5dae0a582eb --- /dev/null +++ b/paddle/fluid/framework/ir/graph_pattern_detecter.h @@ -0,0 +1,181 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include +#endif + +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Some basic torminolygies: +// - PDPattern: a pattern defined as a data flow graph. +// - PDNode: the node in the pattern, each PDNode represents an `ir::Node` +// that meets some conditions defined in `PDNode.teller`. +// - A pattern is defined with PDNodes with edges. + +// Pattern detector node. This node helps to build a pattern. +struct PDNode { + // tell whether an ir::Node* is a candidation for a PDNode. + using teller_t = std::function; + + PDNode(teller_t&& teller, const std::string& name = "") + : teller_(teller), name_(name) { + PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); + } + + PDNode(PDNode&& other) = default; + + std::vector inlinks; + std::vector outlinks; + + bool Tell(Node* node) const { + PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); + return teller_(node); + } + + const std::string& name() const { return name_; } + + PDNode(const PDNode&) = delete; + PDNode& operator=(const PDNode&) = delete; + + private: + teller_t teller_; + std::string name_; +}; + +/* + * A pattern in a graph, which defined with PDNode and edges. Most graph + * patterns can be divided into PDNodes and link relations between them. + * + * For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD + * operators from the computation graph, the MUL's output should have only one + * consumer which is the ELEMENTWISE_ADD. + * This pattern can be defined as with the following pseudo codes + * + * // Create two operator PDNodes. + * MUL = PDPattern.NewNode() + * ELE = PDPattern.NewNode() + * // Create the variable PDNodes. + * MUL_out = PDPattern.NewNode() + * // Add teller to define some rules that help to filter the target Nodes. + * MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; + * ELE.teller = lambda(node): \ + * node->IsOp() && node->Op()->Type == "elementwise_add"; + * MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) + * && (ELE in node->outputs) + * + * One can add more specific tellers for PDNodes or edges, both the Operator + * and Variable Nodes can be ruled in PDNode.teller. + * + * PDPattern can record the general patterns, such as the pattern represents + * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. + * - Ops whose inputs and outputs share the same variables + */ +class PDPattern { + public: + using edge_t = std::pair; + + void AddEdge(PDNode* a, PDNode* b); + + PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = ""); + + const std::vector>& nodes() const { return nodes_; } + const std::vector& edges() const { return edges_; } + + private: +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PDPattern, AddEdge); + FRIEND_TEST(PDPattern, NewNode); +#endif + + std::vector> nodes_; + std::vector edges_; +}; + +/* + * GraphPatternDetecter helps to detect the specific patterns in the graph. + * Input a pattern, output a list of the matched subgraphs/nodes. + * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). + * + * The algorithm has three phases: + * 1. Mark the nodes that match the defined PDNodes in a PDPattern, + * 2. Extend a PDNode to subgraphs by deducing the connection relation defined + * in PAPattern(the edges), + * 3. Get the filtered subgraphs and treat them with a pre-defined handler. + * + * Usage: + * // Create a detector + * GraphPatternDetecter detector; + * // Define the detector's pattern, by adding PDNode and define the edges. + * auto* node0 = detector.mutable_pattern().AddNode(...) + * auto* node1 = detector.mutable_pattern().AddNode(...) + * node0->teller = some lambda. + * node1->teller = some lambda. + * detector.mutable_pattern().AddEdge(node0, node1); + * // Create an handler, to define the behavior of treating the filtered + * // subgraphs that comply with the patterns. + * GraphPatternDetecter::handle_t handler = some labmda + * // Execute the detector. + * detector(&graph, handler); + */ +class GraphPatternDetecter { + public: + using subgraph_t = std::unordered_map; + + // Operate on the detected pattern. + using handle_t = + std::function; + + void operator()(Graph* graph, handle_t handler); + + const PDPattern& pattern() const { return pattern_; } + PDPattern* mutable_pattern() { return &pattern_; } + + private: + // Mark the nodes that fits the pattern. + bool MarkPDNodesInGraph(const ir::Graph& graph); + + // Detect all the pattern and output the hit records. + std::vector DetectPatterns(); + + // Remove duplicate patterns. + void UniquePatterns(std::vector* subgraphs); + + // Remove overlapped match subgraphs, when overlapped, keep the previous one. + void RemoveOverlappedMatch(std::vector* subgraphs); + +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); + FRIEND_TEST(GraphPatternDetecter, DetectPatterns); +#endif + + private: + using hit_rcd_t = + std::pair; + PDPattern pattern_; + std::vector marked_records_; + std::unordered_map> pdnodes2nodes_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..993c885a810fe80a170ed190b892b148d85e8b5f --- /dev/null +++ b/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" + +#include + +namespace paddle { +namespace framework { +namespace ir { + +void BuildGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); +} + +TEST(PDPattern, NewNode) { + PDPattern x; + auto* n = x.NewNode([](Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PDPattern, AddEdge) { + PDPattern x; + auto* a = x.NewNode([](Node* x) { return true; }); + auto* b = x.NewNode([](Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(GraphPatternDetecter, MarkPDNodesInGraph) { + GraphPatternDetecter x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->Name() == "op2" && node->IsOp(); + }); + auto* o3 = x.pattern_.NewNode([](Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->Name() == "op3" && node->IsOp(); + }); + auto* v2 = x.pattern_.NewNode([](Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->Name() == "var2" && node->IsVar(); + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + ProgramDesc program; + Graph graph(program); + BuildGraph(&graph); + + x.MarkPDNodesInGraph(graph); + + ASSERT_EQ(x.pdnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(GraphPatternDetecter, MultiSubgraph) { + ProgramDesc program; + Graph graph(program); + BuildGraph(&graph); + + GraphPatternDetecter x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](Node* node) { + return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3"); + }, + "OP0"); + auto* any_var = x.mutable_pattern()->NewNode( + [](Node* node) { return node->IsVar(); }, "VAR"); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](Node* node) { return node->IsOp(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + GraphPatternDetecter::handle_t handle = [&]( + const GraphPatternDetecter::subgraph_t& s, Graph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> " + << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name(); + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1UL); + ASSERT_LE(count, 2UL); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f548913e4e1d9d5bc5bdace8b92db9065cf3b5e --- /dev/null +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/graph_traits.h" + +namespace paddle { +namespace framework { +namespace ir { + +// +// NodesDFSIterator +// +NodesDFSIterator::NodesDFSIterator(const std::vector &source) { + for (auto *x : source) stack_.push(x); +} + +NodesDFSIterator::NodesDFSIterator(NodesDFSIterator &&other) noexcept + : stack_(std::move(other.stack_)), + visited_(std::move(other.visited_)) {} + +NodesDFSIterator::NodesDFSIterator(const NodesDFSIterator &other) + : stack_(other.stack_), visited_(other.visited_) {} + +Node &NodesDFSIterator::operator*() { + PADDLE_ENFORCE(!stack_.empty()); + return *stack_.top(); +} + +NodesDFSIterator &NodesDFSIterator::operator++() { + PADDLE_ENFORCE(!stack_.empty(), "the iterator exceeds range"); + visited_.insert(stack_.top()); + auto *cur = stack_.top(); + stack_.pop(); + for (auto *x : cur->outputs) { + if (!visited_.count(x)) { + stack_.push(x); + } + } + return *this; +} +bool NodesDFSIterator::operator==(const NodesDFSIterator &other) { + if (stack_.empty()) return other.stack_.empty(); + if ((!stack_.empty()) && (!other.stack_.empty())) { + return stack_.top() == other.stack_.top(); + } + return false; +} + +NodesDFSIterator &NodesDFSIterator::operator=(const NodesDFSIterator &other) { + stack_ = other.stack_; + visited_ = other.visited_; + return *this; +} +Node *NodesDFSIterator::operator->() { return stack_.top(); } + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_traits.h b/paddle/fluid/framework/ir/graph_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..edbe45acb98326ee3bf1d86495832ec8469b634e --- /dev/null +++ b/paddle/fluid/framework/ir/graph_traits.h @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +class iterator_range { + IteratorT begin_, end_; + + public: + template + explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {} + + iterator_range(const IteratorT &begin, const IteratorT &end) + : begin_(begin), end_(end) {} + + const IteratorT &begin() const { return begin_; } + const IteratorT &end() const { return end_; } +}; + +// DFS iterator on nodes. +struct NodesDFSIterator + : public std::iterator { + NodesDFSIterator() = default; + explicit NodesDFSIterator(const std::vector &source); + NodesDFSIterator(NodesDFSIterator &&other) noexcept; + NodesDFSIterator(const NodesDFSIterator &other); + + Node &operator*(); + NodesDFSIterator &operator++(); + // TODO(Superjomn) current implementation just compare the first + // element, need to compare the graph and all the elements in the queue and + // set. + NodesDFSIterator &operator=(const NodesDFSIterator &other); + bool operator==(const NodesDFSIterator &other); + bool operator!=(const NodesDFSIterator &other) { return !(*this == other); } + Node *operator->(); + + private: + std::stack stack_; + std::unordered_set visited_; +}; + +/* + * GraphTraits contains some graph traversal algorithms. + * + * Usage: + * + */ +struct GraphTraits { + static iterator_range DFS(const Graph &g) { + auto start_points = ExtractStartPoints(g); + NodesDFSIterator x(start_points); + return iterator_range(NodesDFSIterator(start_points), + NodesDFSIterator()); + } + + private: + // The nodes those have no input will be treated as start points. + static std::vector ExtractStartPoints(const Graph &g) { + std::vector result; + for (auto *node : g.Nodes()) { + if (node->inputs.empty()) { + result.push_back(node); + } + } + return result; + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index b3138fccee86fb274abe72007961fc1c982b1e96..9c0765ab8ce16733ac021aefc8c7b2bb779319f3 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -58,6 +58,9 @@ class Node { return op_desc_; } + bool IsOp() const { return type_ == Type::kOperation; } + bool IsVar() const { return type_ == Type::kVariable; } + std::vector inputs; std::vector outputs;