trt_graph_fuse_pass.cc 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"

17 18 19 20
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <paddle/infrt/dialect/pd_ops.h>
21 22 23 24 25 26 27
#include <list>
#include <unordered_set>
#include <vector>

namespace infrt {
namespace trt {
namespace {
28 29 30 31
// ReverseDfs
// do reverse dfs. calls "func" to search when visit a node.
// The elements in 'source' can't be nullptr.
// Reference the function nameed "FlexibleDFS" but defined in:
32 33
// paddle/fluid/framework/ir/subgraph_detector.cc.

34 35 36
bool reverseDfs(std::vector<mlir::Operation *> source,
                const std::function<bool(const mlir::Operation *)> &func) {
  std::unordered_set<const mlir::Operation *> visited;
37 38 39 40 41 42 43
  while (!source.empty()) {
    auto node = source.back();
    source.pop_back();
    if (visited.count(node)) continue;
    visited.insert(node);
    if (func(node)) return true;
    auto values = node->getOperands();
44
    for (auto value : values) {
45
      // if the value is a block argument, the node is nullptr.
46
      mlir::Operation *node = value.getDefiningOp();
47 48
      if (node != nullptr && !visited.count(node)) {
        source.emplace_back(node);
49 50 51
      }
    }
  }
52
  return false;
53 54 55
}

// merge the first&second graph op to a new graph op.
56 57 58
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder,  // NOLINT
                             mlir::pd::GraphOp first,
                             mlir::pd::GraphOp second) {
59
  // comput inputs and outputs
60 61
  ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
  for (mlir::Value input : second.getOperands()) {
62 63 64 65
    if (input.getDefiningOp() != first) {
      inputs.push_back(input);
    }
  }
66 67 68
  ::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
  for (mlir::Value output : first.getResults()) {
    for (mlir::Operation *user : output.getUsers()) {
69 70 71 72 73 74 75
      if (user != second && user->getParentOp() != second) {
        op_output_mapping[output] = outputs.size();
        outputs.push_back(output);
        break;
      }
    }
  }
76 77 78 79
  auto return_op = second.getBody()->getTerminator();
  outputs.append(return_op->getOperands().begin(),
                 return_op->getOperands().end());
  ::llvm::SmallVector<mlir::Type, 4> return_types;
80
  for (auto value : outputs) {
81
    return_types.push_back(value.getType());
82 83 84 85 86
  }

  // create the new graph op
  builder.setInsertionPoint(first);
  auto loc = first.getLoc();
87 88
  auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
  mlir::Block *block = new mlir::Block;
89 90 91 92 93 94 95 96 97 98 99
  auto copy_range = second.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                second.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  copy_range = first.getBody()->without_terminator();
  block->getOperations().splice(block->begin(),
                                first.getBody()->getOperations(),
                                copy_range.begin(),
                                copy_range.end());
  builder.setInsertionPointToEnd(block);
100
  builder.create<mlir::pd::ReturnOp>(loc, outputs);
101 102 103 104
  graph_op.body().push_back(block);

  // mapping the output
  unsigned int num_result = first.getNumResults();
105
  return_op = first.getBody()->getTerminator();
106 107 108
  for (unsigned int index = 0; index < num_result; ++index) {
    auto origin_value = first.getResult(index);
    if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
109
      origin_value.replaceAllUsesWith(return_op->getOperand(index));
110
    } else {
111
      auto inner_value = return_op->getOperand(index);
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
      auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
      while (!origin_value.use_empty()) {
        auto replace_value =
            origin_value.use_begin()->getOwner()->getParentOp() == graph_op
                ? inner_value
                : outer_value;
        origin_value.use_begin()->set(replace_value);
      }
    }
  }
  second.replaceAllUsesWith(
      graph_op.getResults().take_back(second.getNumResults()));
  first.erase();
  second.erase();
}

128 129
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) {  // NOLINT
130
  llvm::SetVector<mlir::Operation *> toSort;
131 132 133 134
  if (body.empty()) return;
  for (auto it = body.rbegin(); it != body.rend(); ++it) {
    toSort.insert(&*it);
  }
135 136
  llvm::SetVector<mlir::Operation *> result =
      mlir::topologicalSort(std::move(toSort));
137 138 139 140 141
  for (auto *op : result) {
    op->moveBefore(body.getTerminator());
  }
}

142 143 144 145 146
}  // namespace

// Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() {
  mlir::Block &body = getFunction().front();
147
  mlir::OpBuilder builder(&body, body.begin());
148 149 150 151
  bool changed = false;
  do {
    changed = false;
    for (auto &op : body) {
152 153
      mlir::pd::GraphOp graph_op =
          ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
154 155 156
      if (nullptr == graph_op) continue;

      for (auto user_op : op.getUsers()) {
157 158
        mlir::pd::GraphOp user_graph_op =
            ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
159 160
        if (nullptr == user_graph_op) continue;
        // get all dst input nodes except src.
161
        std::vector<mlir::Operation *> source_nodes;
162 163
        for (auto operand : user_op->getOperands()) {
          auto input = operand.getDefiningOp();
164
          if (input != &op && input != nullptr) {
165 166 167 168
            source_nodes.push_back(input);
          }
        }
        // Reverse DFS from the source_nodes.
169 170
        if (!reverseDfs(source_nodes,
                        [&op](const mlir::Operation *n) { return n == &op; })) {
171 172 173 174 175 176 177 178
          mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
          changed = true;
          break;
        }
      }
      if (changed) break;
    }
  } while (changed);
179
  topoSortBlock(body);
180 181 182
}
}  // namespace trt
}  // namespace infrt