trt_graph_fuse_pass.cc 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"

#include <list>
#include <unordered_set>
#include <vector>
#include "llvm/ADT/SetVector.h"
21
#include "mlir/Analysis/SliceAnalysis.h"
22 23 24 25 26 27 28
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"

namespace infrt {
namespace trt {
namespace {
29 30 31 32
// ReverseDfs
// do reverse dfs. calls "func" to search when visit a node.
// The elements in 'source' can't be nullptr.
// Reference the function nameed "FlexibleDFS" but defined in:
33 34
// paddle/fluid/framework/ir/subgraph_detector.cc.

35 36
bool reverseDfs(std::vector<::mlir::Operation *> source,
                const std::function<bool(const ::mlir::Operation *)> &func) {
37
  std::unordered_set<const ::mlir::Operation *> visited;
38 39 40 41 42 43 44
  while (!source.empty()) {
    auto node = source.back();
    source.pop_back();
    if (visited.count(node)) continue;
    visited.insert(node);
    if (func(node)) return true;
    auto values = node->getOperands();
45
    for (auto value : values) {
46
      // if the value is a block argument, the node is nullptr.
47
      ::mlir::Operation *node = value.getDefiningOp();
48 49
      if (node != nullptr && !visited.count(node)) {
        source.emplace_back(node);
50 51 52
      }
    }
  }
53
  return false;
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
}

// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder,  // NOLINT
                             ::mlir::pd::GraphOp first,
                             ::mlir::pd::GraphOp second) {
  // comput inputs and outputs
  ::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs;
  for (::mlir::Value input : second.getOperands()) {
    if (input.getDefiningOp() != first) {
      inputs.push_back(input);
    }
  }
  ::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping;
  for (::mlir::Value output : first.getResults()) {
    for (::mlir::Operation *user : output.getUsers()) {
      if (user != second && user->getParentOp() != second) {
        op_output_mapping[output] = outputs.size();
        outputs.push_back(output);
        break;
      }
    }
  }
  auto fetch_op = second.getBody()->getTerminator();
  outputs.append(fetch_op->getOperands().begin(),
                 fetch_op->getOperands().end());
  ::llvm::SmallVector<::mlir::Type, 4> fetch_types;
  for (auto value : outputs) {
    fetch_types.push_back(value.getType());
  }

  // create the new graph op
  builder.setInsertionPoint(first);
  auto loc = first.getLoc();
  auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs);
  ::mlir::Block *block = new ::mlir::Block;
  auto copy_range = second.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                second.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  copy_range = first.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                first.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  builder.setInsertionPointToEnd(block);
  builder.create<mlir::pd::FetchOp>(loc, outputs);
  graph_op.body().push_back(block);

  // mapping the output
  unsigned int num_result = first.getNumResults();
  fetch_op = first.getBody()->getTerminator();
  for (unsigned int index = 0; index < num_result; ++index) {
    auto origin_value = first.getResult(index);
    if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
      origin_value.replaceAllUsesWith(fetch_op->getOperand(index));
    } else {
      auto inner_value = fetch_op->getOperand(index);
      auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
      while (!origin_value.use_empty()) {
        auto replace_value =
            origin_value.use_begin()->getOwner()->getParentOp() == graph_op
                ? inner_value
                : outer_value;
        origin_value.use_begin()->set(replace_value);
      }
    }
  }
  second.replaceAllUsesWith(
      graph_op.getResults().take_back(second.getNumResults()));
  first.erase();
  second.erase();
}

129 130 131 132 133 134 135 136 137 138 139 140 141 142
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) {  // NOLINT
  llvm::SetVector<Operation *> toSort;
  if (body.empty()) return;
  for (auto it = body.rbegin(); it != body.rend(); ++it) {
    toSort.insert(&*it);
  }
  llvm::SetVector<Operation *> result =
      ::mlir::topologicalSort(std::move(toSort));
  for (auto *op : result) {
    op->moveBefore(body.getTerminator());
  }
}

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
}  // namespace

// Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() {
  mlir::Block &body = getFunction().front();
  ::mlir::OpBuilder builder(&body, body.begin());
  bool changed = false;
  do {
    changed = false;
    for (auto &op : body) {
      ::mlir::pd::GraphOp graph_op =
          ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
      if (nullptr == graph_op) continue;

      for (auto user_op : op.getUsers()) {
        ::mlir::pd::GraphOp user_graph_op =
            ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op);
        if (nullptr == user_graph_op) continue;
        // get all dst input nodes except src.
        std::vector<::mlir::Operation *> source_nodes;
        for (auto operand : user_op->getOperands()) {
          auto input = operand.getDefiningOp();
165
          if (input != &op && input != nullptr) {
166 167 168 169
            source_nodes.push_back(input);
          }
        }
        // Reverse DFS from the source_nodes.
170 171 172
        if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) {
              return n == &op;
            })) {
173 174 175 176 177 178 179 180
          mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
          changed = true;
          break;
        }
      }
      if (changed) break;
    }
  } while (changed);
181
  topoSortBlock(body);
182 183 184
}
}  // namespace trt
}  // namespace infrt