subgraph_splitter.cc 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/analysis/subgraph_splitter.h"

namespace paddle {
namespace inference {
namespace analysis {

const char *SubGraphSplitter::kMarkerAttrName =
    "_sub_graph_splitter_inside_sub_graph";

std::vector<std::vector<Node *>> SubGraphSplitter::operator()() {
  MarkNodesInsideSubGraph();
  return ExtractSubGraphs();
}

// Mark the output variables inside a subgraph with the func.
inline void MarkOutLinksInSubGraph(const Function *func) {
  for (auto *var : func->outlinks) {
    var->attr(SubGraphSplitter::kMarkerAttrName).Bool() = true;
  }
}

void SubGraphSplitter::MarkNodesInsideSubGraph() {
  for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) {
    if (node_inside_subgraph_teller_(&node)) {
      node.attr(kMarkerAttrName).Bool() = true;
      if (node.type() == Node::Type::kFunction) {
        // If a function is inside the sub-graph, mark all the output variables
        // to be inside too, so that two marked functions will be inside a same
        // sub-graph, lets take a example:  A_function->var->B_function, if
        // A_function is marked, var should also be marked, so that B_function
        // will be in the same sub-graph with A_function if B_function is
        // marked.
        MarkOutLinksInSubGraph(static_cast<const Function *>(&node));
      }
    }
  }
}

const char *kUnionFindParent = "_sub_graph_splitter_union_find_parent_";

// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
// a's output is node b, that is a and b is in the same sub-graph. The UF
// algorithm will group them to the same cluster.
using node_map_t = std::unordered_map<int, Node *>;
// Find the ancestor id of a node.
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
  int tmp = id;
  do {
    tmp = node_map.at(tmp)->attr(kUnionFindParent).Int32();
  } while (node_map.at(tmp)->attr(kUnionFindParent).Int32() != tmp);
  return tmp;
}
// Make this two node share the same ancestor.
// TODO(Superjom) bad performance, make a balanced tree latter.
void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
  int a_ancestor = UnionFindGetAncestor(node_map, a);
  int b_ancestor = UnionFindGetAncestor(node_map, b);
  node_map.at(b_ancestor)->attr(kUnionFindParent).Int32() = a_ancestor;
  node_map.at(a)->attr(kUnionFindParent).Int32() = a_ancestor;
  node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor;
}

std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
  std::vector<Node *> marked_nodes;
79
  for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes_in_TS()) {
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    if (node.attr(kMarkerAttrName).Bool()) {
      marked_nodes.push_back(&node);
    }
  }
  // extract sub-graphs in the marked node set, use Union Find algorithm.
  node_map_t node_map;  // id to ptr
  for (auto *n : marked_nodes) {
    // n's parent == n.id means it is the ancestor
    n->attr(kUnionFindParent).Int32() = n->id();
    node_map[n->id()] = n;
  }
  std::unordered_set<Node *> visited;
  for (auto *n : marked_nodes) {
    for (auto *out : n->outlinks) {
      if (node_map.count(out->id())) {
        UnionFindCombine(node_map, n->id(), out->id());
      }
    }
  }

  std::unordered_map<int /*ancestor*/, std::vector<Node *>> clusters;
  for (auto *n : marked_nodes) {
    if (n->type() == Node::Type::kFunction) {
      clusters[UnionFindGetAncestor(node_map,
                                    n->attr(kUnionFindParent).Int32())]
          .push_back(n);
    }
  }
  std::vector<std::vector<Node *>> result;
  std::for_each(clusters.begin(), clusters.end(),
                [&](const decltype(clusters)::value_type &it) {
                  result.push_back(it.second);
                });

  return result;
}

void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }

void SubGraphFuse::ReplaceNodesWithSubGraphs() {
  auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
  for (auto &subgraph : subgraphs) {
122
    std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
123 124 125
    // replace this sub-graph with the first node. Two steps: 1. Create a Block
    // Node that contains this subgraph 2. Mark the nodes inside the sub-graph
    // as deleted. 3. Replace the deleted node with the new Block Node.
126 127
    auto *block_node = static_cast<FunctionBlock *>(
        graph_->nodes.Create(Node::Type::kFunctionBlock));
128 129 130 131 132 133 134
    auto io = ExtractInputAndOutputOfSubGraph(subgraph);
    block_node->inlinks = std::move(io.first);
    block_node->outlinks = std::move(io.second);
    for (auto *node : subgraph) {
      // TODO(Superjomn) need a unified mechanism to treat deleted node in each
      // pass.
      node->SetDeleted();
135
      block_node->subgraph.push_back(node);
136 137
    }

138 139 140 141 142 143 144 145 146 147 148 149 150
    // Change all the sub-graph's inputs and outputs corresponding inlink and
    // outlink to this sub-graph node.
    auto inlink_or_outlink_cleaner = [&](std::vector<Node *> &nodes) {
      for (auto *&n : nodes) {
        if (subgraph_uniq.count(n)) {
          n = block_node;
        }
      }
      std::unordered_set<Node *> uniq(nodes.begin(), nodes.end());
      nodes.assign(uniq.begin(), uniq.end());
    };
    for (auto *i : block_node->inlinks) {
      inlink_or_outlink_cleaner(i->outlinks);
151
    }
152 153
    for (auto *&o : block_node->outlinks) {
      inlink_or_outlink_cleaner(o->inlinks);
154 155 156 157 158 159 160
    }
  }
}

}  // namespace analysis
}  // namespace inference
}  // namespace paddle