trt_graph_fuse_pass.cc 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"

17
#include <glog/logging.h>
18 19 20
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
21 22 23 24
#include <list>
#include <unordered_set>
#include <vector>

25 26
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"

27 28 29
namespace infrt {
namespace trt {
namespace {
30 31 32 33
// ReverseDfs
// do reverse dfs. calls "func" to search when visit a node.
// The elements in 'source' can't be nullptr.
// Reference the function nameed "FlexibleDFS" but defined in:
34 35
// paddle/fluid/framework/ir/subgraph_detector.cc.

36 37 38
bool reverseDfs(std::vector<mlir::Operation *> source,
                const std::function<bool(const mlir::Operation *)> &func) {
  std::unordered_set<const mlir::Operation *> visited;
39 40 41 42 43 44 45
  while (!source.empty()) {
    auto node = source.back();
    source.pop_back();
    if (visited.count(node)) continue;
    visited.insert(node);
    if (func(node)) return true;
    auto values = node->getOperands();
46
    for (auto value : values) {
47
      // if the value is a block argument, the node is nullptr.
48
      mlir::Operation *node = value.getDefiningOp();
49 50
      if (node != nullptr && !visited.count(node)) {
        source.emplace_back(node);
51 52 53
      }
    }
  }
54
  return false;
55 56 57
}

// merge the first&second graph op to a new graph op.
S
Shang Zhizhou 已提交
58
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder,  // NOLINT
59 60
                             ::infrt::GraphOp first,
                             ::infrt::GraphOp second) {
61
  // comput inputs and outputs
62 63
  ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
  for (mlir::Value input : second.getOperands()) {
64 65 66 67
    if (input.getDefiningOp() != first) {
      inputs.push_back(input);
    }
  }
68 69 70
  ::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
  for (mlir::Value output : first.getResults()) {
    for (mlir::Operation *user : output.getUsers()) {
71 72 73 74 75 76 77
      if (user != second && user->getParentOp() != second) {
        op_output_mapping[output] = outputs.size();
        outputs.push_back(output);
        break;
      }
    }
  }
78 79 80 81
  auto return_op = second.getBody()->getTerminator();
  outputs.append(return_op->getOperands().begin(),
                 return_op->getOperands().end());
  ::llvm::SmallVector<mlir::Type, 4> return_types;
82
  for (auto value : outputs) {
83
    return_types.push_back(value.getType());
84 85 86 87 88
  }

  // create the new graph op
  builder.setInsertionPoint(first);
  auto loc = first.getLoc();
89
  auto graph_op = builder.create<::infrt::GraphOp>(loc, return_types, inputs);
90
  mlir::Block *block = new mlir::Block;
91 92 93 94 95 96 97 98 99 100 101
  auto copy_range = second.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                second.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  copy_range = first.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                first.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  builder.setInsertionPointToEnd(block);
S
Shang Zhizhou 已提交
102
  builder.create<::infrt::ReturnOp>(loc, outputs);
103 104 105 106
  graph_op.body().push_back(block);

  // mapping the output
  unsigned int num_result = first.getNumResults();
107
  return_op = first.getBody()->getTerminator();
108 109 110
  for (unsigned int index = 0; index < num_result; ++index) {
    auto origin_value = first.getResult(index);
    if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
111
      origin_value.replaceAllUsesWith(return_op->getOperand(index));
112
    } else {
113
      auto inner_value = return_op->getOperand(index);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
      auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
      while (!origin_value.use_empty()) {
        auto replace_value =
            origin_value.use_begin()->getOwner()->getParentOp() == graph_op
                ? inner_value
                : outer_value;
        origin_value.use_begin()->set(replace_value);
      }
    }
  }
  second.replaceAllUsesWith(
      graph_op.getResults().take_back(second.getNumResults()));
  first.erase();
  second.erase();
}

130 131
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) {  // NOLINT
132
  llvm::SetVector<mlir::Operation *> toSort;
133 134 135 136
  if (body.empty()) return;
  for (auto it = body.rbegin(); it != body.rend(); ++it) {
    toSort.insert(&*it);
  }
137
  llvm::SetVector<mlir::Operation *> result = mlir::topologicalSort(toSort);
138 139 140 141 142
  for (auto *op : result) {
    op->moveBefore(body.getTerminator());
  }
}

143 144 145
}  // namespace

// Implementation of the trtGraphFusePass.
146
void TRTGraphFusePass::runOnFunction() {
147
  mlir::Block &body = getFunction().front();
148
  mlir::OpBuilder builder(&body, body.begin());
149 150 151 152
  bool changed = false;
  do {
    changed = false;
    for (auto &op : body) {
153 154
      ::infrt::GraphOp graph_op =
          ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op);
155 156 157
      if (nullptr == graph_op) continue;

      for (auto user_op : op.getUsers()) {
158 159
        ::infrt::GraphOp user_graph_op =
            ::llvm::dyn_cast_or_null<::infrt::GraphOp>(user_op);
160 161
        if (nullptr == user_graph_op) continue;
        // get all dst input nodes except src.
162
        std::vector<mlir::Operation *> source_nodes;
163 164
        for (auto operand : user_op->getOperands()) {
          auto input = operand.getDefiningOp();
165
          if (input != &op && input != nullptr) {
166 167 168 169
            source_nodes.push_back(input);
          }
        }
        // Reverse DFS from the source_nodes.
170 171
        if (!reverseDfs(source_nodes,
                        [&op](const mlir::Operation *n) { return n == &op; })) {
S
Shang Zhizhou 已提交
172
          mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
173 174 175 176 177 178 179
          changed = true;
          break;
        }
      }
      if (changed) break;
    }
  } while (changed);
180 181 182

  // TODO(wilber): Implement a toposort for efficiency.
  // topoSortBlock(body);
183
}
W
Wilber 已提交
184 185 186 187 188

std::unique_ptr<mlir::Pass> CreateTrtGraphFusePass() {
  return std::make_unique<TRTGraphFusePass>();
}

189 190
}  // namespace trt
}  // namespace infrt