From 35471b1f8af735aad079ca2787c881499d7829d6 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Mon, 28 Feb 2022 13:20:38 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90infrt=E3=80=91add=20TrtOpConverterPass?= =?UTF-8?q?=20(#39902)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add some trt layers * trtOpConverter pass ok * add comments * add constraints to some attrs in the pd_lower_to_trt patterns * update constraint * fix code style * update pass name * update code style * change .hpp.inc to .cc.inc in mlir_add_rewriter --- cmake/external/llvm.cmake | 2 +- paddle/infrt/CMakeLists.txt | 1 + paddle/infrt/dialect/infrt_base.h | 14 +++++ paddle/infrt/dialect/infrt_base.td | 6 ++ paddle/infrt/dialect/pd_ops.cc | 4 +- paddle/infrt/dialect/tensorrt/CMakeLists.txt | 2 + .../infrt/dialect/tensorrt/pd_lower_to_trt.td | 28 +++++++++ paddle/infrt/dialect/tensorrt/trt_exec.cc | 8 ++- .../dialect/tensorrt/trt_graph_fuse_pass.cc | 2 +- .../dialect/tensorrt/trt_graph_fuse_pass.h | 4 +- .../dialect/tensorrt/trt_graph_split_pass.cc | 2 +- .../dialect/tensorrt/trt_graph_split_pass.h | 6 +- .../dialect/tensorrt/trt_op_converter_pass.cc | 51 ++++++++++++++++ .../dialect/tensorrt/trt_op_converter_pass.h | 59 +++++++++++++++++++ .../dialect/tensorrt/trt_op_teller_pass.cc | 2 +- .../dialect/tensorrt/trt_op_teller_pass.h | 4 +- paddle/infrt/dialect/tensorrt/trt_ops.td | 42 ++++++++++++- .../infrt/tests/dialect/disabled_trt_ops.mlir | 6 +- 18 files changed, 223 insertions(+), 20 deletions(-) create mode 100644 paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td create mode 100644 paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc create mode 100644 paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h diff --git a/cmake/external/llvm.cmake b/cmake/external/llvm.cmake index 27210e5260..a7a9e85ffd 100644 --- a/cmake/external/llvm.cmake +++ b/cmake/external/llvm.cmake @@ -99,7 +99,7 @@ endfunction() function(mlir_add_rewriter td_base) set(LLVM_TARGET_DEFINITIONS ${td_base}.td) - mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") + mlir_tablegen(${td_base}.cpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") add_public_tablegen_target(${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) endfunction() diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index dc22eecc99..f2768f3dfa 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -97,6 +97,7 @@ set(infrt_mlir_incs pd_extra_ops_inc rewrite_inc trt_ops_inc + pd_lower_to_trt_inc ) if (INFRT_WITH_PHI) diff --git a/paddle/infrt/dialect/infrt_base.h b/paddle/infrt/dialect/infrt_base.h index a8e7e13a68..3ef73171dc 100644 --- a/paddle/infrt/dialect/infrt_base.h +++ b/paddle/infrt/dialect/infrt_base.h @@ -54,6 +54,20 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT return b.getIntegerAttr(b.getI32Type(), constant); } +template +static mlir::IntegerAttr createSI32Attr(mlir::OpBuilder &b, // NOLINT + mlir::Location loc, + T constant) { + return b.getSI32IntegerAttr(constant); +} + +template +static mlir::FloatAttr createF32Attr(mlir::OpBuilder &b, // NOLINT + mlir::Location loc, + T constant) { + return b.getF32FloatAttr(constant); +} + static mlir::SmallVector cvtValueToValueRange( const mlir::Value &operand) { return mlir::SmallVector(1, operand); diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td index 4d4727ee8e..0f50eb2d8f 100644 --- a/paddle/infrt/dialect/infrt_base.td +++ b/paddle/infrt/dialect/infrt_base.td @@ -28,6 +28,12 @@ def BufferType : OpaqueType<"b", "buffer", "buffer">; class INFRT_createI32Attr : NativeCodeCall< "infrt::createI32Attr($_builder, $_loc, " # value # ")">; +class INFRT_createSI32Attr : NativeCodeCall< + "infrt::createSI32Attr($_builder, $_loc, " # value # ")">; + +class INFRT_createF32Attr : NativeCodeCall< + "infrt::createF32Attr($_builder, $_loc, " # value # ")">; + def INFRT_cvtValueToValueRange : NativeCodeCall< "infrt::cvtValueToValueRange($0)">; diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc index 7cf5b2fb20..338b04e001 100644 --- a/paddle/infrt/dialect/pd_ops.cc +++ b/paddle/infrt/dialect/pd_ops.cc @@ -24,11 +24,11 @@ #define GET_OP_CLASSES #include "paddle/infrt/dialect/pd_extra_ops.cpp.inc" // NOLINT -#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT - namespace mlir { namespace pd { +#include "paddle/infrt/dialect/rewrite.cpp.inc" // NOLINT + PaddleDialect::PaddleDialect(MLIRContext *context) : Dialect("pd", context, TypeID::get()) { addOperations< diff --git a/paddle/infrt/dialect/tensorrt/CMakeLists.txt b/paddle/infrt/dialect/tensorrt/CMakeLists.txt index 794266513e..99c335ed17 100755 --- a/paddle/infrt/dialect/tensorrt/CMakeLists.txt +++ b/paddle/infrt/dialect/tensorrt/CMakeLists.txt @@ -2,11 +2,13 @@ core_gather_headers() gather_srcs(infrt_src SRCS trt_ops.cc + trt_op_converter_pass.cc trt_op_teller_pass.cc trt_graph_fuse_pass.cc trt_graph_split_pass.cc ) mlir_tablegen_on(trt_ops) +mlir_add_rewriter(pd_lower_to_trt) add_executable(trt-exec trt_exec.cc) target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td new file mode 100644 index 0000000000..701391a750 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -0,0 +1,28 @@ +#ifndef PD_LOWER_TO_TRT +#define PD_LOWER_TO_TRT + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "paddle/infrt/dialect/infrt_base.td" +include "paddle/infrt/dialect/pd_ops.td" +include "paddle/infrt/dialect/tensorrt/trt_ops.td" + +def PD2TRT_Matmul_Lower : Pat< + (PD_MatmulOp $X, $Y, $transpose_X, $transpose_Y, ConstantAttr, ConstantAttr), + (TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>; + +//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ElementWiseOperation::kSUM +def PD2TRT_ElementwiseAdd_Lower : Pat< + (PD_Elementwise_addOp $X, $Y, ConstantAttr), + (TRT_ElementWiseOp $X, $Y, (INFRT_createSI32Attr<"0">)/*kSUM*/)>; + +//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kRELU +def PD2TRT_Relu_Lower : Pat< + (PD_ReluOp $X), + (TRT_ActivationOp $X, (INFRT_createSI32Attr<"0">)/*kRELU*/, (INFRT_createF32Attr<"0.0">), (INFRT_createF32Attr<"0.0">))>; + +//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kCLIP +def PD2TRT_Relu6_Lower : Pat< + (PD_Relu6Op $X, $threshold), + (TRT_ActivationOp $X, (INFRT_createSI32Attr<"8">)/*kCLIP*/, (INFRT_createF32Attr<"0.0">), $threshold)>; + +#endif // PD_LOWER_TO_TRT diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index 1baef7a3f7..7af1fa53d1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -19,6 +19,7 @@ #include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" int main(int argc, char** argv) { @@ -36,9 +37,10 @@ int main(int argc, char** argv) { mlir::PassManager pm(context); mlir::OpPassManager& trt_pass_manager = pm.nest(); - trt_pass_manager.addPass(std::make_unique()); - trt_pass_manager.addPass(std::make_unique()); - trt_pass_manager.addPass(std::make_unique(10)); + trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(std::make_unique(1)); + trt_pass_manager.addPass(std::make_unique()); if (mlir::failed(pm.run(*module))) { std::cout << "\npass failed!\n" << std::endl; return 4; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 1da80ef2c3..17633a4e8e 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -142,7 +142,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT } // namespace // Implementation of the trtGraphFusePass. -void trtGraphFusePass::runOnFunction() { +void TRTGraphFusePass::runOnFunction() { mlir::Block &body = getFunction().front(); mlir::OpBuilder builder(&body, body.begin()); bool changed = false; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index f1e555c6f6..ebd7a4ac4b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -52,8 +52,8 @@ namespace trt { * "pd.fetch" %d, %f * } */ -class trtGraphFusePass - : public mlir::PassWrapper { +class TRTGraphFusePass + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index 257f2b5285..f24b9cc40c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -21,7 +21,7 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 -void trtGraphSplitPass::runOnFunction() { +void TRTGraphSplitPass::runOnFunction() { std::vector worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index d30d186647..51f8422724 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -45,12 +45,12 @@ namespace trt { * "pd.fetch" (%d, %f) * } */ -class trtGraphSplitPass - : public mlir::PassWrapper { +class TRTGraphSplitPass + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } void runOnFunction() override; - explicit trtGraphSplitPass(size_t min_subgraph_size = 3) + explicit TRTGraphSplitPass(size_t min_subgraph_size = 3) : min_subgraph_size_(min_subgraph_size) {} private: diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc new file mode 100644 index 0000000000..e34308a2f0 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/DialectConversion.h" +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/pd_ops.h" + +namespace infrt { +namespace trt { + +#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT + +using namespace mlir; + +void TRTOpConverterPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to TensorRTDialect from + // PaddleDialect + target.addLegalDialect(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the TensorRT operations. + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h new file mode 100644 index 0000000000..0adbf11b89 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace infrt { +namespace trt { +/* + * trtOpConverterPass. + * + * source ir: + * func @main() -> tensor { + * %a = "pd.feed"()... + * %d, %f = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * %n = "pd.conv3d"(%m)... + * %s = "pd.conv2d"(%a)... + * "pd.return" %n, %s + * } ... + * "pd.fetch" %d, %f + * } + * + * destination ir: + * func @main() -> tensor { + * %a = "pd.feed"()... + * %d, %f = "pd.graph"(%a) { + * %m = "trt.Convolution"(%a)... + * %n = "trt.Convolution"(%m)... + * %s = "trt.Convolution"(%a)... + * "pd.return" %n, %s + * } ... + * "pd.fetch" %d, %f + * } + */ +struct TRTOpConverterPass + : public mlir::PassWrapper> { + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + ::llvm::StringRef getName() const override { return "trtOpConverterPass"; } + void runOnOperation() final; +}; +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 4e8d40b982..176fdb7a2e 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -20,7 +20,7 @@ namespace infrt { namespace trt { // Implementation of the trtOpTellerPass。 -void trtOpTellerPass::runOnFunction() { +void TRTOpTellerPass::runOnFunction() { mlir::Block &body = getFunction().front(); std::vector worklist; worklist.reserve(body.getOperations().size()); diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index fb16c974f7..8b9a16376c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -52,8 +52,8 @@ namespace trt { * TODO(winter-wang): Supplementary how to judge the operators can be supported * by tensorrt. */ -class trtOpTellerPass - : public mlir::PassWrapper { +class TRTOpTellerPass + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index cc072b6e68..8e3dfffff5 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -23,8 +23,48 @@ def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); - + let arguments = (ins Variadic:$inputs); let results = (outs Variadic:$outputs); } + +def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { + let summary = "TensorRT IActivationLayer"; + let description = [{ + + TensorRT IActivationLayer. + + }]; + let arguments = (ins TRT_Tensor:$input, SI32Attr:$activation_type, + DefaultValuedAttr:$alpha, + DefaultValuedAttr:$beta); + + let results = (outs TRT_Tensor:$output); +} + +def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> { + let summary = "TensorRT IElementWiseLayer"; + let description = [{ + + TensorRT IElementWiseLayer. + + }]; + let arguments = (ins TRT_Tensor:$input1, TRT_Tensor:$input2, SI32Attr:$elementwise_operation); + + let results = (outs TRT_Tensor:$output); +} + +def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> { + let summary = "TensorRT IMatrixMultiplyLayer"; + let description = [{ + + TensorRT IMatrixMultiplyLayer. + + }]; + let arguments = (ins TRT_Tensor:$input1, BoolAttr:$transpose1, + TRT_Tensor:$input2, BoolAttr:$transpose2); + + let results = (outs TRT_Tensor:$output); +} + #endif // TRT_OPS diff --git a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir b/paddle/infrt/tests/dialect/disabled_trt_ops.mlir index 75ec98f046..b59cfb0481 100644 --- a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir +++ b/paddle/infrt/tests/dialect/disabled_trt_ops.mlir @@ -7,15 +7,15 @@ func @main() -> tensor { %bias1 = "pd.feed"() {name="input4"} : () -> tensor %bias2 = "pd.feed"() {name="input5"} : () -> tensor - %d = "pd.elementwise_add"(%c, %bias) {axis=1:si32} : (tensor, tensor) -> tensor + %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor, tensor) -> tensor %e = "pd.relu6"(%d) {} : (tensor) -> tensor %c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor - %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=1:si32} : (tensor, tensor) -> tensor + %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (tensor, tensor) -> tensor %e1 = "pd.relu"(%d1) {} : (tensor) -> tensor %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:si32} : (tensor, tensor) -> tensor + %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor "pd.fetch"(%e2) {name="output"} :(tensor)->() -- GitLab