ngraph_subgraph_pass.cc 5.8 KB
Newer Older
M
mozga-intel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <set>
#include <string>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/ngraph_subgraph_pass.h"
23
#include "paddle/fluid/framework/ir/subgraph_detector.h"
M
mozga-intel 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
                              const std::set<std::string> &engine_outputs,
                              const std::string &size) {
  std::string engine_hash_key = "";
  for (auto name : engine_inputs) {
    engine_hash_key += name;
  }
  for (auto name : engine_outputs) {
    engine_hash_key += name;
  }
  engine_hash_key += size;
  auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
  return engine_key;
}

47 48
void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
  PADDLE_ENFORCE_NOT_NULL(graph);
M
mozga-intel 已提交
49 50 51 52 53 54 55 56 57 58
  FusePassBase::Init("ngraph_subgraph_pass", graph);

  std::unordered_set<Node *> nodes2delete;

  auto teller = [](const Node *node) {
    if (!node->IsOp() || !node->Op()) return false;
    auto op_type = node->Op()->Type();
    return !paddle::operators::NgraphBridge::isRegister(op_type);
  };

59
  SubGraphFuser fuser(graph, teller, 0, "ngraph_engine");
M
mozga-intel 已提交
60 61 62
  fuser();

  for (auto *node : graph->Nodes()) {
63
    if (node->IsOp() && !Agent(node).subgraph()->empty()) {
M
mozga-intel 已提交
64 65 66 67 68 69
      OpDesc *op_desc = node->Op();
      op_desc->SetType("ngraph_engine");

      CreateNgraphEngineOp(node, graph);

      std::unordered_set<const Node *> nodes2remove(
70
          Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
71

M
mozga-intel 已提交
72 73 74 75 76 77
      GraphSafeRemoveNodes(graph, nodes2remove);
    }
  }

  std::unordered_set<const Node *> nodes2remove;
  for (auto *node : graph->Nodes()) {
78
    if (node->IsOp() && Agent(node).deleted()) {
M
mozga-intel 已提交
79 80 81
      nodes2remove.insert(node);
    }
  }
82

M
mozga-intel 已提交
83
  framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
84
  // std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph);
M
mozga-intel 已提交
85 86
}

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
bool IsValid(std::string name) {
  return name.find(Node::kControlDepVarName) == std::string::npos;
}

void UpdateNgraphIO(Node *node, Graph *graph,
                    std::vector<std::string> *input_names,
                    std::vector<std::string> *output_names) {
  bool is_test = true, has_fetch = false;
  for (Node *node : graph->Nodes()) {
    if (node->IsOp() && node->Name().find("_grad") != std::string::npos) {
      is_test = false;
    }
    if (node->IsVar() && node->Var()) {
      for (auto out : node->outputs) {
        if (out->Name() == "fetch") has_fetch = true;
      }
    }
  }
  if (is_test && has_fetch) {
    for (auto *x : node->inputs) {
      (*input_names).emplace_back(x->Name());
    }
    for (auto *x : node->outputs) {
      (*output_names).emplace_back(x->Name());
    }
    return;
  }

115
  auto &subgraph = *Agent(node).subgraph();
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
  std::unordered_set<std::string> inputs;
  std::unordered_set<std::string> outputs;
  for (auto *node : subgraph) {
    for (auto in : node->inputs) {
      auto name = in->Name();
      if (!IsValid(name)) continue;
      if (!outputs.count(name) && !inputs.count(name)) {
        (*input_names).emplace_back(name);
        inputs.insert(name);
      }
    }
    for (auto out : node->outputs) {
      auto name = out->Name();
      if (!IsValid(name)) continue;
      outputs.insert(name);
      (*output_names).emplace_back(name);
    }
  }
}
M
mozga-intel 已提交
135

136
void NgraphSubgraphPass::CreateNgraphEngineOp(Node *node, Graph *graph) const {
137
  auto &subgraph = *Agent(node).subgraph();
138
  PADDLE_ENFORCE_NE(subgraph.empty(), true, "subgraph cannot be empty");
M
mozga-intel 已提交
139 140 141 142 143 144 145 146 147 148

  framework::proto::BlockDesc block_proto;
  framework::BlockDesc block_desc(nullptr, &block_proto);
  block_desc.Proto()->set_parent_idx(-1);
  block_desc.Proto()->set_idx(0);
  for (auto *node : subgraph) {
    auto *op = block_desc.AppendOp();
    *op->Proto() = *node->Op()->Proto();
  }
  auto *vars = block_desc.Proto()->mutable_vars();
149
  for (Node *node : graph->Nodes()) {
M
mozga-intel 已提交
150 151 152 153
    if (node->IsVar() && node->Var()) {
      *vars->Add() = *node->Var()->Proto();
    }
  }
154 155
  PADDLE_ENFORCE_NE(block_desc.Proto()->vars().empty(), true,
                    "the block has no var-desc");
M
mozga-intel 已提交
156

157 158 159 160 161 162 163 164
  std::vector<std::string> input_names;
  std::vector<std::string> output_names;
  UpdateNgraphIO(node, graph, &input_names, &output_names);
  auto *op_desc = node->Op();
  op_desc->SetInput(
      "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
  op_desc->SetOutput(
      "Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
M
mozga-intel 已提交
165 166

  int sgs = subgraph.size();
167 168 169
  std::string subgraph_str = block_desc.Proto()->SerializeAsString();
  std::string engine_key =
      std::to_string(std::hash<std::string>()(subgraph_str));
M
mozga-intel 已提交
170
  std::vector<int> interval{0, sgs};
171
  op_desc->SetType("ngraph_engine");
M
mozga-intel 已提交
172
  op_desc->SetAttr("interval", interval);
173
  op_desc->SetAttr("graph", subgraph_str);
M
mozga-intel 已提交
174
  op_desc->SetAttr("engine_key", engine_key);
175
  op_desc->SetAttr("op_role", 0);
M
mozga-intel 已提交
176 177 178 179 180 181 182
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(ngraph_subgraph_pass, paddle::framework::ir::NgraphSubgraphPass);