trt_graph_fuse_pass.cc 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"

17 18 19
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
20 21 22 23
#include <list>
#include <unordered_set>
#include <vector>

24 25
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"

26 27 28
namespace infrt {
namespace trt {
namespace {
29 30 31 32
// ReverseDfs
// do reverse dfs. calls "func" to search when visit a node.
// The elements in 'source' can't be nullptr.
// Reference the function nameed "FlexibleDFS" but defined in:
33 34
// paddle/fluid/framework/ir/subgraph_detector.cc.

35 36 37
bool reverseDfs(std::vector<mlir::Operation *> source,
                const std::function<bool(const mlir::Operation *)> &func) {
  std::unordered_set<const mlir::Operation *> visited;
38 39 40 41 42 43 44
  while (!source.empty()) {
    auto node = source.back();
    source.pop_back();
    if (visited.count(node)) continue;
    visited.insert(node);
    if (func(node)) return true;
    auto values = node->getOperands();
45
    for (auto value : values) {
46
      // if the value is a block argument, the node is nullptr.
47
      mlir::Operation *node = value.getDefiningOp();
48 49
      if (node != nullptr && !visited.count(node)) {
        source.emplace_back(node);
50 51 52
      }
    }
  }
53
  return false;
54 55 56
}

// merge the first&second graph op to a new graph op.
S
Shang Zhizhou 已提交
57 58 59
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder,  // NOLINT
                             mlir::pd::GraphOp first,
                             mlir::pd::GraphOp second) {
60
  // comput inputs and outputs
61 62
  ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
  for (mlir::Value input : second.getOperands()) {
63 64 65 66
    if (input.getDefiningOp() != first) {
      inputs.push_back(input);
    }
  }
67 68 69
  ::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
  for (mlir::Value output : first.getResults()) {
    for (mlir::Operation *user : output.getUsers()) {
70 71 72 73 74 75 76
      if (user != second && user->getParentOp() != second) {
        op_output_mapping[output] = outputs.size();
        outputs.push_back(output);
        break;
      }
    }
  }
77 78 79 80
  auto return_op = second.getBody()->getTerminator();
  outputs.append(return_op->getOperands().begin(),
                 return_op->getOperands().end());
  ::llvm::SmallVector<mlir::Type, 4> return_types;
81
  for (auto value : outputs) {
82
    return_types.push_back(value.getType());
83 84 85 86 87
  }

  // create the new graph op
  builder.setInsertionPoint(first);
  auto loc = first.getLoc();
S
Shang Zhizhou 已提交
88
  auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
89
  mlir::Block *block = new mlir::Block;
90 91 92 93 94 95 96 97 98 99 100
  auto copy_range = second.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                second.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  copy_range = first.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                first.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  builder.setInsertionPointToEnd(block);
S
Shang Zhizhou 已提交
101
  builder.create<::infrt::ReturnOp>(loc, outputs);
102 103 104 105
  graph_op.body().push_back(block);

  // mapping the output
  unsigned int num_result = first.getNumResults();
106
  return_op = first.getBody()->getTerminator();
107 108 109
  for (unsigned int index = 0; index < num_result; ++index) {
    auto origin_value = first.getResult(index);
    if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
110
      origin_value.replaceAllUsesWith(return_op->getOperand(index));
111
    } else {
112
      auto inner_value = return_op->getOperand(index);
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
      auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
      while (!origin_value.use_empty()) {
        auto replace_value =
            origin_value.use_begin()->getOwner()->getParentOp() == graph_op
                ? inner_value
                : outer_value;
        origin_value.use_begin()->set(replace_value);
      }
    }
  }
  second.replaceAllUsesWith(
      graph_op.getResults().take_back(second.getNumResults()));
  first.erase();
  second.erase();
}

129 130
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) {  // NOLINT
131
  llvm::SetVector<mlir::Operation *> toSort;
132 133 134 135
  if (body.empty()) return;
  for (auto it = body.rbegin(); it != body.rend(); ++it) {
    toSort.insert(&*it);
  }
136 137
  llvm::SetVector<mlir::Operation *> result =
      mlir::topologicalSort(std::move(toSort));
138 139 140 141 142
  for (auto *op : result) {
    op->moveBefore(body.getTerminator());
  }
}

143 144 145
}  // namespace

// Implementation of the trtGraphFusePass.
146
void TRTGraphFusePass::runOnFunction() {
147
  mlir::Block &body = getFunction().front();
148
  mlir::OpBuilder builder(&body, body.begin());
149 150 151 152
  bool changed = false;
  do {
    changed = false;
    for (auto &op : body) {
S
Shang Zhizhou 已提交
153 154
      mlir::pd::GraphOp graph_op =
          ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
155 156 157
      if (nullptr == graph_op) continue;

      for (auto user_op : op.getUsers()) {
S
Shang Zhizhou 已提交
158 159
        mlir::pd::GraphOp user_graph_op =
            ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
160 161
        if (nullptr == user_graph_op) continue;
        // get all dst input nodes except src.
162
        std::vector<mlir::Operation *> source_nodes;
163 164
        for (auto operand : user_op->getOperands()) {
          auto input = operand.getDefiningOp();
165
          if (input != &op && input != nullptr) {
166 167 168 169
            source_nodes.push_back(input);
          }
        }
        // Reverse DFS from the source_nodes.
170 171
        if (!reverseDfs(source_nodes,
                        [&op](const mlir::Operation *n) { return n == &op; })) {
S
Shang Zhizhou 已提交
172
          mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
173 174 175 176 177 178 179
          changed = true;
          break;
        }
      }
      if (changed) break;
    }
  } while (changed);
180
  topoSortBlock(body);
181 182 183
}
}  // namespace trt
}  // namespace infrt