/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/analysis/subgraph_splitter.h" namespace paddle { namespace inference { namespace analysis { const char *SubGraphSplitter::kMarkerAttrName = "_sub_graph_splitter_inside_sub_graph"; std::vector> SubGraphSplitter::operator()() { MarkNodesInsideSubGraph(); return ExtractSubGraphs(); } // Mark the output variables inside a subgraph with the func. inline void MarkOutLinksInSubGraph(const Function *func) { for (auto *var : func->outlinks) { var->attr(SubGraphSplitter::kMarkerAttrName).Bool() = true; } } void SubGraphSplitter::MarkNodesInsideSubGraph() { for (auto &node : GraphTraits(*graph_).nodes()) { if (node_inside_subgraph_teller_(&node)) { node.attr(kMarkerAttrName).Bool() = true; if (node.type() == Node::Type::kFunction) { // If a function is inside the sub-graph, mark all the output variables // to be inside too, so that two marked functions will be inside a same // sub-graph, lets take a example: A_function->var->B_function, if // A_function is marked, var should also be marked, so that B_function // will be in the same sub-graph with A_function if B_function is // marked. MarkOutLinksInSubGraph(static_cast(&node)); } } } } const char *kUnionFindParent = "_sub_graph_splitter_union_find_parent_"; // Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node // a's output is node b, that is a and b is in the same sub-graph. The UF // algorithm will group them to the same cluster. using node_map_t = std::unordered_map; // Find the ancestor id of a node. int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { int tmp = id; do { tmp = node_map.at(tmp)->attr(kUnionFindParent).Int32(); } while (node_map.at(tmp)->attr(kUnionFindParent).Int32() != tmp); return tmp; } // Make this two node share the same ancestor. // TODO(Superjom) bad performance, make a balanced tree latter. void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { int a_ancestor = UnionFindGetAncestor(node_map, a); int b_ancestor = UnionFindGetAncestor(node_map, b); node_map.at(b_ancestor)->attr(kUnionFindParent).Int32() = a_ancestor; node_map.at(a)->attr(kUnionFindParent).Int32() = a_ancestor; node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor; } // This is a simple representation of a graph. // The BriefNode hold the pointer of the Node. // This is to avoid changing the original graph // in the process of trt graph analysis. struct BriefNode { explicit BriefNode(Node *n) { node = n; } Node *node; std::vector inlinks; std::vector outlinks; }; // Union two adjacent BriefNode. // Suppose we have two adjacent nodes src and dst. // We will perform the following operations: // 1. add all inputs(except src) of dst to src inlinks. // 2. add all outputs of dst to src outlinks. // 3. change all the dst's inputs and outputs // corresponding inlinks and outlinks to src node. // 4. delete all dst's inlinks and outlinks. void UnionContractedNodes(const std::unordered_map &node_map, int src_id, int dst_id) { // merge the two adjacent nodes into one node. BriefNode *src_node = node_map.at(src_id); BriefNode *dst_node = node_map.at(dst_id); std::unordered_set inputs(src_node->inlinks.begin(), src_node->inlinks.end()); std::unordered_set outputs; for (auto *n : src_node->outlinks) { if (n != dst_node) outputs.insert(n); } // Add the inlinks and outlinks of dst node to src node. std::vector dst_in_nodes = dst_node->inlinks; for (BriefNode *node : dst_in_nodes) { if (node != src_node) { inputs.insert(node); } } std::vector dst_out_nodes = dst_node->outlinks; for (BriefNode *node : dst_out_nodes) { outputs.insert(node); } // update the dst and src node's inlinks and outlinks. #ifdef __clang__ src_node->inlinks = std::vector(inputs.begin(), inputs.end()); src_node->outlinks = std::vector(outputs.begin(), outputs.end()); dst_node->inlinks.clear(); dst_node->outlinks.clear(); #else src_node->inlinks = std::move(std::vector(inputs.begin(), inputs.end())); src_node->outlinks = std::move(std::vector(outputs.begin(), outputs.end())); dst_node->inlinks.clear(); dst_node->outlinks.clear(); #endif auto inlink_or_outlink_cleaner = [&](std::vector &nodes) { for (auto *&n : nodes) { if (n == src_node || n == dst_node) { n = src_node; } } }; // Change all the dst inputs and outputs corresponding inlink and // outlink to the src node. for (auto *node : src_node->inlinks) { inlink_or_outlink_cleaner(node->outlinks); } for (auto *node : src_node->outlinks) { inlink_or_outlink_cleaner(node->inlinks); } } // FlexibleDFS // If reverse is true, do reverse dfs. // If enter func is not nullptr, calls enter(node) before visiting any children // of node. // If leave func not nullptr, calls leave(node) after visiting all parents of // node. void FlexibleDFS(const std::vector &source, bool reverse, const std::function &enter, const std::function &leave) { typedef struct { const BriefNode *node; bool leave; } FNode; std::vector stack; for (auto &node : source) { stack.push_back(FNode{node, false}); } std::unordered_set visited; while (!stack.empty()) { auto fnode = stack.back(); stack.pop_back(); if (fnode.leave) { if (leave && !leave(fnode.node)) return; } if (visited.count(fnode.node)) continue; visited.insert(fnode.node); if (enter && !enter(fnode.node)) return; if (leave) stack.push_back(FNode{fnode.node, true}); const std::vector iter_nodes = reverse == true ? fnode.node->inlinks : fnode.node->outlinks; for (const BriefNode *node : iter_nodes) { if (!visited.count(node)) { stack.push_back(FNode{node, false}); } } } } std::vector> SubGraphSplitter::ExtractSubGraphs() { // Run the Extract algorithm to find all subgraphs. std::vector marked_nodes; // We use brief_node_map to represent the original graph in order to avoid // changing the original graph. std::unordered_map brief_node_map; for (auto &node : GraphTraits(*graph_).nodes_in_TS()) { brief_node_map[node.id()] = new BriefNode(&node); if (node.attr(kMarkerAttrName).Bool()) { marked_nodes.push_back(&node); } } // extract sub-graphs in the marked node set, use Union Find algorithm. node_map_t node_map; // id to ptr for (auto *n : marked_nodes) { // n's parent == n.id means it is the ancestor n->attr(kUnionFindParent).Int32() = n->id(); node_map[n->id()] = n; } // create breif node map for (auto &itr : brief_node_map) { for (Node *node : itr.second->node->inlinks) { itr.second->inlinks.push_back(brief_node_map[node->id()]); } for (Node *node : itr.second->node->outlinks) { itr.second->outlinks.push_back(brief_node_map[node->id()]); } } for (auto &itr : brief_node_map) { BriefNode *brief_node = itr.second; if (!brief_node->node->attr(kMarkerAttrName).Bool()) { VLOG(4) << brief_node->node->id() << " node not a trt candicate."; continue; } // Our algorithm must guarantee that: // 1. The graph is always directed acyclic graph(DAG). // 2. If there is a path in the subgraph from X to Y (X and Y are both // nodes in the subgraph), then all paths from X to Y are in the // subgraph. // // In order to achieve the above guarantee. // For adjacent nodes src -> dst. // 1. Get all dst input nodes except src. // 2. Reverse DFS from those input nodes // 3. If there is a path from input nodes to src, // then the src and dst nodes can not be fused into one node, // otherwise it can be done. while (true) { std::unordered_set contract_nodes; for (auto *out : brief_node->outlinks) { // must be an trt candidate if (!out->node->attr(kMarkerAttrName).Bool()) continue; // get all dst input nodes except src. std::vector source_nodes; for (auto *n : out->inlinks) { if (n != brief_node) { source_nodes.push_back(n); } } // Reverse DFS from the source_nodes. bool have_excess_path = false; FlexibleDFS(source_nodes, true, nullptr, [&have_excess_path, brief_node](const BriefNode *n) { if (n == brief_node) { have_excess_path = true; return false; } return true; }); if (have_excess_path) continue; contract_nodes.insert(out); } if (contract_nodes.empty()) break; for (auto dst_node : contract_nodes) { UnionFindCombine(node_map, brief_node->node->id(), dst_node->node->id()); UnionContractedNodes(brief_node_map, brief_node->node->id(), dst_node->node->id()); } } } std::unordered_map> clusters; for (auto *n : marked_nodes) { if (n->type() == Node::Type::kFunction) { clusters[UnionFindGetAncestor(node_map, n->attr(kUnionFindParent).Int32())] .push_back(n); } } std::vector> result; std::for_each(clusters.begin(), clusters.end(), [&](const decltype(clusters)::value_type &it) { result.push_back(it.second); }); return result; } void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } void SubGraphFuse::ReplaceNodesWithSubGraphs() { auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); for (auto &subgraph : subgraphs) { std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); // replace this sub-graph with the first node. Two steps: 1. Create a Block // Node that contains this subgraph 2. Mark the nodes inside the sub-graph // as deleted. 3. Replace the deleted node with the new Block Node. auto *block_node = static_cast( graph_->nodes.Create(Node::Type::kFunctionBlock)); auto io = ExtractInputAndOutputOfSubGraph(subgraph); block_node->inlinks = std::move(io.first); block_node->outlinks = std::move(io.second); for (auto *node : subgraph) { // TODO(Superjomn) need a unified mechanism to treat deleted node in each // pass. node->SetDeleted(); block_node->subgraph.push_back(node); } // Change all the sub-graph's inputs and outputs corresponding inlink and // outlink to this sub-graph node. auto inlink_or_outlink_cleaner = [&](std::vector &nodes) { for (auto *&n : nodes) { if (subgraph_uniq.count(n)) { n = block_node; } } std::unordered_set uniq(nodes.begin(), nodes.end()); nodes.assign(uniq.begin(), uniq.end()); }; for (auto *i : block_node->inlinks) { inlink_or_outlink_cleaner(i->outlinks); } for (auto *&o : block_node->outlinks) { inlink_or_outlink_cleaner(o->inlinks); } } FilterRedundantOutputOfSubGraph(graph_); } } // namespace analysis } // namespace inference } // namespace paddle