diff --git a/paddle/infrt/dialect/infrt/ir/infrt_ops.td b/paddle/infrt/dialect/infrt/ir/infrt_ops.td index 82eba2a1746cce31e3fe99ae71c782bb88524930..cff6ce048a36c1d1e535dc5d44806555c6c2855d 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_ops.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_ops.td @@ -9,6 +9,16 @@ class Infrt_Op traits = []> : Op]> { + let summary = "paddle graph Op"; + let description = [{ + Describe a paddle graph or subgraph. + }]; + let regions = (region SizedRegion<1>:$body); + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); +} + def Infrt_KernelOp : Infrt_Op<"kernel", [NoSideEffect]> { let summary = "kernel op"; let description = [{kernel op!}]; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 0878163a955af236c6a40f60850e9e5cad67b2aa..c575d05949a3f706e294040d8ad26c7c31f9bc17 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -55,8 +55,8 @@ bool reverseDfs(std::vector source, // merge the first&second graph op to a new graph op. void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT - infrt::pd::GraphOp first, - infrt::pd::GraphOp second) { + ::infrt::GraphOp first, + ::infrt::GraphOp second) { // comput inputs and outputs ::llvm::SmallVector inputs(first.getOperands()), outputs; for (mlir::Value input : second.getOperands()) { @@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = builder.create(loc, return_types, inputs); + auto graph_op = builder.create<::infrt::GraphOp>(loc, return_types, inputs); mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), @@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() { do { changed = false; for (auto &op : body) { - infrt::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + ::infrt::GraphOp graph_op = + ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - infrt::pd::GraphOp user_graph_op = - ::llvm::dyn_cast_or_null(user_op); + ::infrt::GraphOp user_graph_op = + ::llvm::dyn_cast_or_null<::infrt::GraphOp>(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. std::vector source_nodes; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index 18afba19e06189294078bcfc1a0b2bb341eb7126..4c7214762303c0d909425323eff83b97ff4928a1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -25,15 +25,15 @@ namespace trt { * source func: * * func @main(%a : tensor) -> tensor { - * %c = "pd.graph"(%a) { + * %c = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m... * } ... - * %d = "pd.graph"(%c) { + * %d = "infrt.graph"(%c) { * %m = "pd.conv3d"(%c)... * infrt.return %m... * } ... - * %f = "pd.graph"(%a) { + * %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m... * } ... @@ -42,7 +42,7 @@ namespace trt { * * destination func: * func @main(%a : tensor) -> tensor { - * %d, %f = "pd.graph"(%a) { + * %d, %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index ade61bfc370f550cf85267b3088d697bf1bea997..2136f19fd1af56f0f9a089a995000d96c1ae88f6 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -21,18 +21,17 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void TRTGraphSplitPass::runOnFunction() { - std::vector worklist; + std::vector<::infrt::GraphOp> worklist; mlir::Block& block = getFunction().front(); for (auto& op : block) { - infrt::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null(&op); + ::infrt::GraphOp graph_op = ::llvm::dyn_cast_or_null<::infrt::GraphOp>(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - infrt::pd::GraphOp graph_op = worklist.back(); + ::infrt::GraphOp graph_op = worklist.back(); worklist.pop_back(); mlir::Block* body = graph_op.getBody(); auto return_op = body->getTerminator(); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index a5dd4f14b2946fe232b7b725f6ace7caf74ff4d4..a71b9cb6536c5f1e1930d5cf43d7dd2fd360788e 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -26,7 +26,7 @@ namespace trt { * source func: * * func @main(%a : tensor) -> tensor { - * %d, %f = "pd.graph"(%a) { + * %d, %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index b7032a2aa25c92e5f8d04414a27da7d6fa232d98..e3dab7093c59f7242356ca0d736b2a6011085164 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -41,14 +41,15 @@ namespace trt { #endif // INFRT_WITH_TRT template -::mlir::IntegerAttr createNvinferEnumAttr(::mlir::PatternRewriter &rewriter, - T enum_value) { +::mlir::IntegerAttr createNvinferEnumAttr( + ::mlir::PatternRewriter &rewriter, // NOLINT + T enum_value) { return rewriter.getSI32IntegerAttr((int32_t)enum_value); } template <> ::mlir::IntegerAttr createNvinferEnumAttr( - ::mlir::PatternRewriter &rewriter, std::string enum_value) { + ::mlir::PatternRewriter &rewriter, std::string enum_value) { // NOLINT (void)enum_value; return rewriter.getSI32IntegerAttr(-1); } @@ -57,10 +58,11 @@ template <> struct PD2TRT_GraphLower : public ::mlir::RewritePattern { explicit PD2TRT_GraphLower(::mlir::MLIRContext *context) - : ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {} + : ::mlir::RewritePattern( + "infrt.graph", 1, context, {"trt.create_engine"}) {} ::mlir::LogicalResult matchAndRewrite( ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { - auto casted_op = ::llvm::dyn_cast(op); + auto casted_op = ::llvm::dyn_cast<::infrt::GraphOp>(op); ::mlir::Operation::operand_range inputs = casted_op.inputs(); auto ods_loc = rewriter.getFusedLoc(op->getLoc()); CreateEngineOp create_engine_op; diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h index ede64f8bcd556a73b779fc3b772bf3fa8f74eaf9..685686493c9ab64ae144fb5e7a4c12de8166096b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -25,7 +25,7 @@ namespace trt { * * source ir: * func @main(%a : tensor) -> tensor { - * %d, %f = "pd.graph"(%a) { + * %d, %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 5918be90cdd303496bac93cec4483bef04d567d0..7c9ec16d20400cce028f6622fa46c5f72276f785 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -40,12 +40,12 @@ void TRTOpTellerPass::runOnFunction() { if (op->getName().getStringRef().substr(0, 3) != "pd.") continue; if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null(op)) continue; + if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create( + auto graph_op = builder.create<::infrt::GraphOp>( loc, op->getResultTypes(), op->getOperands()); ::llvm::SmallVector tblgen_repl_values; diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index 1cb08dc0a2161eeb5720191bada52f9b54e94893..47375d838a987482ec5ee8ceb3a04697dd0f3bc1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -33,15 +33,15 @@ namespace trt { * * destination func: * func @main(%a : tensor) -> tensor { - * %c = "pd.graph"(%a) { + * %c = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m:... * } ... - * %d = "pd.graph"(%c) { + * %d = "infrt.graph"(%c) { * %m = "pd.conv3d"(%c)... * infrt.return %m:... * } ... - * %f = "pd.graph"(%a) { + * %f = "infrt.graph"(%a) { * %m = "pd.conv2d"(%a)... * infrt.return %m:... * } ... diff --git a/tools/infrt/custom_pdop.td b/tools/infrt/custom_pdop.td index ae0316036f1854e281e07de59fb5aa53201bd35e..23ab8668ae6dc20356ce2ccf24d5258438c041d5 100644 --- a/tools/infrt/custom_pdop.td +++ b/tools/infrt/custom_pdop.td @@ -23,16 +23,6 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); } -def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { - let summary = "paddle graph Op"; - let description = [{ - Describe a paddle graph or subgraph. - }]; - let regions = (region SizedRegion<1>:$body); - let arguments = (ins Variadic:$inputs); - let results = (outs Variadic:$outputs); -} - def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { let summary = "constant Op"; let description = [{}];