diff --git a/cmake/external/llvm.cmake b/cmake/external/llvm.cmake index 27210e5260048a57cc442fce4c6cf8657e401568..a7a9e85ffd7314ac7026fccdf45fae2fa3de09d3 100644 --- a/cmake/external/llvm.cmake +++ b/cmake/external/llvm.cmake @@ -99,7 +99,7 @@ endfunction() function(mlir_add_rewriter td_base) set(LLVM_TARGET_DEFINITIONS ${td_base}.td) - mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") + mlir_tablegen(${td_base}.cpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") add_public_tablegen_target(${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) endfunction() diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index dc22eecc99cdd8e8a0972683c3fbfd04ad9e481f..f2768f3dfa88d3405008baa7662f5e209ca3954c 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -97,6 +97,7 @@ set(infrt_mlir_incs pd_extra_ops_inc rewrite_inc trt_ops_inc + pd_lower_to_trt_inc ) if (INFRT_WITH_PHI) diff --git a/paddle/infrt/dialect/infrt_base.h b/paddle/infrt/dialect/infrt_base.h index a8e7e13a681caa4891c42ac01d2a759d878594d1..3ef73171dcdea4e0367837f4b3893405c29a1580 100644 --- a/paddle/infrt/dialect/infrt_base.h +++ b/paddle/infrt/dialect/infrt_base.h @@ -54,6 +54,20 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT return b.getIntegerAttr(b.getI32Type(), constant); } +template +static mlir::IntegerAttr createSI32Attr(mlir::OpBuilder &b, // NOLINT + mlir::Location loc, + T constant) { + return b.getSI32IntegerAttr(constant); +} + +template +static mlir::FloatAttr createF32Attr(mlir::OpBuilder &b, // NOLINT + mlir::Location loc, + T constant) { + return b.getF32FloatAttr(constant); +} + static mlir::SmallVector cvtValueToValueRange( const mlir::Value &operand) { return mlir::SmallVector(1, operand); diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td index 4d4727ee8e185032c6530cd293b0545283660e46..0f50eb2d8fb4ac83578f13888d05188a9143382f 100644 --- a/paddle/infrt/dialect/infrt_base.td +++ b/paddle/infrt/dialect/infrt_base.td @@ -28,6 +28,12 @@ def BufferType : OpaqueType<"b", "buffer", "buffer">; class INFRT_createI32Attr : NativeCodeCall< "infrt::createI32Attr($_builder, $_loc, " # value # ")">; +class INFRT_createSI32Attr : NativeCodeCall< + "infrt::createSI32Attr($_builder, $_loc, " # value # ")">; + +class INFRT_createF32Attr : NativeCodeCall< + "infrt::createF32Attr($_builder, $_loc, " # value # ")">; + def INFRT_cvtValueToValueRange : NativeCodeCall< "infrt::cvtValueToValueRange($0)">; diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc index 7cf5b2fb20f527eefe31f817c7fe85c7864c8669..338b04e001320289b71f6127318e7a073cefcacf 100644 --- a/paddle/infrt/dialect/pd_ops.cc +++ b/paddle/infrt/dialect/pd_ops.cc @@ -24,11 +24,11 @@ #define GET_OP_CLASSES #include "paddle/infrt/dialect/pd_extra_ops.cpp.inc" // NOLINT -#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT - namespace mlir { namespace pd { +#include "paddle/infrt/dialect/rewrite.cpp.inc" // NOLINT + PaddleDialect::PaddleDialect(MLIRContext *context) : Dialect("pd", context, TypeID::get()) { addOperations< diff --git a/paddle/infrt/dialect/tensorrt/CMakeLists.txt b/paddle/infrt/dialect/tensorrt/CMakeLists.txt index 794266513eb81b36655f44bfd1f6623216690ac5..99c335ed1782e8089f77bb3f21aadb00f6f6864f 100755 --- a/paddle/infrt/dialect/tensorrt/CMakeLists.txt +++ b/paddle/infrt/dialect/tensorrt/CMakeLists.txt @@ -2,11 +2,13 @@ core_gather_headers() gather_srcs(infrt_src SRCS trt_ops.cc + trt_op_converter_pass.cc trt_op_teller_pass.cc trt_graph_fuse_pass.cc trt_graph_split_pass.cc ) mlir_tablegen_on(trt_ops) +mlir_add_rewriter(pd_lower_to_trt) add_executable(trt-exec trt_exec.cc) target_link_libraries(trt-exec infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td new file mode 100644 index 0000000000000000000000000000000000000000..701391a750354938efe3703ef8642b21f8a878ea --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -0,0 +1,28 @@ +#ifndef PD_LOWER_TO_TRT +#define PD_LOWER_TO_TRT + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "paddle/infrt/dialect/infrt_base.td" +include "paddle/infrt/dialect/pd_ops.td" +include "paddle/infrt/dialect/tensorrt/trt_ops.td" + +def PD2TRT_Matmul_Lower : Pat< + (PD_MatmulOp $X, $Y, $transpose_X, $transpose_Y, ConstantAttr, ConstantAttr), + (TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>; + +//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ElementWiseOperation::kSUM +def PD2TRT_ElementwiseAdd_Lower : Pat< + (PD_Elementwise_addOp $X, $Y, ConstantAttr), + (TRT_ElementWiseOp $X, $Y, (INFRT_createSI32Attr<"0">)/*kSUM*/)>; + +//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kRELU +def PD2TRT_Relu_Lower : Pat< + (PD_ReluOp $X), + (TRT_ActivationOp $X, (INFRT_createSI32Attr<"0">)/*kRELU*/, (INFRT_createF32Attr<"0.0">), (INFRT_createF32Attr<"0.0">))>; + +//TO DO(shangzhizhou):replace '"INFRT_createI32Attr<"0">' to enum nvinfer1::ActivationType::kCLIP +def PD2TRT_Relu6_Lower : Pat< + (PD_Relu6Op $X, $threshold), + (TRT_ActivationOp $X, (INFRT_createSI32Attr<"8">)/*kCLIP*/, (INFRT_createF32Attr<"0.0">), $threshold)>; + +#endif // PD_LOWER_TO_TRT diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index 1baef7a3f77fdd9d3e363110ea3679aa942e222f..7af1fa53d12e3113d0fe51e7ba15bbd5c082456c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -19,6 +19,7 @@ #include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" int main(int argc, char** argv) { @@ -36,9 +37,10 @@ int main(int argc, char** argv) { mlir::PassManager pm(context); mlir::OpPassManager& trt_pass_manager = pm.nest(); - trt_pass_manager.addPass(std::make_unique()); - trt_pass_manager.addPass(std::make_unique()); - trt_pass_manager.addPass(std::make_unique(10)); + trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(std::make_unique(1)); + trt_pass_manager.addPass(std::make_unique()); if (mlir::failed(pm.run(*module))) { std::cout << "\npass failed!\n" << std::endl; return 4; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 1da80ef2c3b1000c045327510a03081f8aa954ca..17633a4e8e99293524e5ca635069267e27c2a603 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -142,7 +142,7 @@ void topoSortBlock(mlir::Block &body) { // NOLINT } // namespace // Implementation of the trtGraphFusePass. -void trtGraphFusePass::runOnFunction() { +void TRTGraphFusePass::runOnFunction() { mlir::Block &body = getFunction().front(); mlir::OpBuilder builder(&body, body.begin()); bool changed = false; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index f1e555c6f67ecaadff76fb17f68ebaae1a6528e1..ebd7a4ac4bd3712d98df4a097682787b3977ebfb 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -52,8 +52,8 @@ namespace trt { * "pd.fetch" %d, %f * } */ -class trtGraphFusePass - : public mlir::PassWrapper { +class TRTGraphFusePass + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index 257f2b528542557db33121a4c304eb8e6f657007..f24b9cc40cdcc2b065ea033cb03638e8d292df89 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -21,7 +21,7 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 -void trtGraphSplitPass::runOnFunction() { +void TRTGraphSplitPass::runOnFunction() { std::vector worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index d30d186647fc32aa4e16047000ee4071effb900d..51f84227243403f5a2299d820acad1b49592abc3 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -45,12 +45,12 @@ namespace trt { * "pd.fetch" (%d, %f) * } */ -class trtGraphSplitPass - : public mlir::PassWrapper { +class TRTGraphSplitPass + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } void runOnFunction() override; - explicit trtGraphSplitPass(size_t min_subgraph_size = 3) + explicit TRTGraphSplitPass(size_t min_subgraph_size = 3) : min_subgraph_size_(min_subgraph_size) {} private: diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e34308a2f0fa8c3c0142a62324f00c29b61fd7d3 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" +#include "mlir/IR/Builders.h" +#include "mlir/Transforms/DialectConversion.h" +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/pd_ops.h" + +namespace infrt { +namespace trt { + +#include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT + +using namespace mlir; + +void TRTOpConverterPass::runOnOperation() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to TensorRTDialect from + // PaddleDialect + target.addLegalDialect(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the TensorRT operations. + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..0adbf11b89144b0a9e14dc158e2eab1c56e2563a --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace infrt { +namespace trt { +/* + * trtOpConverterPass. + * + * source ir: + * func @main() -> tensor { + * %a = "pd.feed"()... + * %d, %f = "pd.graph"(%a) { + * %m = "pd.conv2d"(%a)... + * %n = "pd.conv3d"(%m)... + * %s = "pd.conv2d"(%a)... + * "pd.return" %n, %s + * } ... + * "pd.fetch" %d, %f + * } + * + * destination ir: + * func @main() -> tensor { + * %a = "pd.feed"()... + * %d, %f = "pd.graph"(%a) { + * %m = "trt.Convolution"(%a)... + * %n = "trt.Convolution"(%m)... + * %s = "trt.Convolution"(%a)... + * "pd.return" %n, %s + * } ... + * "pd.fetch" %d, %f + * } + */ +struct TRTOpConverterPass + : public mlir::PassWrapper> { + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + ::llvm::StringRef getName() const override { return "trtOpConverterPass"; } + void runOnOperation() final; +}; +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 4e8d40b982b2eaf13aeef4f026d783c3f353c14b..176fdb7a2e054ac2e0c952c7af27995cf8e3c433 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -20,7 +20,7 @@ namespace infrt { namespace trt { // Implementation of the trtOpTellerPass。 -void trtOpTellerPass::runOnFunction() { +void TRTOpTellerPass::runOnFunction() { mlir::Block &body = getFunction().front(); std::vector worklist; worklist.reserve(body.getOperations().size()); diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index fb16c974f7fb3f923bdc460d62d8e5b9f628fff9..8b9a16376ce5527b2133c9f2c2ecea928fb4cd8f 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -52,8 +52,8 @@ namespace trt { * TODO(winter-wang): Supplementary how to judge the operators can be supported * by tensorrt. */ -class trtOpTellerPass - : public mlir::PassWrapper { +class TRTOpTellerPass + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index cc072b6e6885bb68df5cf216fe210aded8a6ec6a..8e3dfffff54f13cc6d1f23c3459ed45257082d4f 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -23,8 +23,48 @@ def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); - + let arguments = (ins Variadic:$inputs); let results = (outs Variadic:$outputs); } + +def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { + let summary = "TensorRT IActivationLayer"; + let description = [{ + + TensorRT IActivationLayer. + + }]; + let arguments = (ins TRT_Tensor:$input, SI32Attr:$activation_type, + DefaultValuedAttr:$alpha, + DefaultValuedAttr:$beta); + + let results = (outs TRT_Tensor:$output); +} + +def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> { + let summary = "TensorRT IElementWiseLayer"; + let description = [{ + + TensorRT IElementWiseLayer. + + }]; + let arguments = (ins TRT_Tensor:$input1, TRT_Tensor:$input2, SI32Attr:$elementwise_operation); + + let results = (outs TRT_Tensor:$output); +} + +def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> { + let summary = "TensorRT IMatrixMultiplyLayer"; + let description = [{ + + TensorRT IMatrixMultiplyLayer. + + }]; + let arguments = (ins TRT_Tensor:$input1, BoolAttr:$transpose1, + TRT_Tensor:$input2, BoolAttr:$transpose2); + + let results = (outs TRT_Tensor:$output); +} + #endif // TRT_OPS diff --git a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir b/paddle/infrt/tests/dialect/disabled_trt_ops.mlir index 75ec98f04661a7d8cfe55c5fbea9dbc87933ad18..b59cfb04816974cbdb923e6d18af1184be963c59 100644 --- a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir +++ b/paddle/infrt/tests/dialect/disabled_trt_ops.mlir @@ -7,15 +7,15 @@ func @main() -> tensor { %bias1 = "pd.feed"() {name="input4"} : () -> tensor %bias2 = "pd.feed"() {name="input5"} : () -> tensor - %d = "pd.elementwise_add"(%c, %bias) {axis=1:si32} : (tensor, tensor) -> tensor + %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor, tensor) -> tensor %e = "pd.relu6"(%d) {} : (tensor) -> tensor %c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor - %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=1:si32} : (tensor, tensor) -> tensor + %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (tensor, tensor) -> tensor %e1 = "pd.relu"(%d1) {} : (tensor) -> tensor %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:si32} : (tensor, tensor) -> tensor + %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor "pd.fetch"(%e2) {name="output"} :(tensor)->()