From acdf0663ae98fee60ea61ef25bb3e8af7d88f6b4 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Wed, 2 Mar 2022 09:42:20 +0800 Subject: [PATCH] update pd_2_trt lower pass (#40019) * update pd_2_trt lower pass * update pd_2_trt lower pass * update style * udpate * change trt.graph to trt.create_engine * update comments * update comments * add test --- .../dialect/tensorrt/trt_graph_fuse_pass.cc | 20 +++++++++--------- .../dialect/tensorrt/trt_graph_fuse_pass.h | 21 ++++++++++++------- .../dialect/tensorrt/trt_graph_split_pass.cc | 7 +++---- .../dialect/tensorrt/trt_graph_split_pass.h | 9 ++++++-- .../dialect/tensorrt/trt_op_converter_pass.cc | 12 +++++------ .../dialect/tensorrt/trt_op_converter_pass.h | 8 +++---- .../dialect/tensorrt/trt_op_teller_pass.cc | 17 +++++++-------- .../dialect/tensorrt/trt_op_teller_pass.h | 17 +++++++++------ paddle/infrt/dialect/tensorrt/trt_ops.h | 1 + paddle/infrt/dialect/tensorrt/trt_ops.td | 15 ++----------- .../{disabled_trt_ops.mlir => trt_ops.mlir} | 1 + paddle/infrt/tests/lit.cfg.py.in | 3 ++- 12 files changed, 67 insertions(+), 64 deletions(-) rename paddle/infrt/tests/dialect/{disabled_trt_ops.mlir => trt_ops.mlir} (98%) diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 17633a4e8e9..fa0095363c5 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -53,9 +53,9 @@ bool reverseDfs(std::vector source, } // merge the first&second graph op to a new graph op. -void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT - mlir::pd::GraphOp first, - mlir::pd::GraphOp second) { +void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT + CreateEngineOp first, + CreateEngineOp second) { // comput inputs and outputs ::llvm::SmallVector inputs(first.getOperands()), outputs; for (mlir::Value input : second.getOperands()) { @@ -84,7 +84,8 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = builder.create(loc, return_types, inputs); + auto graph_op = + builder.create(loc, return_types, inputs, true); mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), @@ -97,7 +98,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT copy_range.begin(), copy_range.end()); builder.setInsertionPointToEnd(block); - builder.create(loc, outputs); + builder.create<::infrt::dialect::ReturnOp>(loc, outputs); graph_op.body().push_back(block); // mapping the output @@ -149,13 +150,12 @@ void TRTGraphFusePass::runOnFunction() { do { changed = false; for (auto &op : body) { - mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + CreateEngineOp graph_op = ::llvm::dyn_cast_or_null(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - mlir::pd::GraphOp user_graph_op = - ::llvm::dyn_cast_or_null(user_op); + CreateEngineOp user_graph_op = + ::llvm::dyn_cast_or_null(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. std::vector source_nodes; @@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() { // Reverse DFS from the source_nodes. if (!reverseDfs(source_nodes, [&op](const mlir::Operation *n) { return n == &op; })) { - mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); + mergeTwoAdjacentCreateEngineOp(builder, graph_op, user_graph_op); changed = true; break; } diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index ebd7a4ac4bd..350add905aa 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -26,28 +28,28 @@ namespace trt { * * func @main() -> tensor { * %a = "pd.feed"()... - * %c = "pd.graph"(%a) { + * %c = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" %m + * "Infrt.return" %m * } ... - * %d = "pd.graph"(%c) { + * %d = "trt.create_engine"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.return" %m + * "Infrt.return" %m * } ... - * %f = "pd.graph"(%a) { + * %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" %m + * "Infrt.return" %m * } ... * "pd.fetch" %d, %f * * destination func: * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.return" %n, %s + * "Infrt.return" %n, %s * } ... * "pd.fetch" %d, %f * } @@ -55,6 +57,9 @@ namespace trt { class TRTGraphFusePass : public mlir::PassWrapper { public: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; }; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index f24b9cc40cd..5ee7b23213a 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -22,18 +22,17 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void TRTGraphSplitPass::runOnFunction() { - std::vector worklist; + std::vector worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { - mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + CreateEngineOp graph_op = ::llvm::dyn_cast_or_null(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - mlir::pd::GraphOp graph_op = worklist.back(); + CreateEngineOp graph_op = worklist.back(); worklist.pop_back(); mlir::Block* body = graph_op.getBody(); auto return_op = body->getTerminator(); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index 51f84227243..28078e2bc2d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -27,11 +29,11 @@ namespace trt { * * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.return" (%n, %s) + * "Infrt.return" (%n, %s) * } ... * "pd.fetch" (%d, %f) * } @@ -49,6 +51,9 @@ class TRTGraphSplitPass : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } void runOnFunction() override; explicit TRTGraphSplitPass(size_t min_subgraph_size = 3) : min_subgraph_size_(min_subgraph_size) {} diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index e34308a2f0f..8d81e739d9c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" -#include "mlir/IR/Builders.h" -#include "mlir/Transforms/DialectConversion.h" +#include +#include #include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/pd_ops.h" @@ -22,12 +22,10 @@ namespace trt { #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT -using namespace mlir; - void TRTOpConverterPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. - ConversionTarget target(getContext()); + ::mlir::ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to TensorRTDialect from @@ -36,13 +34,13 @@ void TRTOpConverterPass::runOnOperation() { // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the TensorRT operations. - RewritePatternSet patterns(&getContext()); + ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed( + if (::mlir::failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h index 0adbf11b891..a8128a585ee 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -25,11 +25,11 @@ namespace trt { * source ir: * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.return" %n, %s + * "Infrt.return" %n, %s * } ... * "pd.fetch" %d, %f * } @@ -37,11 +37,11 @@ namespace trt { * destination ir: * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "trt.Convolution"(%a)... * %n = "trt.Convolution"(%m)... * %s = "trt.Convolution"(%a)... - * "pd.return" %n, %s + * "Infrt.return" %n, %s * } ... * "pd.fetch" %d, %f * } diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 176fdb7a2e0..17e893a383a 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -15,6 +15,7 @@ #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" #include +#include "paddle/infrt/dialect/basic_kernels.h" #include "paddle/infrt/dialect/pd_ops.h" namespace infrt { @@ -33,16 +34,14 @@ void TRTOpTellerPass::runOnFunction() { auto *op = worklist.back(); worklist.pop_back(); if (op == nullptr) continue; - auto op1 = ::llvm::dyn_cast_or_null(op); - if (op1) continue; - auto op2 = ::llvm::dyn_cast_or_null(op); - if (op2) continue; - auto op3 = ::llvm::dyn_cast_or_null(op); - if (op3) continue; + if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null(op)) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create( - loc, op->getResultTypes(), op->getOperands()); + auto graph_op = builder.create( + loc, op->getResultTypes(), op->getOperands(), true); ::llvm::SmallVector tblgen_repl_values; for (auto v : @@ -55,7 +54,7 @@ void TRTOpTellerPass::runOnFunction() { graph_op.body().push_back(block); op->moveBefore(block, block->begin()); builder.setInsertionPointToEnd(block); - builder.create(loc, op->getResults()); + builder.create<::infrt::dialect::ReturnOp>(loc, op->getResults()); } } } // namespace trt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index 8b9a16376ce..471eafa9f9b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -35,17 +37,17 @@ namespace trt { * destination func: * func @main() -> tensor { * %a = "pd.feed"()... - * %c = "pd.graph"(%a) { + * %c = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" (%m) + * "Infrt.return" (%m) * } ... - * %d = "pd.graph"(%c) { + * %d = "trt.create_engine"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.return" (%m) + * "Infrt.return" (%m) * } ... - * %f = "pd.graph"(%a) { + * %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" (%m) + * "Infrt.return" (%m) * } ... * "pd.fetch" (%d, %f) * } @@ -55,6 +57,9 @@ namespace trt { class TRTOpTellerPass : public mlir::PassWrapper { public: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; }; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h index a37491ec1ab..95b2ed41fdf 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.h +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -28,6 +28,7 @@ #include #include #include +#include "paddle/infrt/dialect/basic_kernels.h" namespace infrt { namespace trt { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 8e3dfffff54..31142a5157b 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -7,25 +7,14 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/tensorrt/trt_op_base.td" -def TRT_FetchOp : TRT_Op<"fetch", [Terminator]> { - let summary = "TensorRT engine return operation"; - let description = [{ - The `trt.fetch` operation terminates and returns values for the - `trt.graph` operation. - }]; - - let arguments = (ins Variadic:$inputs); -} - -def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { +def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::dialect::ReturnOp">]> { let summary = "trt Graph Op"; let description = [{ Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); - let arguments = (ins Variadic:$inputs); + let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); let results = (outs Variadic:$outputs); - } def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { diff --git a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir b/paddle/infrt/tests/dialect/trt_ops.mlir similarity index 98% rename from paddle/infrt/tests/dialect/disabled_trt_ops.mlir rename to paddle/infrt/tests/dialect/trt_ops.mlir index b59cfb04816..49510bc542d 100644 --- a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir +++ b/paddle/infrt/tests/dialect/trt_ops.mlir @@ -1,3 +1,4 @@ +// RUN: trt-exec %s // CHECK-LABEL: @main func @main() -> tensor { %bias = "pd.feed"() {name="input0"} : () -> tensor diff --git a/paddle/infrt/tests/lit.cfg.py.in b/paddle/infrt/tests/lit.cfg.py.in index 19ee0076b55..d47957dac92 100644 --- a/paddle/infrt/tests/lit.cfg.py.in +++ b/paddle/infrt/tests/lit.cfg.py.in @@ -21,10 +21,11 @@ build_dir = "@CMAKE_BINARY_DIR@" config.llvm_tools_dir = os.path.join(build_dir, "third_party/install/llvm/bin") config.llvm_tools_dir = os.path.join(build_dir, "/third_party/install/llvm/lib") infrtopt_bin = os.path.join(build_dir, "paddle/infrt/dialect/") +trtexec_bin = os.path.join(build_dir, "paddle/infrt/dialect/tensorrt/") infrtexec_bin = os.path.join(build_dir, "paddle/infrt/host_context/") llvm_bin = os.path.join(build_dir, "third_party/install/llvm/bin/") config.environment['PATH'] = os.path.pathsep.join( - (infrtopt_bin, infrtexec_bin, llvm_bin, config.environment['PATH'])) + (infrtopt_bin, infrtexec_bin, trtexec_bin, llvm_bin, config.environment['PATH'])) config.suffixes = ['.mlir'] -- GitLab