subgraph.h 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
18
#include <unordered_map>
19 20
#include <unordered_set>
#include <vector>
21
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
22 23
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
24
#include "paddle/fluid/framework/ir/node.h"
25
#include "paddle/fluid/framework/ir/subgraph_detector.h"
26 27 28 29 30 31

namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {

32 33
class SubGraph {
 public:
34
  SubGraph() = default;
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  explicit SubGraph(int type) : type_(type) {}
  SubGraph(int type, std::string func_name, bool save_intermediate_out,
           const std::unordered_set<Node*>& nodes_set)
      : type_(type),
        func_name_(func_name),
        save_intermediate_out_(save_intermediate_out) {
    for (auto* n : nodes_set) {
      nodes_set_.insert(n);
      if (n && n->IsOp() && n->Op()) {
        // If the node is an op node, then add its input/output var nodes
        //  into the subgraph.
        for (auto* in : n->inputs) {
          nodes_set_.insert(in);
        }
        for (auto* out : n->outputs) {
          nodes_set_.insert(out);
        }
      }
    }
  }
55

56 57 58 59 60 61 62 63 64
  bool IsValid(int min_subgraph_size) {
    int num_operations = GetNumOperations();
    if (num_operations < min_subgraph_size) {
      VLOG(2) << "There are only " << num_operations
              << " operations in the subgraph. Expected at least "
              << min_subgraph_size;
      return false;
    }

65
    return true;
66
  }
67

68
  int GetType() const { return type_; }
69

70 71 72
  void SetFuncName(std::string func_name) { func_name_ = func_name; }
  std::string GetFuncName() const { return func_name_; }

73 74
  bool SaveIntermediateOut() const { return save_intermediate_out_; }

75
  const std::unordered_set<Node*>& Nodes() const { return nodes_set_; }
76
  const std::vector<Node*>& SortedNodes() {
77 78
    if (!is_sorted_) {
      TopologicalSort();
79
    }
80
    return sorted_nodes_;
81 82
  }

83
  size_t GetNumNodes() { return nodes_set_.size(); }
84

85
  bool Has(Node* n) { return nodes_set_.find(n) != nodes_set_.end(); }
86

87 88
  int GetNumOperations() {
    int num_operations = 0;
89
    for (auto* n : nodes_set_) {
90 91 92 93 94 95 96
      if (n && n->IsOp() && n->Op()) {
        num_operations++;
      }
    }
    return num_operations;
  }

97 98
  std::vector<Node*> GetInputVarNodes() {
    // The order of input nodes should be consistent anywhere.
99
    std::vector<Node*> input_vars;
100
    for (auto* n : SortedNodes()) {
101 102 103 104 105 106 107 108 109
      if (n && n->IsVar() && n->Var()) {
        bool is_found = true;
        // When the inputs size is 0, it is also considered the input var of
        // subgraph.
        if (n->inputs.size() == 0U) {
          is_found = false;
        }
        // Normally a var node has only one input op node.
        for (auto* in : n->inputs) {
110
          if (!Has(in)) {
111 112 113 114 115 116 117 118 119 120 121
            is_found = false;
          }
        }
        if (!is_found) {
          input_vars.push_back(n);
        }
      }
    }
    return input_vars;
  }

122
  std::vector<Node*> GetOutputVarNodes(bool with_intermediate_out) {
123
    // The order of output nodes should be consistant anywhere..
124
    std::vector<Node*> output_vars;
125
    for (auto* n : SortedNodes()) {
126
      if (IsOutputOfInternalOp(n)) {
127 128
        // If the var_node is the output of some op_node in the subgraph, it
        // is considered the output var node of the subgraph.
129 130 131 132 133
        if (with_intermediate_out) {
          output_vars.push_back(n);
        } else {
          if (n->outputs.empty() || IsInputOfExternalOp(n)) {
            output_vars.push_back(n);
134 135 136 137
          }
        }
      }
    }
138
    return output_vars;
139
  }
140

141
  std::vector<Node*> GetIntermediateOutVarNodes() {
142 143 144 145 146 147 148 149 150 151 152 153 154
    // Intermediate output var nodes: the output of some op_node in the
    // subgraph, but not referenced outside the subgraph.
    std::vector<Node*> intermediate_out_vars;
    for (auto* n : SortedNodes()) {
      if (IsOutputOfInternalOp(n) && IsInputOfInternalOp(n) &&
          !IsInputOfExternalOp(n)) {
        // When the outputs size is 0, it is also considered a intermidiate
        // output. It maybe an unused output or the fetching vars, so that we
        // cannot eleiminate it directly here.
        intermediate_out_vars.push_back(n);
      }
    }
    return intermediate_out_vars;
155
  }
156

157 158 159 160 161
  std::unordered_set<Node*> GetIntermediateOutVarNodesSet() {
    std::vector<Node*> intermediate_out_vars = GetIntermediateOutVarNodes();
    return std::unordered_set<Node*>(intermediate_out_vars.begin(),
                                     intermediate_out_vars.end());
  }
162

163 164 165 166 167 168 169 170
 private:
  bool IsInputOfInternalOp(Node* n) {
    bool is_input_of_internal_op = false;
    if (Has(n) && n && n->IsVar() && n->Var()) {
      for (auto* out : n->outputs) {
        if (Has(out)) {
          is_input_of_internal_op = true;
          break;
171
        }
172 173 174 175
      }
    }
    return is_input_of_internal_op;
  }
176

177 178 179 180 181 182 183 184 185
  bool IsInputOfExternalOp(Node* n) {
    // If n is the input any one node outside the subgraph.
    bool is_input_of_external_op = false;
    if (Has(n) && n && n->IsVar() && n->Var()) {
      for (auto* out : n->outputs) {
        if (!Has(out)) {
          is_input_of_external_op = true;
          break;
        }
186
      }
187 188 189
    }
    return is_input_of_external_op;
  }
190

191 192 193 194 195 196 197 198
  bool IsOutputOfInternalOp(Node* n) {
    bool is_output_of_internal_op = false;
    if (Has(n) && n && n->IsVar() && n->Var()) {
      for (auto* in : n->inputs) {
        if (Has(in)) {
          is_output_of_internal_op = true;
          break;
        }
199 200
      }
    }
201
    return is_output_of_internal_op;
202 203
  }

204 205 206 207 208 209 210 211
  void TopologicalSort() {
    if (!is_sorted_) {
      std::unordered_map<Node*, std::vector<Node*>> inputs_map;
      std::unordered_map<Node*, std::vector<Node*>> outputs_map;
      for (auto* n : nodes_set_) {
        inputs_map[n] = n->inputs;
        outputs_map[n] = n->outputs;
      }
212

213
      for (auto* n : nodes_set_) {
214
        if (n && ((n->IsVar() && n->Var()) || n->IsCtrlVar())) {
215 216 217 218 219
          // Set the input of subgraph's input var node to null.
          std::vector<Node*> inputs;
          for (auto* in : n->inputs) {
            if (Has(in)) {
              inputs.push_back(in);
220 221
            }
          }
222 223 224 225 226 227
          // Set the output of subgraph's output var node to null.
          std::vector<Node*> outputs;
          for (auto* out : n->outputs) {
            if (Has(out)) {
              outputs.push_back(out);
            }
228
          }
229 230
          n->inputs = inputs;
          n->outputs = outputs;
231 232
        }
      }
233 234 235 236 237
      // Collect the start points of the subgraph.
      std::vector<Node*> start_points;
      for (auto* n : nodes_set_) {
        if (n->inputs.empty()) {
          start_points.push_back(n);
238 239
        }
      }
240 241 242 243 244
      // Sort the subgraph.
      NodesTSIterator x(start_points);
      for (auto& n : iterator_range<NodesTSIterator>(
               NodesTSIterator(start_points), NodesTSIterator())) {
        sorted_nodes_.push_back(&n);
245
      }
246 247 248 249
      // Reset the inputs, outputs.
      for (auto* n : nodes_set_) {
        n->inputs = inputs_map[n];
        n->outputs = outputs_map[n];
250 251
      }
    }
252
    is_sorted_ = true;
253 254 255
  }

 private:
256
  int type_{-1};
257
  std::string data_type_;
258 259 260 261 262 263
  std::string func_name_;
  bool save_intermediate_out_{true};

  std::unordered_set<Node*> nodes_set_;
  bool is_sorted_{false};
  std::vector<Node*> sorted_nodes_;
264 265 266 267 268 269
};

}  // namespace fusion_group
}  // namespace ir
}  // namespace framework
}  // namespace paddle