From 9bb3744ff742f7f9372a248273f11b0985a01d76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Tue, 29 Mar 2022 16:51:30 +0800 Subject: [PATCH] [Infrt] delete custom_pdop.td and move op to infrt dialect. (#41021) --- paddle/infrt/dialect/infrt/ir/infrt_base.td | 1 + .../infrt/dialect/infrt/ir/infrt_dialect.cc | 47 ++++++++++++++----- paddle/infrt/dialect/infrt/ir/infrt_dialect.h | 1 + paddle/infrt/dialect/infrt/ir/infrt_ops.td | 16 ++++++- .../infrt/dialect/infrt/pass/infrt_op_fuse.td | 4 -- .../dialect/infrt/pass/infrt_op_fuse_pass.cc | 2 +- paddle/infrt/dialect/pd/ir/pd_op_base.td | 1 - paddle/infrt/dialect/pd/ir/pd_ops.cc | 37 --------------- .../dialect/tensorrt/trt_op_teller_pass.cc | 7 --- .../dialect/disabled_rewrite_conv_bn.mlir | 7 ++- paddle/infrt/tests/dialect/paddle_ops.mlir | 9 ---- paddle/infrt/tests/dialect/pd/rewrite.mlir | 30 +++++------- paddle/infrt/tests/dialect/phi/phi_pass.mlir | 9 ++-- tools/infrt/custom_pdop.td | 37 --------------- ...rate_pd_op_dialect_from_paddle_op_maker.py | 13 ++--- 15 files changed, 74 insertions(+), 147 deletions(-) delete mode 100644 paddle/infrt/tests/dialect/paddle_ops.mlir delete mode 100644 tools/infrt/custom_pdop.td diff --git a/paddle/infrt/dialect/infrt/ir/infrt_base.td b/paddle/infrt/dialect/infrt/ir/infrt_base.td index 9b1d213229..ba8867d223 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_base.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_base.td @@ -10,6 +10,7 @@ def Infrt_Dialect : Dialect { let name = "infrt"; let cppNamespace = "::infrt"; + let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; } diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc index eb69a95c58..c4f20cb4d3 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc @@ -183,16 +183,41 @@ void InfrtDialect::printType(::mlir::Type type, llvm_unreachable("unknown infrt type."); } -// /// Parse an attribute registered to this dialect. -// ::mlir::Attribute InfrtDialect::parseAttribute(::mlir::DialectAsmParser -// &parser, -// ::mlir::Type type) const { -// return mlir::Attribute(); -// } -// /// Print an attribute registered to this dialect. -// void InfrtDialect::printAttribute(::mlir::Attribute attr, -// ::mlir::DialectAsmPrinter &os) const { - -// } +mlir::Operation *InfrtDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + return builder.create(loc, value); +} + +void ConstantOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + mlir::Attribute value) { + if (auto elem_attr = value.dyn_cast()) { + return ConstantOp::build(builder, state, elem_attr); + } else if (value.isa()) { + mlir::ShapedType type = + mlir::RankedTensorType::get(/*shape=*/{}, value.getType()); + state.addAttribute("value", mlir::DenseElementsAttr::get(type, value)); + state.addTypes(type); + return; + } + llvm_unreachable("unsupported attribute type for building pd.constant"); +} + +mlir::LogicalResult ConstantOp::inferReturnTypes( + mlir::MLIRContext *context, + mlir::Optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(attributes.get("value").getType()); + return mlir::success(); +} +mlir::OpFoldResult ConstantOp::fold( + ::llvm::ArrayRef operands) { + return value(); +} } // namespace infrt diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.h b/paddle/infrt/dialect/infrt/ir/infrt_dialect.h index 3e6ea2a74c..e2e9b9348e 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.h +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "paddle/infrt/dialect/infrt/common/types.h" diff --git a/paddle/infrt/dialect/infrt/ir/infrt_ops.td b/paddle/infrt/dialect/infrt/ir/infrt_ops.td index cff6ce048a..2736b7ad8c 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_ops.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_ops.td @@ -1,3 +1,4 @@ +include "mlir/Interfaces/InferTypeOpInterface.td" include "paddle/infrt/dialect/infrt/ir/infrt_base.td" // Op definition @@ -9,7 +10,7 @@ class Infrt_Op traits = []> : Op]> { +def Infrt_GraphOp : Infrt_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { let summary = "paddle graph Op"; let description = [{ Describe a paddle graph or subgraph. @@ -69,3 +70,16 @@ def Infrt_TensorCastOp : Infrt_Op<"tensor_cast", [NoSideEffect]> { let arguments = (ins AnyType:$input); let results = (outs AnyType:$output); } + +def Infrt_ConstantOp : Infrt_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { + let summary = "constant Op"; + let description = [{}]; + + let arguments = (ins ElementsAttr:$value); + let results = (outs AnyType:$output); + let hasFolder = 1; + + let builders = [ + OpBuilder<(ins "mlir::Attribute":$value)>, + ]; +} diff --git a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td index 3d825a9c76..1d6c0a7538 100644 --- a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td +++ b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td @@ -9,10 +9,6 @@ def FuseTensorCastPattern : Pat< (Infrt_TensorCastOp (Infrt_TensorCastOp $arg)), (Infrt_TensorCastOp $arg)>; -def FuseFeedTensorCastPattern : Pat< - (Infrt_TensorCastOp (PD_FeedOp $name)), - (PD_FeedOp $name)>; - def TypesAreIdentical : Constraint>; def RedundantTensorCastOptPattern : Pat< (Infrt_TensorCastOp:$res $arg), (replaceWithValue $arg), diff --git a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc index eec0e0bc7c..a674e395da 100644 --- a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc +++ b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc @@ -38,7 +38,7 @@ void InfrtOpFusePass::runOnFunction() { ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - // Fuse pd.return Operation + // Fuse infrt.return Operation auto terminator_op = getFunction().front().getTerminator(); if (nullptr == terminator_op) return; for (auto operand : terminator_op->getOperands()) { diff --git a/paddle/infrt/dialect/pd/ir/pd_op_base.td b/paddle/infrt/dialect/pd/ir/pd_op_base.td index e28854a848..cb6f7aadd9 100644 --- a/paddle/infrt/dialect/pd/ir/pd_op_base.td +++ b/paddle/infrt/dialect/pd/ir/pd_op_base.td @@ -16,7 +16,6 @@ def Paddle_Dialect : Dialect { This dialect contains the PaddlePaddle operators. }]; - let hasConstantMaterializer = 1; let cppNamespace = "infrt::pd"; } diff --git a/paddle/infrt/dialect/pd/ir/pd_ops.cc b/paddle/infrt/dialect/pd/ir/pd_ops.cc index b5ba48581e..be6ff4cf74 100644 --- a/paddle/infrt/dialect/pd/ir/pd_ops.cc +++ b/paddle/infrt/dialect/pd/ir/pd_ops.cc @@ -35,42 +35,5 @@ void PaddleDialect::initialize() { #include "paddle/infrt/dialect/pd/ir/pd_extra_ops.cpp.inc" // NOLINT >(); } - -mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, - mlir::Attribute value, - mlir::Type type, - mlir::Location loc) { - return builder.create(loc, value); -} - -void ConstantOp::build(mlir::OpBuilder &builder, - mlir::OperationState &state, - mlir::Attribute value) { - if (auto elem_attr = value.dyn_cast()) { - return ConstantOp::build(builder, state, elem_attr); - } else if (value.isa()) { - mlir::ShapedType type = - mlir::RankedTensorType::get(/*shape=*/{}, value.getType()); - state.addAttribute("value", mlir::DenseElementsAttr::get(type, value)); - state.addTypes(type); - return; - } - llvm_unreachable("unsupported attribute type for building pd.constant"); -} - -mlir::LogicalResult ConstantOp::inferReturnTypes( - mlir::MLIRContext *context, - mlir::Optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::RegionRange regions, - llvm::SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back(attributes.get("value").getType()); - return mlir::success(); -} -mlir::OpFoldResult ConstantOp::fold( - ::llvm::ArrayRef operands) { - return value(); -} } // namespace pd } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 9c3d80d77e..77c22c1285 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -39,13 +39,6 @@ void TRTOpTellerPass::runOnFunction() { worklist.pop_back(); if (op == nullptr) continue; if (op->getName().getStringRef().substr(0, 3) != "pd.") continue; - if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue; - if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; - if (::llvm::dyn_cast_or_null<::infrt::phi::TensorMapGetTensorOp>(op)) - continue; - builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); auto graph_op = builder.create<::infrt::GraphOp>( diff --git a/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir b/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir index 2889b92b18..4a1e627b60 100644 --- a/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir +++ b/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir @@ -1,6 +1,5 @@ // CHECK-LABEL: @main -func @main() -> tensor { - %a = "pd.feed"() {name="input0"} : () -> tensor +func @main(%a:tensor) -> tensor { %filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32> %bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> @@ -11,5 +10,5 @@ func @main() -> tensor { %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor - "pd.fetch"(%d) {name="output"} :(tensor)->() -} \ No newline at end of file + infrt.return %d:tensor +} diff --git a/paddle/infrt/tests/dialect/paddle_ops.mlir b/paddle/infrt/tests/dialect/paddle_ops.mlir deleted file mode 100644 index 4b80555149..0000000000 --- a/paddle/infrt/tests/dialect/paddle_ops.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: infrtopt %s | FileCheck %s -// CHECK-LABEL: @ops -func @ops() { - %a = pd.feed() {name="input0"} : tensor - %b = pd.feed() {name="input1"}: tensor - %d = pd.feed() {name="input3"}: !infrt.lod_tensor<3x4x9xf32, 0> - %c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - infrt.return -} diff --git a/paddle/infrt/tests/dialect/pd/rewrite.mlir b/paddle/infrt/tests/dialect/pd/rewrite.mlir index ea0248b9d9..295ad47770 100644 --- a/paddle/infrt/tests/dialect/pd/rewrite.mlir +++ b/paddle/infrt/tests/dialect/pd/rewrite.mlir @@ -1,28 +1,20 @@ // RUN: infrtopt --pd-op-fuse %s | FileCheck %s // CHECK-LABEL: @main -func @main() -> tensor { - %a = "pd.feed"() {name="input0"} : () -> tensor - %b = "pd.feed"() {name="input1"} : () -> tensor - %bias = "pd.feed"() {name="input2"} : () -> tensor +func @main(%arg0: tensor, %arg1: tensor, %arg2:tensor, %arg3:tensor, %arg4:tensor, %arg5:tensor, %arg6:tensor) -> tensor { - %b1 = "pd.feed"() {name="input3"} : () -> tensor - %b2 = "pd.feed"() {name="input4"} : () -> tensor - %bias1 = "pd.feed"() {name="input5"} : () -> tensor - %bias2 = "pd.feed"() {name="input6"} : () -> tensor - - // CHECK: %{{[0-9]+}} = "pd.FC"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor - %c = "pd.matmul_v2"(%a, %b) {transpose_y=false} : (tensor, tensor) -> tensor - %d = "pd.elementwise_add"(%c, %bias) {axis=1:si32} : (tensor, tensor) -> tensor + // CHECK: %0 = "pd.FC"(%arg0, %arg1, %arg4) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor + %c = "pd.matmul_v2"(%arg0, %arg1) {transpose_y=false} : (tensor, tensor) -> tensor + %d = "pd.elementwise_add"(%c, %arg4) {axis=1:si32} : (tensor, tensor) -> tensor %e = "pd.relu6"(%d) {} : (tensor) -> tensor - // CHECK: %{{[0-9]+}} = "pd.FC"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor - %c1 = "pd.matmul_v2"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor - %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=1:si32} : (tensor, tensor) -> tensor + // CHECK: %2 = "pd.FC"(%1, %arg2, %arg5) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor + %c1 = "pd.matmul_v2"(%e, %arg2) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor + %d1 = "pd.elementwise_add"(%c1, %arg5) {axis=1:si32} : (tensor, tensor) -> tensor %e1 = "pd.relu"(%d1) {} : (tensor) -> tensor - // CHECK: %{{[0-9]+}} = "pd.FC"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor - %c2 = "pd.matmul_v2"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:si32} : (tensor, tensor) -> tensor + // CHECK: %4 = "pd.FC"(%3, %arg3, %arg6) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor + %c2 = "pd.matmul_v2"(%e1, %arg3) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor + %d2 = "pd.elementwise_add"(%c2, %arg6) {axis=1:si32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - "pd.fetch"(%e2) {name="output"} :(tensor)->() + infrt.return %e2:tensor } diff --git a/paddle/infrt/tests/dialect/phi/phi_pass.mlir b/paddle/infrt/tests/dialect/phi/phi_pass.mlir index 47badd97d3..784ead5b2a 100644 --- a/paddle/infrt/tests/dialect/phi/phi_pass.mlir +++ b/paddle/infrt/tests/dialect/phi/phi_pass.mlir @@ -1,18 +1,15 @@ // RUN: infrtopt -phi-op-convert -infrt-op-fuse %s // CHECK-LABEL: @ops -func @ops() { - %a = pd.feed() {name="input0"} : !infrt.lod_tensor - %b = pd.feed() {name="input1"} : !infrt.lod_tensor - %d = pd.feed() {name="input3"} : !infrt.lod_tensor<3x4x9xf32, 0> +func @ops(%a:!infrt.lod_tensor, %b:!infrt.lod_tensor) { %g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor, !infrt.lod_tensor) -> tensor %h = "pd.abs"(%g):(tensor) -> tensor - "pd.fetch"(%h) {name="output"} :(tensor)->() + infrt.return %h:tensor } // CHECK-LABEL: @op_execute func @op_execute(%a:!infrt.lod_tensor, %b:!infrt.lod_tensor, %c:!infrt.lod_tensor) -> !infrt.lod_tensor { %g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor, !infrt.lod_tensor) -> tensor %h = "pd.abs"(%g):(tensor) -> tensor - "pd.fetch"(%h) {name="output"} :(tensor)->() + infrt.return %h:tensor } diff --git a/tools/infrt/custom_pdop.td b/tools/infrt/custom_pdop.td deleted file mode 100644 index 23ab8668ae..0000000000 --- a/tools/infrt/custom_pdop.td +++ /dev/null @@ -1,37 +0,0 @@ -def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> { - let summary = "Feed Op"; - - let description = [{ - Feed a tensor into the model. - }]; - - let arguments = (ins StrAttr:$name); - let results = (outs PD_Tensor:$out); - - let assemblyFormat = [{ - `(` `)` attr-dict `:` type($out) - }]; -} - -def PD_FetchOp : PD_Op<"fetch", [Terminator]> { - let summary = "fetch Op"; - - let description = [{ - Return the output tensor from the subgraph. - }]; - - let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); -} - -def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { - let summary = "constant Op"; - let description = [{}]; - - let arguments = (ins ElementsAttr:$value); - let results = (outs PD_Tensor:$output); - let hasFolder = 1; - - let builders = [ - OpBuilder<(ins "mlir::Attribute":$value)>, - ]; -} diff --git a/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py b/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py index a85bf231d5..a4f93a5d6c 100644 --- a/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py +++ b/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py @@ -209,7 +209,6 @@ def get_constraint(op_type, op_proto): # funtion to generate paddle op dialect file def convert_op_proto_into_mlir(op_descs): dst_dialect_file = "../../paddle/infrt/dialect/pd/ir/pd_ops.td" - custom_dialect_file = "custom_pdop.td" # 1. Head files comment_ = "/*===- TableGen'source file -----------------------------------------------===*\\\n\ @@ -372,19 +371,13 @@ def convert_op_proto_into_mlir(op_descs): ops_mlir_file.write(RESULTS) ops_mlir_file.write("}\n") + with open(dst_dialect_file, 'a') as ops_mlir_file: + ops_mlir_file.write("\n#endif // PD_OPS") + print("Skipped ops num: " + str(len(skipped_op_list))) print("Automatically generated op dialects num: " + str( len(automatically_generated_op_dialect))) - # 3. custom op dialect and end of file - with open(dst_dialect_file, 'a') as ops_mlir_file: - with open(custom_dialect_file, 'r') as custom_ops_file: - custom_ops = custom_ops_file.readlines() - ops_mlir_file.writelines(custom_ops) - - end_ = "\n#endif // PD_OPS" - ops_mlir_file.write(end_) - if __name__ == "__main__": all_op_protos_dict = get_all_ops_desc() -- GitLab