From e72ef603b43b054f3d7787bd34000e759f88a365 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Thu, 10 Mar 2022 09:56:41 +0800 Subject: [PATCH] Add trt execute (#40224) * add trt.execute * merge trt.engine type * update return op * update comments * fix style * fix style --- paddle/infrt/dialect/infrt/infrt_ops.td | 16 ++++++ paddle/infrt/dialect/pd_ops.cc | 1 - paddle/infrt/dialect/pd_ops.h | 1 + ...rt_dilaect_types.h => trt_dialect_types.h} | 2 + .../dialect/tensorrt/trt_graph_fuse_pass.cc | 20 +++---- .../dialect/tensorrt/trt_graph_fuse_pass.h | 32 +++++------ .../dialect/tensorrt/trt_graph_split_pass.cc | 8 +-- .../dialect/tensorrt/trt_graph_split_pass.h | 19 +++---- .../dialect/tensorrt/trt_op_converter_pass.cc | 53 +++++++++++++++++++ .../dialect/tensorrt/trt_op_converter_pass.h | 22 ++++---- .../dialect/tensorrt/trt_op_teller_pass.cc | 9 ++-- .../dialect/tensorrt/trt_op_teller_pass.h | 27 ++++------ paddle/infrt/dialect/tensorrt/trt_ops.cc | 12 ++++- paddle/infrt/dialect/tensorrt/trt_ops.h | 2 + paddle/infrt/dialect/tensorrt/trt_ops.td | 16 ++++-- paddle/infrt/tests/dialect/trt_ops.mlir | 11 +--- tools/infrt/custom_pdop.td | 2 +- 17 files changed, 162 insertions(+), 91 deletions(-) rename paddle/infrt/dialect/tensorrt/{trt_dilaect_types.h => trt_dialect_types.h} (91%) diff --git a/paddle/infrt/dialect/infrt/infrt_ops.td b/paddle/infrt/dialect/infrt/infrt_ops.td index ecd7093e72b..e07a598d9bc 100644 --- a/paddle/infrt/dialect/infrt/infrt_ops.td +++ b/paddle/infrt/dialect/infrt/infrt_ops.td @@ -18,6 +18,22 @@ def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { let results = (outs Variadic); } +def Infrt_ReturnOp : Infrt_Op<"return", [Terminator]> { + let summary = "host executor return operation"; + let description = [{ + The "infrt.return" operation represents a return operation within a function. + + func @foo() : (i32, f8) { + infrt.return %0, %1 : i32, f8 + } + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder<(ins), + [{ build($_builder, $_state, llvm::None); }]>]; +} + def Infrt_CvtTensorOp : Infrt_Op<"cvt_tensor", [NoSideEffect]> { let summary = "convert tensor type op"; let description = [{convert tensor type op!}]; diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc index 338b04e0013..55ab174fcaf 100644 --- a/paddle/infrt/dialect/pd_ops.cc +++ b/paddle/infrt/dialect/pd_ops.cc @@ -16,7 +16,6 @@ #include #include -#include "paddle/infrt/dialect/infrt/infrt_dialect.h" #include "paddle/infrt/dialect/infrt_base.h" #define GET_OP_CLASSES diff --git a/paddle/infrt/dialect/pd_ops.h b/paddle/infrt/dialect/pd_ops.h index b48c68060d4..41dd2ddd94e 100644 --- a/paddle/infrt/dialect/pd_ops.h +++ b/paddle/infrt/dialect/pd_ops.h @@ -28,6 +28,7 @@ #include #include #include +#include "paddle/infrt/dialect/infrt/infrt_dialect.h" namespace mlir { namespace pd { diff --git a/paddle/infrt/dialect/tensorrt/trt_dilaect_types.h b/paddle/infrt/dialect/tensorrt/trt_dialect_types.h similarity index 91% rename from paddle/infrt/dialect/tensorrt/trt_dilaect_types.h rename to paddle/infrt/dialect/tensorrt/trt_dialect_types.h index efcf7dd5be1..0c3edcec1ed 100644 --- a/paddle/infrt/dialect/tensorrt/trt_dilaect_types.h +++ b/paddle/infrt/dialect/tensorrt/trt_dialect_types.h @@ -23,6 +23,8 @@ class EngineType : public mlir::Type::TypeBase { public: using Base::Base; + static EngineType get(); + static EngineType get(mlir::MLIRContext *context); }; } // namespace trt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index fa0095363c5..ad6b136463a 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -53,9 +53,9 @@ bool reverseDfs(std::vector source, } // merge the first&second graph op to a new graph op. -void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT - CreateEngineOp first, - CreateEngineOp second) { +void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT + mlir::pd::GraphOp first, + mlir::pd::GraphOp second) { // comput inputs and outputs ::llvm::SmallVector inputs(first.getOperands()), outputs; for (mlir::Value input : second.getOperands()) { @@ -84,8 +84,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = - builder.create(loc, return_types, inputs, true); + auto graph_op = builder.create(loc, return_types, inputs); mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), @@ -98,7 +97,7 @@ void mergeTwoAdjacentCreateEngineOp(mlir::OpBuilder &builder, // NOLINT copy_range.begin(), copy_range.end()); builder.setInsertionPointToEnd(block); - builder.create<::infrt::dialect::ReturnOp>(loc, outputs); + builder.create<::infrt::ReturnOp>(loc, outputs); graph_op.body().push_back(block); // mapping the output @@ -150,12 +149,13 @@ void TRTGraphFusePass::runOnFunction() { do { changed = false; for (auto &op : body) { - CreateEngineOp graph_op = ::llvm::dyn_cast_or_null(&op); + mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - CreateEngineOp user_graph_op = - ::llvm::dyn_cast_or_null(user_op); + mlir::pd::GraphOp user_graph_op = + ::llvm::dyn_cast_or_null(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. std::vector source_nodes; @@ -168,7 +168,7 @@ void TRTGraphFusePass::runOnFunction() { // Reverse DFS from the source_nodes. if (!reverseDfs(source_nodes, [&op](const mlir::Operation *n) { return n == &op; })) { - mergeTwoAdjacentCreateEngineOp(builder, graph_op, user_graph_op); + mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); changed = true; break; } diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index 350add905aa..803e53e3244 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -15,7 +15,6 @@ #pragma once #include #include "paddle/infrt/dialect/infrt_base.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -26,40 +25,37 @@ namespace trt { * * source func: * - * func @main() -> tensor { - * %a = "pd.feed"()... - * %c = "trt.create_engine"(%a) { + * func @main(%a : tensor) -> tensor { + * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "Infrt.return" %m + * "infrt.return" (%m) * } ... - * %d = "trt.create_engine"(%c) { + * %d = "pd.graph"(%c) { * %m = "pd.conv3d"(%c)... - * "Infrt.return" %m + * "infrt.return" (%m) * } ... - * %f = "trt.create_engine"(%a) { + * %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "Infrt.return" %m + * "infrt.return" (%m) * } ... - * "pd.fetch" %d, %f + * "infrt.return" (%d, %f).. + * } * * destination func: - * func @main() -> tensor { - * %a = "pd.feed"()... - * %d, %f = "trt.create_engine"(%a) { + * func @main(%a : tensor) -> tensor { + * %d, %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "Infrt.return" %n, %s + * "infrt.return" (%n, %s) * } ... - * "pd.fetch" %d, %f + * "infrt.return" (%d, %f) * } */ class TRTGraphFusePass : public mlir::PassWrapper { public: - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } + void getDependentDialects(mlir::DialectRegistry ®istry) const override {} ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; }; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index 5ee7b23213a..e3a7b455024 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -16,23 +16,23 @@ #include #include "paddle/infrt/dialect/pd_ops.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void TRTGraphSplitPass::runOnFunction() { - std::vector worklist; + std::vector worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { - CreateEngineOp graph_op = ::llvm::dyn_cast_or_null(&op); + mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - CreateEngineOp graph_op = worklist.back(); + mlir::pd::GraphOp graph_op = worklist.back(); worklist.pop_back(); mlir::Block* body = graph_op.getBody(); auto return_op = body->getTerminator(); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index 28078e2bc2d..1c44a13cf9d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -15,7 +15,6 @@ #pragma once #include #include "paddle/infrt/dialect/infrt_base.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -27,33 +26,29 @@ namespace trt { * * source func: * - * func @main() -> tensor { - * %a = "pd.feed"()... - * %d, %f = "trt.create_engine"(%a) { + * func @main(%a : tensor) -> tensor { + * %d, %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "Infrt.return" (%n, %s) + * "infrt.return" (%n, %s)... * } ... - * "pd.fetch" (%d, %f) + * "infrt.return" (%d, %f)... * } * * destination func: - * func @main() -> tensor { - * %a = "pd.feed"()... + * func @main(%a : tensor) -> tensor { * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... - * "pd.fetch" (%d, %f) + * "infrt.return" (%d, %f)... * } */ class TRTGraphSplitPass : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } + void getDependentDialects(mlir::DialectRegistry ®istry) const override {} void runOnFunction() override; explicit TRTGraphSplitPass(size_t min_subgraph_size = 3) : min_subgraph_size_(min_subgraph_size) {} diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 8d81e739d9c..1be5f4dbc39 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -16,12 +16,64 @@ #include #include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/pd_ops.h" +#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" namespace infrt { namespace trt { #include "paddle/infrt/dialect/tensorrt/pd_lower_to_trt.cpp.inc" // NOLINT +struct PD2TRT_GraphLower : public ::mlir::RewritePattern { + PD2TRT_GraphLower(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {} + ::mlir::LogicalResult matchAndRewrite( + ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { + auto casted_op = ::llvm::dyn_cast(op); + ::mlir::Operation::operand_range inputs = casted_op.inputs(); + auto ods_loc = rewriter.getFusedLoc(op->getLoc()); + CreateEngineOp create_engine_op; + // inputs + ::mlir::SmallVector<::mlir::Value, 4> trt_inputs; + for (auto v : inputs) { + trt_inputs.push_back(v); + } + create_engine_op = rewriter.create( + ods_loc, + ::llvm::SmallVector(1, EngineType::get()), + trt_inputs, + true /*run_once*/); + ::mlir::Block *block = new ::mlir::Block; + block->getOperations().splice(block->begin(), + casted_op.getBody()->getOperations(), + casted_op.getBody()->begin(), + casted_op.getBody()->end()); + create_engine_op.body().push_back(block); + + // trt.execute + // outputs + ::llvm::SmallVector<::mlir::Type, 4> execute_outputs_types; + for (auto v : casted_op.getODSResults(0)) { + execute_outputs_types.push_back(v.getType()); + } + // inputs + ::mlir::SmallVector<::mlir::Value, 4> execute_inputs( + create_engine_op.getODSResults(0)); + for (auto v : inputs) { + execute_inputs.push_back(v); + } + auto execute_op = rewriter.create( + ods_loc, execute_outputs_types, execute_inputs); + + ::llvm::SmallVector<::mlir::Value, 4> replace_values; + for (auto v : + ::llvm::SmallVector<::mlir::Value, 4>{execute_op.getODSResults(0)}) { + replace_values.push_back(v); + } + rewriter.replaceOp(op, replace_values); + return ::mlir::success(); + } +}; + void TRTOpConverterPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. @@ -36,6 +88,7 @@ void TRTOpConverterPass::runOnOperation() { // the set of patterns that will lower the TensorRT operations. ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); + patterns.add(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h index a8128a585ee..7550d8c84e1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -15,6 +15,7 @@ #pragma once #include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" +#include "paddle/infrt/dialect/infrt/infrt_dialect.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { @@ -23,27 +24,26 @@ namespace trt { * trtOpConverterPass. * * source ir: - * func @main() -> tensor { - * %a = "pd.feed"()... - * %d, %f = "trt.create_engine"(%a) { + * func @main(%a : tensor) -> tensor { + * %d, %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "Infrt.return" %n, %s + * "infrt.return" (%n, %s)... * } ... - * "pd.fetch" %d, %f + * "infrt.return" (%d, %f)... * } * * destination ir: - * func @main() -> tensor { - * %a = "pd.feed"()... - * %d, %f = "trt.create_engine"(%a) { + * func @main(%a : tensor) -> tensor { + * %engine = "trt.create_engine"(%a) ({ * %m = "trt.Convolution"(%a)... * %n = "trt.Convolution"(%m)... * %s = "trt.Convolution"(%a)... - * "Infrt.return" %n, %s - * } ... - * "pd.fetch" %d, %f + * "infrt.return" (%n, %s)... + * }){run_once = true} ... + * %d, %f = "trt.execute"(%engine, %a)... + * "infrt.return" (%d, %f)... * } */ struct TRTOpConverterPass diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 17e893a383a..13b7f1aee55 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -16,6 +16,7 @@ #include #include "paddle/infrt/dialect/basic_kernels.h" +#include "paddle/infrt/dialect/infrt/infrt_dialect.h" #include "paddle/infrt/dialect/pd_ops.h" namespace infrt { @@ -37,11 +38,11 @@ void TRTOpTellerPass::runOnFunction() { if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create( - loc, op->getResultTypes(), op->getOperands(), true); + auto graph_op = builder.create( + loc, op->getResultTypes(), op->getOperands()); ::llvm::SmallVector tblgen_repl_values; for (auto v : @@ -54,7 +55,7 @@ void TRTOpTellerPass::runOnFunction() { graph_op.body().push_back(block); op->moveBefore(block, block->begin()); builder.setInsertionPointToEnd(block); - builder.create<::infrt::dialect::ReturnOp>(loc, op->getResults()); + builder.create<::infrt::ReturnOp>(loc, op->getResults()); } } } // namespace trt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index 471eafa9f9b..b9e461c8633 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -15,7 +15,6 @@ #pragma once #include #include "paddle/infrt/dialect/infrt_base.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -26,30 +25,28 @@ namespace trt { * * source func: * - * func @main() -> tensor { - * %a = "pd.feed"()... + * func @main(%a : tensor) -> tensor { * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... - * "pd.fetch" (%d, %f) + * "infrt.return"(%d, %f) ... * } * * destination func: - * func @main() -> tensor { - * %a = "pd.feed"()... - * %c = "trt.create_engine"(%a) { + * func @main(%a : tensor) -> tensor { + * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "Infrt.return" (%m) + * "infrt.return" (%m) * } ... - * %d = "trt.create_engine"(%c) { + * %d = "pd.graph"(%c) { * %m = "pd.conv3d"(%c)... - * "Infrt.return" (%m) + * "infrt.return" (%m) * } ... - * %f = "trt.create_engine"(%a) { + * %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "Infrt.return" (%m) + * "infrt.return" (%m) * } ... - * "pd.fetch" (%d, %f) + * "infrt.return" (%d, %f) * } * TODO(winter-wang): Supplementary how to judge the operators can be supported * by tensorrt. @@ -57,9 +54,7 @@ namespace trt { class TRTOpTellerPass : public mlir::PassWrapper { public: - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } + void getDependentDialects(mlir::DialectRegistry ®istry) const override {} ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; }; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.cc b/paddle/infrt/dialect/tensorrt/trt_ops.cc index f179939e232..d5222976625 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.cc +++ b/paddle/infrt/dialect/tensorrt/trt_ops.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include #include @@ -19,11 +18,20 @@ #include #include #include -#include "paddle/infrt/dialect/tensorrt/trt_dilaect_types.h" +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" namespace infrt { namespace trt { +EngineType EngineType::get() { + return Base::get(::infrt::Global::getMLIRContext()); +} + +EngineType EngineType::get(mlir::MLIRContext *context) { + return Base::get(context); +} + TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context) : mlir::Dialect("trt", context, mlir::TypeID::get()) { addTypes(); diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h index 978b9906e5f..44444232915 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.h +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -29,6 +29,8 @@ #include #include #include "paddle/infrt/dialect/basic_kernels.h" +#include "paddle/infrt/dialect/infrt/infrt_dialect.h" +#include "paddle/infrt/dialect/pd_ops.h" namespace infrt { namespace trt { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 31142a5157b..132a1d7805b 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -7,14 +7,24 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/tensorrt/trt_op_base.td" -def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::dialect::ReturnOp">]> { - let summary = "trt Graph Op"; + +def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { + let summary = "trt CreateEngine Op"; let description = [{ Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); - let results = (outs Variadic:$outputs); + let results = (outs TRT_EngineType:$output); +} + +def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> { + let summary = "trt execute Op"; + let description = [{ + Describe a tensorrt runtime. + }]; + let arguments = (ins TRT_EngineType:$engine, Variadic:$inputs); + let results = (outs Variadic:$output); } def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { diff --git a/paddle/infrt/tests/dialect/trt_ops.mlir b/paddle/infrt/tests/dialect/trt_ops.mlir index 49510bc542d..6d25044d139 100644 --- a/paddle/infrt/tests/dialect/trt_ops.mlir +++ b/paddle/infrt/tests/dialect/trt_ops.mlir @@ -1,13 +1,6 @@ // RUN: trt-exec %s // CHECK-LABEL: @main -func @main() -> tensor { - %bias = "pd.feed"() {name="input0"} : () -> tensor - %c = "pd.feed"() {name="input1"} : () -> tensor - %b1 = "pd.feed"() {name="input2"} : () -> tensor - %b2 = "pd.feed"() {name="input3"} : () -> tensor - %bias1 = "pd.feed"() {name="input4"} : () -> tensor - %bias2 = "pd.feed"() {name="input5"} : () -> tensor - +func @main(%bias:tensor, %c:tensor, %b1:tensor, %b2:tensor, %bias1:tensor, %bias2:tensor) -> tensor { %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor, tensor) -> tensor %e = "pd.relu6"(%d) {} : (tensor) -> tensor @@ -19,5 +12,5 @@ func @main() -> tensor { %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - "pd.fetch"(%e2) {name="output"} :(tensor)->() + "infrt.return"(%e2) : (tensor)->() } diff --git a/tools/infrt/custom_pdop.td b/tools/infrt/custom_pdop.td index 2139fbd8155..f7547672595 100644 --- a/tools/infrt/custom_pdop.td +++ b/tools/infrt/custom_pdop.td @@ -33,7 +33,7 @@ def PD_ReturnOp : PD_Op<"return", [Terminator]> { let arguments = (ins Variadic:$inputs); } -def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> { +def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { let summary = "paddle graph Op"; let description = [{ Describe a paddle graph or subgraph. -- GitLab