diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 17633a4e8e99293524e5ca635069267e27c2a603..fa0095363c5fd34778162fb4e3204450ef1e7815 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -53,9 +53,9 @@ bool reverseDfs(std::vector source, } // merge the first&second graph op to a new graph op. -void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT - mlir::pd::GraphOp first, - mlir::pd::GraphOp second) { +void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT + CreateEngineOp first, + CreateEngineOp second) { // comput inputs and outputs ::llvm::SmallVector inputs(first.getOperands()), outputs; for (mlir::Value input : second.getOperands()) { @@ -84,7 +84,8 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = builder.create(loc, return_types, inputs); + auto graph_op = + builder.create(loc, return_types, inputs, true); mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), @@ -97,7 +98,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT copy_range.begin(), copy_range.end()); builder.setInsertionPointToEnd(block); - builder.create(loc, outputs); + builder.create<::infrt::dialect::ReturnOp>(loc, outputs); graph_op.body().push_back(block); // mapping the output @@ -149,13 +150,12 @@ void TRTGraphFusePass::runOnFunction() { do { changed = false; for (auto &op : body) { - mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + CreateEngineOp graph_op = ::llvm::dyn_cast_or_null(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - mlir::pd::GraphOp user_graph_op = - ::llvm::dyn_cast_or_null(user_op); + CreateEngineOp user_graph_op = + ::llvm::dyn_cast_or_null(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. std::vector source_nodes; @@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() { // Reverse DFS from the source_nodes. if (!reverseDfs(source_nodes, [&op](const mlir::Operation *n) { return n == &op; })) { - mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); + mergeTwoAdjacentCreateEngineOp(builder, graph_op, user_graph_op); changed = true; break; } diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index ebd7a4ac4bd3712d98df4a097682787b3977ebfb..350add905aac75c0ba8527aa6e9bc1510fab876c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -26,28 +28,28 @@ namespace trt { * * func @main() -> tensor { * %a = "pd.feed"()... - * %c = "pd.graph"(%a) { + * %c = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" %m + * "Infrt.return" %m * } ... - * %d = "pd.graph"(%c) { + * %d = "trt.create_engine"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.return" %m + * "Infrt.return" %m * } ... - * %f = "pd.graph"(%a) { + * %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" %m + * "Infrt.return" %m * } ... * "pd.fetch" %d, %f * * destination func: * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.return" %n, %s + * "Infrt.return" %n, %s * } ... * "pd.fetch" %d, %f * } @@ -55,6 +57,9 @@ namespace trt { class TRTGraphFusePass : public mlir::PassWrapper { public: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; }; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index f24b9cc40cdcc2b065ea033cb03638e8d292df89..5ee7b23213a0106d3491712c37d34940f7c15c58 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -22,18 +22,17 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void TRTGraphSplitPass::runOnFunction() { - std::vector worklist; + std::vector worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { - mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + CreateEngineOp graph_op = ::llvm::dyn_cast_or_null(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - mlir::pd::GraphOp graph_op = worklist.back(); + CreateEngineOp graph_op = worklist.back(); worklist.pop_back(); mlir::Block* body = graph_op.getBody(); auto return_op = body->getTerminator(); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index 51f84227243403f5a2299d820acad1b49592abc3..28078e2bc2dbff46d6a9eaf5522b949f68785898 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -27,11 +29,11 @@ namespace trt { * * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.return" (%n, %s) + * "Infrt.return" (%n, %s) * } ... * "pd.fetch" (%d, %f) * } @@ -49,6 +51,9 @@ class TRTGraphSplitPass : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } void runOnFunction() override; explicit TRTGraphSplitPass(size_t min_subgraph_size = 3) : min_subgraph_size_(min_subgraph_size) {} diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index e34308a2f0fa8c3c0142a62324f00c29b61fd7d3..8d81e739d9c72ebcaa57b927a360864db59d7e97 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" -#include "mlir/IR/Builders.h" -#include "mlir/Transforms/DialectConversion.h" +#include +#include #include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/pd_ops.h" @@ -22,12 +22,10 @@ namespace trt { #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT -using namespace mlir; - void TRTOpConverterPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. - ConversionTarget target(getContext()); + ::mlir::ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to TensorRTDialect from @@ -36,13 +34,13 @@ void TRTOpConverterPass::runOnOperation() { // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the TensorRT operations. - RewritePatternSet patterns(&getContext()); + ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed( + if (::mlir::failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h index 0adbf11b89144b0a9e14dc158e2eab1c56e2563a..a8128a585ee82dc60811c65b1105beb33e8c3b18 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -25,11 +25,11 @@ namespace trt { * source ir: * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.return" %n, %s + * "Infrt.return" %n, %s * } ... * "pd.fetch" %d, %f * } @@ -37,11 +37,11 @@ namespace trt { * destination ir: * func @main() -> tensor { * %a = "pd.feed"()... - * %d, %f = "pd.graph"(%a) { + * %d, %f = "trt.create_engine"(%a) { * %m = "trt.Convolution"(%a)... * %n = "trt.Convolution"(%m)... * %s = "trt.Convolution"(%a)... - * "pd.return" %n, %s + * "Infrt.return" %n, %s * } ... * "pd.fetch" %d, %f * } diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 176fdb7a2e054ac2e0c952c7af27995cf8e3c433..17e893a383a9cd3f893e80181858dc3cc2b0552b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -15,6 +15,7 @@ #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" #include +#include "paddle/infrt/dialect/basic_kernels.h" #include "paddle/infrt/dialect/pd_ops.h" namespace infrt { @@ -33,16 +34,14 @@ void TRTOpTellerPass::runOnFunction() { auto *op = worklist.back(); worklist.pop_back(); if (op == nullptr) continue; - auto op1 = ::llvm::dyn_cast_or_null(op); - if (op1) continue; - auto op2 = ::llvm::dyn_cast_or_null(op); - if (op2) continue; - auto op3 = ::llvm::dyn_cast_or_null(op); - if (op3) continue; + if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null(op)) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create( - loc, op->getResultTypes(), op->getOperands()); + auto graph_op = builder.create( + loc, op->getResultTypes(), op->getOperands(), true); ::llvm::SmallVector tblgen_repl_values; for (auto v : @@ -55,7 +54,7 @@ void TRTOpTellerPass::runOnFunction() { graph_op.body().push_back(block); op->moveBefore(block, block->begin()); builder.setInsertionPointToEnd(block); - builder.create(loc, op->getResults()); + builder.create<::infrt::dialect::ReturnOp>(loc, op->getResults()); } } } // namespace trt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index 8b9a16376ce5527b2133c9f2c2ecea928fb4cd8f..471eafa9f9ba33dad4182ba7da55a607c2bf8f0d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -14,6 +14,8 @@ #pragma once #include +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -35,17 +37,17 @@ namespace trt { * destination func: * func @main() -> tensor { * %a = "pd.feed"()... - * %c = "pd.graph"(%a) { + * %c = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" (%m) + * "Infrt.return" (%m) * } ... - * %d = "pd.graph"(%c) { + * %d = "trt.create_engine"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.return" (%m) + * "Infrt.return" (%m) * } ... - * %f = "pd.graph"(%a) { + * %f = "trt.create_engine"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.return" (%m) + * "Infrt.return" (%m) * } ... * "pd.fetch" (%d, %f) * } @@ -55,6 +57,9 @@ namespace trt { class TRTOpTellerPass : public mlir::PassWrapper { public: + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; }; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h index a37491ec1abc7fd423fef23df5170936d2a769c7..95b2ed41fdfe9c5fdc7832dba46427528aee1332 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.h +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -28,6 +28,7 @@ #include #include #include +#include "paddle/infrt/dialect/basic_kernels.h" namespace infrt { namespace trt { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 8e3dfffff54f13cc6d1f23c3459ed45257082d4f..31142a5157bfcd544128671fbdf22a993f1cc646 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -7,25 +7,14 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/tensorrt/trt_op_base.td" -def TRT_FetchOp : TRT_Op<"fetch", [Terminator]> { - let summary = "TensorRT engine return operation"; - let description = [{ - The `trt.fetch` operation terminates and returns values for the - `trt.graph` operation. - }]; - - let arguments = (ins Variadic:$inputs); -} - -def TRT_GraphOp : TRT_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { +def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::dialect::ReturnOp">]> { let summary = "trt Graph Op"; let description = [{ Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); - let arguments = (ins Variadic:$inputs); + let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); let results = (outs Variadic:$outputs); - } def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { diff --git a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir b/paddle/infrt/tests/dialect/trt_ops.mlir similarity index 98% rename from paddle/infrt/tests/dialect/disabled_trt_ops.mlir rename to paddle/infrt/tests/dialect/trt_ops.mlir index b59cfb04816974cbdb923e6d18af1184be963c59..49510bc542dc0409067b5d61cb189dfab8b6601f 100644 --- a/paddle/infrt/tests/dialect/disabled_trt_ops.mlir +++ b/paddle/infrt/tests/dialect/trt_ops.mlir @@ -1,3 +1,4 @@ +// RUN: trt-exec %s // CHECK-LABEL: @main func @main() -> tensor { %bias = "pd.feed"() {name="input0"} : () -> tensor diff --git a/paddle/infrt/tests/lit.cfg.py.in b/paddle/infrt/tests/lit.cfg.py.in index 19ee0076b5594bf8c42d6888b28ef1fa172584ad..d47957dac928409ad4b49884db9c70310b38d9ca 100644 --- a/paddle/infrt/tests/lit.cfg.py.in +++ b/paddle/infrt/tests/lit.cfg.py.in @@ -21,10 +21,11 @@ build_dir = "@CMAKE_BINARY_DIR@" config.llvm_tools_dir = os.path.join(build_dir, "third_party/install/llvm/bin") config.llvm_tools_dir = os.path.join(build_dir, "/third_party/install/llvm/lib") infrtopt_bin = os.path.join(build_dir, "paddle/infrt/dialect/") +trtexec_bin = os.path.join(build_dir, "paddle/infrt/dialect/tensorrt/") infrtexec_bin = os.path.join(build_dir, "paddle/infrt/host_context/") llvm_bin = os.path.join(build_dir, "third_party/install/llvm/bin/") config.environment['PATH'] = os.path.pathsep.join( - (infrtopt_bin, infrtexec_bin, llvm_bin, config.environment['PATH'])) + (infrtopt_bin, infrtexec_bin, trtexec_bin, llvm_bin, config.environment['PATH'])) config.suffixes = ['.mlir']