未验证 提交 9f0958fa 编写于 作者: 王明冬 提交者: GitHub

[infrt] add trt_graph_split_pass for infrt. test=develop (#38494)

上级 dfdc9960
if (NOT WITH_INFRT)
return()
endif()
# compile flags
set(INFRT_FLAGS -Wno-comment)
foreach(flag ${INFRT_FLAGS})
safe_set_cflag(CMAKE_C_FLAGS ${flag})
safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag})
endforeach()
set(INFRT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/infrt" )
set(INFRT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/infrt" )
set(INFRT_TEST_TARGETS CACHE INTERNAL "")
......
......@@ -36,7 +36,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
......
func @ops() {
%a = pd.feed() : tensor<?xf32>
%b = pd.feed() : tensor<?xf32>
%a = pd.feed() {name="input0"} : tensor<?xf32>
%b = pd.feed() {name="input1"}: tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
......
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?xf32>
%b = "pd.feed"() : () -> tensor<?xf32>
%bias = "pd.feed"() : () -> tensor<?xf32>
%a = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%b = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%bias = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%b1 = "pd.feed"() : () -> tensor<?xf32>
%b2 = "pd.feed"() : () -> tensor<?xf32>
%bias1 = "pd.feed"() : () -> tensor<?xf32>
%bias2 = "pd.feed"() : () -> tensor<?xf32>
%b1 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%b2 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%bias2 = "pd.feed"() {name="input6"} : () -> tensor<?xf32>
%c = "pd.matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
......
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?x3x256x256xf32>
%a = "pd.feed"() {name="input0"} : () -> tensor<?x3x256x256xf32>
%filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32>
%bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
......
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.feed"() : () -> tensor<?xf32>
%b = "pd.feed"() : () -> tensor<?xf32>
%bias = "pd.feed"() : () -> tensor<?xf32>
%c = "pd.feed"() : () -> tensor<?xf32>
%b1 = "pd.feed"() : () -> tensor<?xf32>
%b2 = "pd.feed"() : () -> tensor<?xf32>
%bias1 = "pd.feed"() : () -> tensor<?xf32>
%bias2 = "pd.feed"() : () -> tensor<?xf32>
%bias = "pd.feed"() {name="input0"} : () -> tensor<?xf32>
%c = "pd.feed"() {name="input1"} : () -> tensor<?xf32>
%b1 = "pd.feed"() {name="input2"} : () -> tensor<?xf32>
%b2 = "pd.feed"() {name="input3"} : () -> tensor<?xf32>
%bias1 = "pd.feed"() {name="input4"} : () -> tensor<?xf32>
%bias2 = "pd.feed"() {name="input5"} : () -> tensor<?xf32>
%d = "pd.elementwise_add"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
......
......@@ -6,14 +6,14 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/pd_op_base.td"
def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> {
def PD_FeedOp : PD_Op<"feed"> {
let summary = "Feed Op";
let description = [{
Feed a tensor into the model.
}];
let arguments = (ins);
let arguments = (ins StrAttr:$name);
let results = (outs PD_Tensor:$out);
let assemblyFormat = [{
......
......@@ -4,5 +4,9 @@ gather_srcs(infrt_src SRCS
trt_ops.cc
trt_op_teller_pass.cc
trt_graph_fuse_pass.cc
trt_graph_split_pass.cc
)
mlir_tablegen_on(trt_ops)
add_executable(trt-exec trt_exec.cc)
target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS})
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
int main(int argc, char** argv) {
static llvm::cl::opt<std::string> input_file(
llvm::cl::Positional,
llvm::cl::desc("Specify input filename"),
llvm::cl::init("-"));
llvm::cl::ParseCommandLineOptions(argc, argv);
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context);
module->dump();
mlir::PassManager pm(context);
mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>();
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtOpTellerPass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtGraphFusePass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::trtGraphSplitPass>(10));
if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl;
return 4;
}
module->dump();
return 0;
}
......@@ -18,6 +18,7 @@
#include <unordered_set>
#include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
......@@ -25,42 +26,31 @@
namespace infrt {
namespace trt {
namespace {
// FlexibleDFS
// do reverse dfs. calls leave(node) after visiting all parents of node.
// Reference the function with the same name but defined in:
// ReverseDfs
// do reverse dfs. calls "func" to search when visit a node.
// The elements in 'source' can't be nullptr.
// Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc.
void FlexibleDFS(const std::vector<::mlir::Operation *> &source,
const std::function<bool(const ::mlir::Operation *)> &leave) {
typedef struct {
::mlir::Operation *node;
bool leave;
} FNode;
std::vector<FNode> stack;
for (auto &node : source) {
stack.push_back(FNode{node, false});
}
bool reverseDfs(std::vector<::mlir::Operation *> source,
const std::function<bool(const ::mlir::Operation *)> &func) {
std::unordered_set<const ::mlir::Operation *> visited;
while (!stack.empty()) {
auto fnode = stack.back();
stack.pop_back();
if (fnode.leave) {
if (leave && !leave(fnode.node)) return;
}
if (visited.count(fnode.node)) continue;
visited.insert(fnode.node);
if (leave) stack.push_back(FNode{fnode.node, true});
auto values = fnode.node->getOperands();
while (!source.empty()) {
auto node = source.back();
source.pop_back();
if (visited.count(node)) continue;
visited.insert(node);
if (func(node)) return true;
auto values = node->getOperands();
for (auto value : values) {
// if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp();
if (!visited.count(node)) {
stack.push_back(FNode{node, false});
if (node != nullptr && !visited.count(node)) {
source.emplace_back(node);
}
}
}
return false;
}
// merge the first&second graph op to a new graph op.
......@@ -136,6 +126,20 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
second.erase();
}
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort;
if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it);
}
llvm::SetVector<Operation *> result =
::mlir::topologicalSort(std::move(toSort));
for (auto *op : result) {
op->moveBefore(body.getTerminator());
}
}
} // namespace
// Implementation of the trtGraphFusePass.
......@@ -158,21 +162,14 @@ void trtGraphFusePass::runOnFunction() {
std::vector<::mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp();
if (input != &op) {
if (input != &op && input != nullptr) {
source_nodes.push_back(input);
}
}
// Reverse DFS from the source_nodes.
bool have_excess_path = false;
FlexibleDFS(source_nodes,
[&have_excess_path, &op](const ::mlir::Operation *n) {
if (n == &op) {
have_excess_path = true;
return false;
}
return true;
});
if (!have_excess_path) {
if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) {
return n == &op;
})) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true;
break;
......@@ -181,7 +178,7 @@ void trtGraphFusePass::runOnFunction() {
if (changed) break;
}
} while (changed);
topoSortBlock(body);
}
} // namespace trt
} // namespace infrt
......@@ -25,7 +25,7 @@ namespace trt {
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
......@@ -42,7 +42,7 @@ namespace trt {
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
......@@ -58,6 +58,5 @@ class trtGraphFusePass
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
};
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front();
for (auto& op : block) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back();
::mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands());
auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op),
body->getOperations(),
copy_range.begin(),
copy_range.end());
graph_op.erase();
}
}
} // namespace trt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
namespace infrt {
namespace trt {
/*
* trtGraphSplitPass.
*
* Splite the graph op when the number of operations is too small.
* The feature is the opposite of 'trtOpTellerPass'.
*
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %d, %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* }
*/
class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override;
explicit trtGraphSplitPass(size_t min_subgraph_size = 3)
: min_subgraph_size_(min_subgraph_size) {}
private:
size_t min_subgraph_size_;
};
} // namespace trt
} // namespace infrt
......@@ -20,7 +20,6 @@
namespace infrt {
namespace trt {
// Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front();
......@@ -60,6 +59,5 @@ void trtOpTellerPass::runOnFunction() {
builder.create<mlir::pd::FetchOp>(loc, op->getResults());
}
}
} // namespace trt
} // namespace infrt
......@@ -17,7 +17,6 @@
namespace infrt {
namespace trt {
/*
* trtOpTellerPass.
*
......@@ -26,7 +25,7 @@ namespace trt {
* source func:
*
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()...
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
......@@ -35,7 +34,7 @@ namespace trt {
*
* destination func:
* func @main() -> tensor<?xf32> {
* %a = "pd.feed"()
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
......@@ -59,6 +58,5 @@ class trtOpTellerPass
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
};
} // namespace trt
} // namespace infrt
文件模式从 100755 更改为 100644
......@@ -65,12 +65,12 @@ function infrt_gen_and_build() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
rm -f infrt_summary.txt
cmake .. -DWITH_MKL=OFF -DWITH_GPU=OFF -DCMAKE_BUILD_TYPE=Release -DWITH_INFRT=ON -DWITH_PYTHON=OFF -DWITH_TESTING==${WITH_TESTING:-ON}; build_error=$?
cmake .. -DWITH_MKL=OFF -DWITH_GPU=OFF -DWITH_CRYPTO=OFF -DCMAKE_BUILD_TYPE=Release -DWITH_INFRT=ON -DWITH_PYTHON=OFF -DWITH_TESTING==${WITH_TESTING:-ON}; build_error=$?
if [ "$build_error" != 0 ];then
exit 7;
fi
make -j ${parallel_number} infrt infrtopt infrt-exec test_infrt_exec infrt_lib_dist;build_error=$?
make -j ${parallel_number} infrt infrtopt infrt-exec test_infrt_exec trt-exec infrt_lib_dist;build_error=$?
if [ "$build_error" != 0 ];then
exit 7;
fi
......@@ -115,6 +115,9 @@ function main() {
build_only)
infrt_gen_and_build ${parallel_number}
;;
test_only)
test_infrt
;;
*)
print_usage
exit 1
......@@ -126,7 +129,7 @@ function main() {
cat ${PADDLE_ROOT}/build/infrt_summary.txt
echo "========================================================"
fi
echo "paddle_build script finished as expected"
echo "paddle_build script finished as expected!"
}
main $@
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册