ngraph_subgraph_pass.cc 6.0 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <set>
#include <string>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/ngraph_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

namespace ANAT = paddle::inference::analysis;

std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
                              const std::set<std::string> &engine_outputs,
                              const std::string &size) {
  std::string engine_hash_key = "";
  for (auto name : engine_inputs) {
    engine_hash_key += name;
  }
  for (auto name : engine_outputs) {
    engine_hash_key += name;
  }
  engine_hash_key += size;
  auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
  return engine_key;
}

50 51
void NgraphSubgraphPass::ApplyImpl(Graph *graph) const {
  PADDLE_ENFORCE_NOT_NULL(graph);
X
xiexionghang 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  FusePassBase::Init("ngraph_subgraph_pass", graph);

  std::unordered_set<Node *> nodes2delete;

  auto teller = [](const Node *node) {
    if (!node->IsOp() || !node->Op()) return false;
    auto op_type = node->Op()->Type();
    return !paddle::operators::NgraphBridge::isRegister(op_type);
  };

  ANAT::SubGraphFuser fuser(graph, teller, 0, "ngraph_engine");
  fuser();

  for (auto *node : graph->Nodes()) {
    if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) {
      OpDesc *op_desc = node->Op();
      op_desc->SetType("ngraph_engine");

      CreateNgraphEngineOp(node, graph);

      std::unordered_set<const Node *> nodes2remove(
          ANAT::Agent(node).subgraph()->begin(),
          ANAT::Agent(node).subgraph()->end());
75

X
xiexionghang 已提交
76 77 78 79 80 81 82 83 84 85
      GraphSafeRemoveNodes(graph, nodes2remove);
    }
  }

  std::unordered_set<const Node *> nodes2remove;
  for (auto *node : graph->Nodes()) {
    if (node->IsOp() && ANAT::Agent(node).deleted()) {
      nodes2remove.insert(node);
    }
  }
86

X
xiexionghang 已提交
87
  framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
88
  // std::vector<ir::Node *> nodes = ir::TopologySortOperations(*graph);
X
xiexionghang 已提交
89 90
}

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
bool IsValid(std::string name) {
  return name.find(Node::kControlDepVarName) == std::string::npos;
}

void UpdateNgraphIO(Node *node, Graph *graph,
                    std::vector<std::string> *input_names,
                    std::vector<std::string> *output_names) {
  bool is_test = true, has_fetch = false;
  for (Node *node : graph->Nodes()) {
    if (node->IsOp() && node->Name().find("_grad") != std::string::npos) {
      is_test = false;
    }
    if (node->IsVar() && node->Var()) {
      for (auto out : node->outputs) {
        if (out->Name() == "fetch") has_fetch = true;
      }
    }
  }
  if (is_test && has_fetch) {
    for (auto *x : node->inputs) {
      (*input_names).emplace_back(x->Name());
    }
    for (auto *x : node->outputs) {
      (*output_names).emplace_back(x->Name());
    }
    return;
  }

X
xiexionghang 已提交
119
  auto &subgraph = *ANAT::Agent(node).subgraph();
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  std::unordered_set<std::string> inputs;
  std::unordered_set<std::string> outputs;
  for (auto *node : subgraph) {
    for (auto in : node->inputs) {
      auto name = in->Name();
      if (!IsValid(name)) continue;
      if (!outputs.count(name) && !inputs.count(name)) {
        (*input_names).emplace_back(name);
        inputs.insert(name);
      }
    }
    for (auto out : node->outputs) {
      auto name = out->Name();
      if (!IsValid(name)) continue;
      outputs.insert(name);
      (*output_names).emplace_back(name);
    }
  }
}
X
xiexionghang 已提交
139

140 141 142
void NgraphSubgraphPass::CreateNgraphEngineOp(Node *node, Graph *graph) const {
  auto &subgraph = *ANAT::Agent(node).subgraph();
  PADDLE_ENFORCE_NE(subgraph.empty(), true, "subgraph cannot be empty");
X
xiexionghang 已提交
143 144 145 146 147 148 149 150 151 152

  framework::proto::BlockDesc block_proto;
  framework::BlockDesc block_desc(nullptr, &block_proto);
  block_desc.Proto()->set_parent_idx(-1);
  block_desc.Proto()->set_idx(0);
  for (auto *node : subgraph) {
    auto *op = block_desc.AppendOp();
    *op->Proto() = *node->Op()->Proto();
  }
  auto *vars = block_desc.Proto()->mutable_vars();
153
  for (Node *node : graph->Nodes()) {
X
xiexionghang 已提交
154 155 156 157
    if (node->IsVar() && node->Var()) {
      *vars->Add() = *node->Var()->Proto();
    }
  }
158 159
  PADDLE_ENFORCE_NE(block_desc.Proto()->vars().empty(), true,
                    "the block has no var-desc");
X
xiexionghang 已提交
160

161 162 163 164 165 166 167 168
  std::vector<std::string> input_names;
  std::vector<std::string> output_names;
  UpdateNgraphIO(node, graph, &input_names, &output_names);
  auto *op_desc = node->Op();
  op_desc->SetInput(
      "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
  op_desc->SetOutput(
      "Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
X
xiexionghang 已提交
169 170

  int sgs = subgraph.size();
171 172 173
  std::string subgraph_str = block_desc.Proto()->SerializeAsString();
  std::string engine_key =
      std::to_string(std::hash<std::string>()(subgraph_str));
X
xiexionghang 已提交
174
  std::vector<int> interval{0, sgs};
175
  op_desc->SetType("ngraph_engine");
X
xiexionghang 已提交
176
  op_desc->SetAttr("interval", interval);
177
  op_desc->SetAttr("graph", subgraph_str);
X
xiexionghang 已提交
178
  op_desc->SetAttr("engine_key", engine_key);
179
  op_desc->SetAttr("op_role", 0);
X
xiexionghang 已提交
180 181 182 183 184 185 186
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(ngraph_subgraph_pass, paddle::framework::ir::NgraphSubgraphPass);