subgraph.h 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"

namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {

struct SubGraph {
  int type{-1};
  std::string func_name;
  std::unordered_set<Node*> nodes_set;

  bool IsEmpty() { return nodes_set.empty(); }

  size_t GetNumNodes() { return nodes_set.size(); }

  int GetNumOperations() {
    int num_operations = 0;
    for (auto* n : nodes_set) {
      if (n && n->IsOp() && n->Op()) {
        num_operations++;
      }
    }
    return num_operations;
  }

  std::vector<Node*> GetInputVarNodes() const {
    // The order of input nodes should be consistent with that of the generated
    // code.
    std::vector<Node*> input_vars;
    for (auto* n : nodes_set) {
      if (n && n->IsVar() && n->Var()) {
        bool is_found = true;
        // When the inputs size is 0, it is also considered the input var of
        // subgraph.
        if (n->inputs.size() == 0U) {
          is_found = false;
        }
        // Normally a var node has only one input op node.
        for (auto* in : n->inputs) {
          if (nodes_set.find(in) == nodes_set.end()) {
            is_found = false;
          }
        }
        if (!is_found) {
          input_vars.push_back(n);
        }
      }
    }
    return input_vars;
  }

  std::vector<Node*> GetOutputVarNodes() const {
    // The order of output nodes should be consistant with that of the generated
    // code.
    std::vector<Node*> output_vars;
    for (auto* n : nodes_set) {
      if (n && n->IsVar() && n->Var()) {
        bool is_found = true;
        if (n->outputs.size() == 0U) {
          is_found = false;
        }
        for (auto* out : n->outputs) {
          if (nodes_set.find(out) == nodes_set.end()) {
            is_found = false;
          }
        }
        if (!is_found) {
          output_vars.push_back(n);
        }
      }
    }
    return output_vars;
  }
};

}  // namespace fusion_group
}  // namespace ir
}  // namespace framework
}  // namespace paddle