From bf93050c9f874762ca8de7698754afb0554a9ded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Mon, 28 Mar 2022 19:31:05 +0800 Subject: [PATCH] [infrt] move graph op from pd dialect to infrt dialect. (#41003) --- paddle/infrt/dialect/infrt/ir/infrt_ops.td | 10 ++++++++++ .../infrt/dialect/tensorrt/trt_graph_fuse_pass.cc | 14 +++++++------- .../infrt/dialect/tensorrt/trt_graph_fuse_pass.h | 8 ++++---- .../infrt/dialect/tensorrt/trt_graph_split_pass.cc | 7 +++---- .../infrt/dialect/tensorrt/trt_graph_split_pass.h | 2 +- .../dialect/tensorrt/trt_op_converter_pass.cc | 12 +++++++----- .../infrt/dialect/tensorrt/trt_op_converter_pass.h | 2 +- .../infrt/dialect/tensorrt/trt_op_teller_pass.cc | 4 ++-- paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h | 6 +++--- tools/infrt/custom_pdop.td | 10 ---------- 10 files changed, 38 insertions(+), 37 deletions(-) diff --git a/paddle/infrt/dialect/infrt/ir/infrt_ops.td b/paddle/infrt/dialect/infrt/ir/infrt_ops.td index 82eba2a174..cff6ce048a 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_ops.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_ops.td @@ -9,6 +9,16 @@ class Infrt_Op traits = []> : Op]> { + let summary = "paddle graph Op"; + let description = [{ + Describe a paddle graph or subgraph. + }]; + let regions = (region SizedRegion<1>:$body); + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); +} + def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { let summary = "kernel op"; let description = [{kernel op!}]; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 0878163a95..c575d05949 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -55,8 +55,8 @@ bool reverseDfs(std::vector source, // merge the first&second graph op to a new graph op. void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT - infrt::pd::GraphOp first, - infrt::pd::GraphOp second) { + ::infrt::GraphOp first, + ::infrt::GraphOp second) { // comput inputs and outputs ::llvm::SmallVector inputs(first.getOperands()), outputs; for (mlir::Value input : second.getOperands()) { @@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = builder.create(loc, return_types, inputs); + auto graph_op = builder.create<::infrt::GraphOp>(loc, return_types, inputs); mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), @@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() { do { changed = false; for (auto &op : body) { - infrt::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + ::infrt::GraphOp graph_op = + ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - infrt::pd::GraphOp user_graph_op = - ::llvm::dyn_cast_or_null(user_op); + ::infrt::GraphOp user_graph_op = + ::llvm::dyn_cast_or_null<::infrt::GraphOp>(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. std::vector source_nodes; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index 18afba19e0..4c72147623 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -25,15 +25,15 @@ namespace trt { * source func: * * func @main(%a : tensor) -> tensor { - * %c = "pd.graph"(%a) { + * %c = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m... * } ... - * %d = "pd.graph"(%c) { + * %d = "infrt.graph"(%c) { * %m = "pd.conv3d"(%c)... * infrt.return %m... * } ... - * %f = "pd.graph"(%a) { + * %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m... * } ... @@ -42,7 +42,7 @@ namespace trt { * * destination func: * func @main(%a : tensor) -> tensor { - * %d, %f = "pd.graph"(%a) { + * %d, %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index ade61bfc37..2136f19fd1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -21,18 +21,17 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void TRTGraphSplitPass::runOnFunction() { - std::vector worklist; + std::vector<::infrt::GraphOp> worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { - infrt::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + ::infrt::GraphOp graph_op = ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - infrt::pd::GraphOp graph_op = worklist.back(); + ::infrt::GraphOp graph_op = worklist.back(); worklist.pop_back(); mlir::Block* body = graph_op.getBody(); auto return_op = body->getTerminator(); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index a5dd4f14b2..a71b9cb653 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -26,7 +26,7 @@ namespace trt { * source func: * * func @main(%a : tensor) -> tensor { - * %d, %f = "pd.graph"(%a) { + * %d, %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index b7032a2aa2..e3dab7093c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -41,14 +41,15 @@ namespace trt { #endif // INFRT_WITH_TRT template -::mlir::IntegerAttr createNvinferEnumAttr(::mlir::PatternRewriter &rewriter, - T enum_value) { +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, // NOLINT + T enum_value) { return rewriter.getSI32IntegerAttr((int32_t)enum_value); } template <> ::mlir::IntegerAttr createNvinferEnumAttr( - ::mlir::PatternRewriter &rewriter, std::string enum_value) { + ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT (void)enum_value; return rewriter.getSI32IntegerAttr(-1); } @@ -57,10 +58,11 @@ template <> struct PD2TRT_GraphLower : public ::mlir::RewritePattern { explicit PD2TRT_GraphLower(::mlir::MLIRContext *context) - : ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {} + : ::mlir::RewritePattern( + "infrt.graph", 1, context, {"trt.create_engine"}) {} ::mlir::LogicalResult matchAndRewrite( ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { - auto casted_op = ::llvm::dyn_cast(op); + auto casted_op = ::llvm::dyn_cast<::infrt::GraphOp>(op); ::mlir::Operation::operand_range inputs = casted_op.inputs(); auto ods_loc = rewriter.getFusedLoc(op->getLoc()); CreateEngineOp create_engine_op; diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h index ede64f8bcd..685686493c 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -25,7 +25,7 @@ namespace trt { * * source ir: * func @main(%a : tensor) -> tensor { - * %d, %f = "pd.graph"(%a) { + * %d, %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 5918be90cd..7c9ec16d20 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -40,12 +40,12 @@ void TRTOpTellerPass::runOnFunction() { if (op->getName().getStringRef().substr(0, 3) != "pd.") continue; if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create( + auto graph_op = builder.create<::infrt::GraphOp>( loc, op->getResultTypes(), op->getOperands()); ::llvm::SmallVector tblgen_repl_values; diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index 1cb08dc0a2..47375d838a 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -33,15 +33,15 @@ namespace trt { * * destination func: * func @main(%a : tensor) -> tensor { - * %c = "pd.graph"(%a) { + * %c = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m:... * } ... - * %d = "pd.graph"(%c) { + * %d = "infrt.graph"(%c) { * %m = "pd.conv3d"(%c)... * infrt.return %m:... * } ... - * %f = "pd.graph"(%a) { + * %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m:... * } ... diff --git a/tools/infrt/custom_pdop.td b/tools/infrt/custom_pdop.td index ae0316036f..23ab8668ae 100644 --- a/tools/infrt/custom_pdop.td +++ b/tools/infrt/custom_pdop.td @@ -23,16 +23,6 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); } -def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { - let summary = "paddle graph Op"; - let description = [{ - Describe a paddle graph or subgraph. - }]; - let regions = (region SizedRegion<1>:$body); - let arguments = (ins Variadic:$inputs); - let results = (outs Variadic:$outputs); -} - def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { let summary = "constant Op"; let description = [{}]; -- GitLab