From 44112817a5592d2f5a5a414c9ddaaabff26f0cdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Wed, 22 Dec 2021 15:04:32 +0800 Subject: [PATCH] [infrt] add tensorrt op teller pass. test=develop (#38304) --- cmake/infrt_lib.cmake | 17 +- paddle/infrt/CMakeLists.txt | 3 + paddle/infrt/dialect/CMakeLists.txt | 2 + paddle/infrt/dialect/pd_ops.cc | 5 - paddle/infrt/dialect/pd_ops.h | 4 + paddle/infrt/dialect/pd_ops.td | 22 ++- paddle/infrt/dialect/tensorrt/CMakeLists.txt | 8 + .../dialect/tensorrt/trt_graph_fuse_pass.cc | 187 ++++++++++++++++++ .../dialect/tensorrt/trt_graph_fuse_pass.h | 63 ++++++ paddle/infrt/dialect/tensorrt/trt_op_base.td | 77 ++++++++ .../dialect/tensorrt/trt_op_teller_pass.cc | 65 ++++++ .../dialect/tensorrt/trt_op_teller_pass.h | 64 ++++++ paddle/infrt/dialect/tensorrt/trt_ops.cc | 39 ++++ paddle/infrt/dialect/tensorrt/trt_ops.h | 50 +++++ paddle/infrt/dialect/tensorrt/trt_ops.td | 30 +++ paddle/scripts/infrt_build.sh | 3 + 16 files changed, 624 insertions(+), 15 deletions(-) create mode 100755 paddle/infrt/dialect/tensorrt/CMakeLists.txt create mode 100644 paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc create mode 100644 paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h create mode 100755 paddle/infrt/dialect/tensorrt/trt_op_base.td create mode 100644 paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc create mode 100644 paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h create mode 100644 paddle/infrt/dialect/tensorrt/trt_ops.cc create mode 100755 paddle/infrt/dialect/tensorrt/trt_ops.h create mode 100755 paddle/infrt/dialect/tensorrt/trt_ops.td diff --git a/cmake/infrt_lib.cmake b/cmake/infrt_lib.cmake index 73a8cdbee5..5b27c9d840 100644 --- a/cmake/infrt_lib.cmake +++ b/cmake/infrt_lib.cmake @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -set(PADDLE_INFRT_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_infrt_install_dir" CACHE STRING +set(INFRT_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_infrt_install_dir" CACHE STRING "A path setting paddle infrt shared and static libraries") function(copy TARGET) @@ -52,18 +52,17 @@ add_custom_target(infrt_lib_dist DEPENDS ${infrt_lib_deps}) # CMakeCache Info copy(infrt_lib_dist SRCS ${CMAKE_BINARY_DIR}/CMakeCache.txt - DSTS ${PADDLE_INFRT_INSTALL_DIR}) + DSTS ${INFRT_INSTALL_DIR}) -set(src_dir "${PADDLE_SOURCE_DIR}/paddle/infrt") -set(paddle_infrt_lib ${PADDLE_BINARY_DIR}/paddle/infrt/libinfrt.*) +set(infrt_lib ${INFRT_BINARY_DIR}/libinfrt.*) copy(infrt_lib_dist - SRCS ${src_dir}/api/infrt_api.h ${paddle_infrt_lib} - DSTS ${PADDLE_INFRT_INSTALL_DIR}/infrt/include ${PADDLE_INFRT_INSTALL_DIR}/infrt/lib) + SRCS ${INFRT_SOURCE_DIR}/api/infrt_api.h ${infrt_lib} + DSTS ${INFRT_INSTALL_DIR}/infrt/include ${INFRT_INSTALL_DIR}/infrt/lib) copy(infrt_lib_dist - SRCS ${CMAKE_BINARY_DIR}/paddle/infrt/paddle/framework.pb.h - DSTS ${PADDLE_INFRT_INSTALL_DIR}/infrt/include/internal) + SRCS ${INFRT_BINARY_DIR}/paddle/framework.pb.h + DSTS ${INFRT_INSTALL_DIR}/infrt/include/internal) # paddle fluid version function(version version_file) @@ -74,4 +73,4 @@ function(version version_file) file(WRITE ${version_file} "GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n") file(APPEND ${version_file} "CXX compiler version: ${CMAKE_CXX_COMPILER_VERSION}\n") endfunction() -version(${PADDLE_INFRT_INSTALL_DIR}/version.txt) +version(${INFRT_INSTALL_DIR}/version.txt) diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index 465397977f..c81a0fe91c 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -1,6 +1,8 @@ if (NOT WITH_INFRT) return() endif() +set(INFRT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/infrt" ) +set(INFRT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/infrt" ) include(infrt_lib) @@ -74,6 +76,7 @@ set(infrt_mlir_incs dense_tensor_inc pd_ops_inc rewrite_inc + trt_ops_inc ) message(STATUS "infrt srcs:\n${infrt_src}") diff --git a/paddle/infrt/dialect/CMakeLists.txt b/paddle/infrt/dialect/CMakeLists.txt index c06d777163..d145843684 100644 --- a/paddle/infrt/dialect/CMakeLists.txt +++ b/paddle/infrt/dialect/CMakeLists.txt @@ -51,3 +51,5 @@ infrt_exec_check(test_infrt_tensor_type mlir_tests/tensor_type.mlir) infrt_exec_check(test_infrt__basic mlir_tests/basic.mlir) infrt_exec_check(test_infrt_benchmark mlir_tests/benchmark.mlir) infrt_exec_check(test_infrt_mlir_dense_tensor mlir_tests/dense_tensor.mlir) + +add_subdirectory(tensorrt) diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc index 7ca07dd5fc..ce10be6d10 100644 --- a/paddle/infrt/dialect/pd_ops.cc +++ b/paddle/infrt/dialect/pd_ops.cc @@ -20,11 +20,6 @@ namespace mlir { namespace pd { - -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/pd_ops.hpp.inc" -#undef GET_OP_CLASSES - PaddleDialect::PaddleDialect(MLIRContext *context) : Dialect("pd", context, TypeID::get()) { addOperations< diff --git a/paddle/infrt/dialect/pd_ops.h b/paddle/infrt/dialect/pd_ops.h index d09b603225..71e0a53988 100644 --- a/paddle/infrt/dialect/pd_ops.h +++ b/paddle/infrt/dialect/pd_ops.h @@ -53,5 +53,9 @@ class PaddleDialect : public Dialect { } }; +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.hpp.inc" +#undef GET_OP_CLASSES + } // namespace pd } // namespace mlir diff --git a/paddle/infrt/dialect/pd_ops.td b/paddle/infrt/dialect/pd_ops.td index 9e906ad0c0..2aa7ab576a 100644 --- a/paddle/infrt/dialect/pd_ops.td +++ b/paddle/infrt/dialect/pd_ops.td @@ -6,7 +6,7 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/pd_op_base.td" -def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> { +def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> { let summary = "Feed Op"; let description = [{ @@ -21,6 +21,26 @@ def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> { }]; } +def PD_FetchOp : PD_Op<"fetch", [Terminator]> { + let summary = "fetch Op"; + + let description = [{ + Fetch tensor from the graph. + }]; + + let arguments = (ins Variadic:$inputs); +} + +def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { + let summary = "paddle graph Op"; + let description = [{ + Describe a paddle graph or subgraph. + }]; + let regions = (region SizedRegion<1>:$body); + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); +} + def PD_ConstantOp : PD_Op<"Constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { let summary = "constant Op"; let description = [{}]; diff --git a/paddle/infrt/dialect/tensorrt/CMakeLists.txt b/paddle/infrt/dialect/tensorrt/CMakeLists.txt new file mode 100755 index 0000000000..e1e3f40075 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/CMakeLists.txt @@ -0,0 +1,8 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + trt_ops.cc + trt_op_teller_pass.cc + trt_graph_fuse_pass.cc + ) +mlir_tablegen_on(trt_ops) diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc new file mode 100644 index 0000000000..a44fcece43 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -0,0 +1,187 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" + +#include +#include +#include +#include "llvm/ADT/SetVector.h" +#include "mlir/IR/Builders.h" +#include "paddle/infrt/dialect/pd_ops.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace infrt { +namespace trt { +namespace { + +// FlexibleDFS +// do reverse dfs. calls leave(node) after visiting all parents of node. +// Reference the function with the same name but defined in: +// paddle/fluid/framework/ir/subgraph_detector.cc. +void FlexibleDFS(const std::vector<::mlir::Operation *> &source, + const std::function &leave) { + typedef struct { + ::mlir::Operation *node; + bool leave; + } FNode; + + std::vector stack; + for (auto &node : source) { + stack.push_back(FNode{node, false}); + } + std::unordered_set visited; + while (!stack.empty()) { + auto fnode = stack.back(); + stack.pop_back(); + + if (fnode.leave) { + if (leave && !leave(fnode.node)) return; + } + if (visited.count(fnode.node)) continue; + visited.insert(fnode.node); + + if (leave) stack.push_back(FNode{fnode.node, true}); + auto values = fnode.node->getOperands(); + for (auto value : values) { + ::mlir::Operation *node = value.getDefiningOp(); + if (!visited.count(node)) { + stack.push_back(FNode{node, false}); + } + } + } +} + +// merge the first&second graph op to a new graph op. +void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT + ::mlir::pd::GraphOp first, + ::mlir::pd::GraphOp second) { + // comput inputs and outputs + ::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs; + for (::mlir::Value input : second.getOperands()) { + if (input.getDefiningOp() != first) { + inputs.push_back(input); + } + } + ::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping; + for (::mlir::Value output : first.getResults()) { + for (::mlir::Operation *user : output.getUsers()) { + if (user != second && user->getParentOp() != second) { + op_output_mapping[output] = outputs.size(); + outputs.push_back(output); + break; + } + } + } + auto fetch_op = second.getBody()->getTerminator(); + outputs.append(fetch_op->getOperands().begin(), + fetch_op->getOperands().end()); + ::llvm::SmallVector<::mlir::Type, 4> fetch_types; + for (auto value : outputs) { + fetch_types.push_back(value.getType()); + } + + // create the new graph op + builder.setInsertionPoint(first); + auto loc = first.getLoc(); + auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs); + ::mlir::Block *block = new ::mlir::Block; + auto copy_range = second.getBody()->without_terminator(); + block->getOperations().splice(block->begin(), + second.getBody()->getOperations(), + copy_range.begin(), + copy_range.end()); + copy_range = first.getBody()->without_terminator(); + block->getOperations().splice(block->begin(), + first.getBody()->getOperations(), + copy_range.begin(), + copy_range.end()); + builder.setInsertionPointToEnd(block); + builder.create(loc, outputs); + graph_op.body().push_back(block); + + // mapping the output + unsigned int num_result = first.getNumResults(); + fetch_op = first.getBody()->getTerminator(); + for (unsigned int index = 0; index < num_result; ++index) { + auto origin_value = first.getResult(index); + if (op_output_mapping.find(origin_value) == op_output_mapping.end()) { + origin_value.replaceAllUsesWith(fetch_op->getOperand(index)); + } else { + auto inner_value = fetch_op->getOperand(index); + auto outer_value = graph_op.getResult(op_output_mapping[origin_value]); + while (!origin_value.use_empty()) { + auto replace_value = + origin_value.use_begin()->getOwner()->getParentOp() == graph_op + ? inner_value + : outer_value; + origin_value.use_begin()->set(replace_value); + } + } + } + second.replaceAllUsesWith( + graph_op.getResults().take_back(second.getNumResults())); + first.erase(); + second.erase(); +} + +} // namespace + +// Implementation of the trtGraphFusePass. +void trtGraphFusePass::runOnFunction() { + mlir::Block &body = getFunction().front(); + ::mlir::OpBuilder builder(&body, body.begin()); + bool changed = false; + do { + changed = false; + for (auto &op : body) { + ::mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); + if (nullptr == graph_op) continue; + + for (auto user_op : op.getUsers()) { + ::mlir::pd::GraphOp user_graph_op = + ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op); + if (nullptr == user_graph_op) continue; + // get all dst input nodes except src. + std::vector<::mlir::Operation *> source_nodes; + for (auto operand : user_op->getOperands()) { + auto input = operand.getDefiningOp(); + if (input != &op) { + source_nodes.push_back(input); + } + } + // Reverse DFS from the source_nodes. + bool have_excess_path = false; + FlexibleDFS(source_nodes, + [&have_excess_path, &op](const ::mlir::Operation *n) { + if (n == &op) { + have_excess_path = true; + return false; + } + return true; + }); + if (!have_excess_path) { + mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); + changed = true; + break; + } + } + if (changed) break; + } + } while (changed); +} + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h new file mode 100644 index 0000000000..d5a019eff8 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -0,0 +1,63 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "mlir/Pass/Pass.h" + +namespace infrt { +namespace trt { +/* + * trtGraphFusePass. + * + * Merge the adjacent graph op to a new graph op. + * + * source func: + * + * func @main() -> tensor { + * %a = "pd.feed"() + * %c = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * "pd.fetch" %m + * } ... + * %d = "pd.graph"(%c) { + * %m = "pd.conv3d"(%c)... + * "pd.fetch" %m + * } ... + * %f = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * "pd.fetch" %m + * } ... + * "pd.fetch" %d, %f + * + * destination func: + * func @main() -> tensor { + * %a = "pd.feed"() + * %d, %f = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * %n = "pd.conv3d"(%m)... + * %s = "pd.conv2d"(%a)... + * "pd.fetch" %n, %s + * } ... + * "pd.fetch" %d, %f + * } + */ +class trtGraphFusePass + : public ::mlir::PassWrapper { + public: + ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } + void runOnFunction() override; +}; + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_base.td b/paddle/infrt/dialect/tensorrt/trt_op_base.td new file mode 100755 index 0000000000..5722f17d59 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_base.td @@ -0,0 +1,77 @@ +// This file defines some basic elements of Paddle(alias trt) dialect. +// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td + +#ifndef TRT_OP_BASE +#define TRT_OP_BASE + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def TRT_Dialect : Dialect { + let name = "trt"; + + let description = [{ + The PaddlePaddle dialect. + + This dialect contains the PaddlePaddle operators. + }]; + + let cppNamespace = "::infrt::trt"; +} + +class TRT_Op traits = []> : + Op; + + +class TRT_PaddleAttr : + Attr()">, + "PaddlePaddle " # description # " attribute">; + + +//===----------------------------------------------------------------------===// +// PaddlePaddle type definitions +//===----------------------------------------------------------------------===// + +def TRT_TRTDialectType : Type()">, "PaddlePaddle type">; + +class TRT_PaddleType : + Type()">, + "Paddle " # description # " type">, + BuildableType<"getType()">; + +//===----------------------------------------------------------------------===// +// Integer types +def TRT_Bool : AnyTypeOf<[I<1>], "bool">; +def TRT_Int8 : AnyTypeOf<[I8], "8-bit integer">; +def TRT_Int16 : AnyTypeOf<[I16], "16-bit integer">; +def TRT_Int32 : AnyTypeOf<[I32], "32-bit integer">; +def TRT_Int64 : AnyTypeOf<[I64], "64-bit integer">; + +def TRT_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">; +def TRT_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">; +def TRT_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">; +def TRT_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">; + +def TRT_SInt : AnyTypeOf<[TRT_Int8, TRT_Int16, TRT_Int32, TRT_Int64], "signed integer">; +def TRT_UInt : AnyTypeOf<[TRT_UInt8, TRT_UInt16, TRT_UInt32, TRT_UInt64], "unsigned integer">; +def TRT_Int : AnyTypeOf<[TRT_SInt, TRT_UInt], "integer">; + +// Float types +def TRT_Float16 : AnyTypeOf<[F16], "16-bit float">; +def TRT_Float32 : AnyTypeOf<[F32], "32-bit float">; +def TRT_Float64 : AnyTypeOf<[F64], "64-bit float">; + +def TRT_Float : AnyTypeOf<[TRT_Float16, TRT_Float32, TRT_Float64], "floating-point">; + + +// Tensor types + +def TRT_ElementType : Type, + "trt.dtype">; + +def TRT_Tensor : TensorOf<[TRT_ElementType]>; + + +#endif // TRT_OP_BASE diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc new file mode 100644 index 0000000000..14cffdbf98 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" + +#include "mlir/IR/Builders.h" +#include "paddle/infrt/dialect/pd_ops.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace infrt { +namespace trt { + +// Implementation of the trtOpTellerPass。 +void trtOpTellerPass::runOnFunction() { + ::mlir::Block &body = getFunction().front(); + std::vector<::mlir::Operation *> worklist; + worklist.reserve(body.getOperations().size()); + for (auto &op : body) { + worklist.push_back(&op); + } + // Build GraphOp. + ::mlir::OpBuilder builder(&body, body.begin()); + while (!worklist.empty()) { + auto *op = worklist.back(); + worklist.pop_back(); + if (op == nullptr) continue; + auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op); + if (op1) continue; + auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op); + if (op2) continue; + auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op); + if (op3) continue; + builder.setInsertionPoint(op); + auto loc = getFunction().getLoc(); + auto graph_op = builder.create<::mlir::pd::GraphOp>( + loc, op->getResultTypes(), op->getOperands()); + + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + for (auto v : + ::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) { + tblgen_repl_values.push_back(v); + } + op->replaceAllUsesWith(tblgen_repl_values); + // Build graph op. + ::mlir::Block *block = new ::mlir::Block; + graph_op.body().push_back(block); + op->moveBefore(block, block->begin()); + builder.setInsertionPointToEnd(block); + builder.create(loc, op->getResults()); + } +} + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h new file mode 100644 index 0000000000..e9cebaf36d --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "mlir/Pass/Pass.h" + +namespace infrt { +namespace trt { + +/* + * trtOpTellerPass. + * + * Pick out the operators supported by tensorrt and convert it to graph. + * + * source func: + * + * func @main() -> tensor { + * %a = "pd.feed"() + * %c = "pd.conv2d"(%a) ... + * %d = "pd.conv3d"(%c) ... + * %f = "pd.conv2d"(%a) ... + * "pd.fetch" %d, %f + * } + * + * destination func: + * func @main() -> tensor { + * %a = "pd.feed"() + * %c = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * "pd.fetch" %m + * } ... + * %d = "pd.graph"(%c) { + * %m = "pd.conv3d"(%c)... + * "pd.fetch" %m + * } ... + * %f = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * "pd.fetch" %m + * } ... + * "pd.fetch" %d, %f + * } + * TODO(winter-wang): Supplementary how to judge the operators can be supported + * by tensorrt. + */ +class trtOpTellerPass + : public ::mlir::PassWrapper { + public: + ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } + void runOnFunction() override; +}; + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.cc b/paddle/infrt/dialect/tensorrt/trt_ops.cc new file mode 100644 index 0000000000..4c02238b10 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_ops.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace infrt { +namespace trt { + +TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect("trt", context, ::mlir::TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT + >(); +#undef GET_OP_LIST +} + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT +#undef GET_OP_CLASSES + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h new file mode 100755 index 0000000000..176db98842 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -0,0 +1,50 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace infrt { +namespace trt { + +class TensorRTDialect : public ::mlir::Dialect { + public: + explicit TensorRTDialect(::mlir::MLIRContext* context); + static llvm::StringRef getDialectNamespace() { return "trt"; } +}; + +// mlir bug。 can be removed safety when update mlir to llvm11. +using namespace mlir; // NOLINT + +#define GET_OP_CLASSES +#include "dialect/tensorrt/trt_ops.hpp.inc" +#undef GET_OP_CLASSES + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td new file mode 100755 index 0000000000..cc072b6e68 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -0,0 +1,30 @@ +#ifndef TRT_OPS +#define TRT_OPS + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/IR/OpBase.td" +include "paddle/infrt/dialect/tensorrt/trt_op_base.td" + +def TRT_FetchOp : TRT_Op<"fetch", [Terminator]> { + let summary = "TensorRT engine return operation"; + let description = [{ + The `trt.fetch` operation terminates and returns values for the + `trt.graph` operation. + }]; + + let arguments = (ins Variadic:$inputs); +} + +def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { + let summary = "trt Graph Op"; + let description = [{ + Describe a tensorrt subgraph. + }]; + let regions = (region SizedRegion<1>:$body); + + let results = (outs Variadic:$outputs); + +} +#endif // TRT_OPS diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index 0e386ef950..2119c65664 100644 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -113,6 +113,9 @@ function main() { infrt_gen_and_build ${parallel_number} test_infrt ;; + build_only) + infrt_gen_and_build ${parallel_number} + ;; *) print_usage exit 1 -- GitLab