diff --git a/paddle/infrt/dialect/infrt/ir/infrt_base.td b/paddle/infrt/dialect/infrt/ir/infrt_base.td index 9b1d2132292df708b7c170442be702417593cfb4..ba8867d223a9a4e5553614a17e6d1e3d79c364f0 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_base.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_base.td @@ -10,6 +10,7 @@ def Infrt_Dialect : Dialect { let name = "infrt"; let cppNamespace = "::infrt"; + let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; } diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc index eb69a95c583f2a7d987e1e5f4617fcd34e0dad28..c4f20cb4d35c54d3e5b9eaf9fa378907f8872567 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc @@ -183,16 +183,41 @@ void InfrtDialect::printType(::mlir::Type type, llvm_unreachable("unknown infrt type."); } -// /// Parse an attribute registered to this dialect. -// ::mlir::Attribute InfrtDialect::parseAttribute(::mlir::DialectAsmParser -// &parser, -// ::mlir::Type type) const { -// return mlir::Attribute(); -// } -// /// Print an attribute registered to this dialect. -// void InfrtDialect::printAttribute(::mlir::Attribute attr, -// ::mlir::DialectAsmPrinter &os) const { - -// } +mlir::Operation *InfrtDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + return builder.create(loc, value); +} + +void ConstantOp::build(mlir::OpBuilder &builder, + mlir::OperationState &state, + mlir::Attribute value) { + if (auto elem_attr = value.dyn_cast()) { + return ConstantOp::build(builder, state, elem_attr); + } else if (value.isa()) { + mlir::ShapedType type = + mlir::RankedTensorType::get(/*shape=*/{}, value.getType()); + state.addAttribute("value", mlir::DenseElementsAttr::get(type, value)); + state.addTypes(type); + return; + } + llvm_unreachable("unsupported attribute type for building pd.constant"); +} + +mlir::LogicalResult ConstantOp::inferReturnTypes( + mlir::MLIRContext *context, + mlir::Optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(attributes.get("value").getType()); + return mlir::success(); +} +mlir::OpFoldResult ConstantOp::fold( + ::llvm::ArrayRef operands) { + return value(); +} } // namespace infrt diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.h b/paddle/infrt/dialect/infrt/ir/infrt_dialect.h index 3e6ea2a74c79d43015a62f166928e10adb48698a..e2e9b9348eb46da737d72757680a4fdf4aee5282 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.h +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "paddle/infrt/dialect/infrt/common/types.h" diff --git a/paddle/infrt/dialect/infrt/ir/infrt_ops.td b/paddle/infrt/dialect/infrt/ir/infrt_ops.td index cff6ce048a36c1d1e535dc5d44806555c6c2855d..2736b7ad8c25f210690d1e3ce8ee217fce820577 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_ops.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_ops.td @@ -1,3 +1,4 @@ +include "mlir/Interfaces/InferTypeOpInterface.td" include "paddle/infrt/dialect/infrt/ir/infrt_base.td" // Op definition @@ -9,7 +10,7 @@ class Infrt_Op traits = []> : Op]> { +def Infrt_GraphOp : Infrt_Op<"graph", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { let summary = "paddle graph Op"; let description = [{ Describe a paddle graph or subgraph. @@ -69,3 +70,16 @@ def Infrt_TensorCastOp : Infrt_Op<"tensor_cast", [NoSideEffect]> { let arguments = (ins AnyType:$input); let results = (outs AnyType:$output); } + +def Infrt_ConstantOp : Infrt_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { + let summary = "constant Op"; + let description = [{}]; + + let arguments = (ins ElementsAttr:$value); + let results = (outs AnyType:$output); + let hasFolder = 1; + + let builders = [ + OpBuilder<(ins "mlir::Attribute":$value)>, + ]; +} diff --git a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td index 3d825a9c762f4833e577125d20423a5f1d41737f..1d6c0a75382b2cd0f2aff35090901c29637a76f9 100644 --- a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td +++ b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse.td @@ -9,10 +9,6 @@ def FuseTensorCastPattern : Pat< (Infrt_TensorCastOp (Infrt_TensorCastOp $arg)), (Infrt_TensorCastOp $arg)>; -def FuseFeedTensorCastPattern : Pat< - (Infrt_TensorCastOp (PD_FeedOp $name)), - (PD_FeedOp $name)>; - def TypesAreIdentical : Constraint>; def RedundantTensorCastOptPattern : Pat< (Infrt_TensorCastOp:$res $arg), (replaceWithValue $arg), diff --git a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc index eec0e0bc7c5ab624e9db7744c357b58ff5107eef..a674e395da4ab405b028d1a84399ab358ba1473f 100644 --- a/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc +++ b/paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.cc @@ -38,7 +38,7 @@ void InfrtOpFusePass::runOnFunction() { ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - // Fuse pd.return Operation + // Fuse infrt.return Operation auto terminator_op = getFunction().front().getTerminator(); if (nullptr == terminator_op) return; for (auto operand : terminator_op->getOperands()) { diff --git a/paddle/infrt/dialect/pd/ir/pd_op_base.td b/paddle/infrt/dialect/pd/ir/pd_op_base.td index e28854a848023c1161c8cda24edb705f536b5698..cb6f7aadd9891fa04dd7d0dc7759067bdbeb10b5 100644 --- a/paddle/infrt/dialect/pd/ir/pd_op_base.td +++ b/paddle/infrt/dialect/pd/ir/pd_op_base.td @@ -16,7 +16,6 @@ def Paddle_Dialect : Dialect { This dialect contains the PaddlePaddle operators. }]; - let hasConstantMaterializer = 1; let cppNamespace = "infrt::pd"; } diff --git a/paddle/infrt/dialect/pd/ir/pd_ops.cc b/paddle/infrt/dialect/pd/ir/pd_ops.cc index b5ba48581ee62f4e77328ed9f91ad956632dbbb7..be6ff4cf749c62ad7121e9b8dfb5f4ccf223f28a 100644 --- a/paddle/infrt/dialect/pd/ir/pd_ops.cc +++ b/paddle/infrt/dialect/pd/ir/pd_ops.cc @@ -35,42 +35,5 @@ void PaddleDialect::initialize() { #include "paddle/infrt/dialect/pd/ir/pd_extra_ops.cpp.inc" // NOLINT >(); } - -mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, - mlir::Attribute value, - mlir::Type type, - mlir::Location loc) { - return builder.create(loc, value); -} - -void ConstantOp::build(mlir::OpBuilder &builder, - mlir::OperationState &state, - mlir::Attribute value) { - if (auto elem_attr = value.dyn_cast()) { - return ConstantOp::build(builder, state, elem_attr); - } else if (value.isa()) { - mlir::ShapedType type = - mlir::RankedTensorType::get(/*shape=*/{}, value.getType()); - state.addAttribute("value", mlir::DenseElementsAttr::get(type, value)); - state.addTypes(type); - return; - } - llvm_unreachable("unsupported attribute type for building pd.constant"); -} - -mlir::LogicalResult ConstantOp::inferReturnTypes( - mlir::MLIRContext *context, - mlir::Optional location, - mlir::ValueRange operands, - mlir::DictionaryAttr attributes, - mlir::RegionRange regions, - llvm::SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back(attributes.get("value").getType()); - return mlir::success(); -} -mlir::OpFoldResult ConstantOp::fold( - ::llvm::ArrayRef operands) { - return value(); -} } // namespace pd } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 9c3d80d77e1c2984272d60badfeac7951f771719..77c22c12854c6495e7a279edc2a0aab12c5a8ab8 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -39,13 +39,6 @@ void TRTOpTellerPass::runOnFunction() { worklist.pop_back(); if (op == nullptr) continue; if (op->getName().getStringRef().substr(0, 3) != "pd.") continue; - if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null(op)) continue; - if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue; - if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; - if (::llvm::dyn_cast_or_null<::infrt::phi::TensorMapGetTensorOp>(op)) - continue; - builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); auto graph_op = builder.create<::infrt::GraphOp>( diff --git a/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir b/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir index 2889b92b18ef08fb6014eff948e2a5fc3d50c7f3..4a1e627b609c21cd3feb438777e603d8d84d8626 100644 --- a/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir +++ b/paddle/infrt/tests/dialect/disabled_rewrite_conv_bn.mlir @@ -1,6 +1,5 @@ // CHECK-LABEL: @main -func @main() -> tensor { - %a = "pd.feed"() {name="input0"} : () -> tensor +func @main(%a:tensor) -> tensor { %filter = "pd.constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32> %bias = "pd.constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> @@ -11,5 +10,5 @@ func @main() -> tensor { %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor - "pd.fetch"(%d) {name="output"} :(tensor)->() -} \ No newline at end of file + infrt.return %d:tensor +} diff --git a/paddle/infrt/tests/dialect/paddle_ops.mlir b/paddle/infrt/tests/dialect/paddle_ops.mlir deleted file mode 100644 index 4b8055514936417dd83a6bb23afaea31eb2d1013..0000000000000000000000000000000000000000 --- a/paddle/infrt/tests/dialect/paddle_ops.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: infrtopt %s | FileCheck %s -// CHECK-LABEL: @ops -func @ops() { - %a = pd.feed() {name="input0"} : tensor - %b = pd.feed() {name="input1"}: tensor - %d = pd.feed() {name="input3"}: !infrt.lod_tensor<3x4x9xf32, 0> - %c = "pd.matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - infrt.return -} diff --git a/paddle/infrt/tests/dialect/pd/rewrite.mlir b/paddle/infrt/tests/dialect/pd/rewrite.mlir index ea0248b9d95d28e0160192a44f4c542d50a4892d..295ad47770784c8f85d724a85e4788d707a8694e 100644 --- a/paddle/infrt/tests/dialect/pd/rewrite.mlir +++ b/paddle/infrt/tests/dialect/pd/rewrite.mlir @@ -1,28 +1,20 @@ // RUN: infrtopt --pd-op-fuse %s | FileCheck %s // CHECK-LABEL: @main -func @main() -> tensor { - %a = "pd.feed"() {name="input0"} : () -> tensor - %b = "pd.feed"() {name="input1"} : () -> tensor - %bias = "pd.feed"() {name="input2"} : () -> tensor +func @main(%arg0: tensor, %arg1: tensor, %arg2:tensor, %arg3:tensor, %arg4:tensor, %arg5:tensor, %arg6:tensor) -> tensor { - %b1 = "pd.feed"() {name="input3"} : () -> tensor - %b2 = "pd.feed"() {name="input4"} : () -> tensor - %bias1 = "pd.feed"() {name="input5"} : () -> tensor - %bias2 = "pd.feed"() {name="input6"} : () -> tensor - - // CHECK: %{{[0-9]+}} = "pd.FC"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor - %c = "pd.matmul_v2"(%a, %b) {transpose_y=false} : (tensor, tensor) -> tensor - %d = "pd.elementwise_add"(%c, %bias) {axis=1:si32} : (tensor, tensor) -> tensor + // CHECK: %0 = "pd.FC"(%arg0, %arg1, %arg4) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor + %c = "pd.matmul_v2"(%arg0, %arg1) {transpose_y=false} : (tensor, tensor) -> tensor + %d = "pd.elementwise_add"(%c, %arg4) {axis=1:si32} : (tensor, tensor) -> tensor %e = "pd.relu6"(%d) {} : (tensor) -> tensor - // CHECK: %{{[0-9]+}} = "pd.FC"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor - %c1 = "pd.matmul_v2"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor - %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=1:si32} : (tensor, tensor) -> tensor + // CHECK: %2 = "pd.FC"(%1, %arg2, %arg5) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor + %c1 = "pd.matmul_v2"(%e, %arg2) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor + %d1 = "pd.elementwise_add"(%c1, %arg5) {axis=1:si32} : (tensor, tensor) -> tensor %e1 = "pd.relu"(%d1) {} : (tensor) -> tensor - // CHECK: %{{[0-9]+}} = "pd.FC"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor - %c2 = "pd.matmul_v2"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:si32} : (tensor, tensor) -> tensor + // CHECK: %4 = "pd.FC"(%3, %arg3, %arg6) {in_num_col_dims = 1 : i32} : (tensor, tensor, tensor) -> tensor + %c2 = "pd.matmul_v2"(%e1, %arg3) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor + %d2 = "pd.elementwise_add"(%c2, %arg6) {axis=1:si32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - "pd.fetch"(%e2) {name="output"} :(tensor)->() + infrt.return %e2:tensor } diff --git a/paddle/infrt/tests/dialect/phi/phi_pass.mlir b/paddle/infrt/tests/dialect/phi/phi_pass.mlir index 47badd97d37db578ec36f496b21212d73fd9920e..784ead5b2a0e35aea73a5084cf957573e53281fe 100644 --- a/paddle/infrt/tests/dialect/phi/phi_pass.mlir +++ b/paddle/infrt/tests/dialect/phi/phi_pass.mlir @@ -1,18 +1,15 @@ // RUN: infrtopt -phi-op-convert -infrt-op-fuse %s // CHECK-LABEL: @ops -func @ops() { - %a = pd.feed() {name="input0"} : !infrt.lod_tensor - %b = pd.feed() {name="input1"} : !infrt.lod_tensor - %d = pd.feed() {name="input3"} : !infrt.lod_tensor<3x4x9xf32, 0> +func @ops(%a:!infrt.lod_tensor, %b:!infrt.lod_tensor) { %g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor, !infrt.lod_tensor) -> tensor %h = "pd.abs"(%g):(tensor) -> tensor - "pd.fetch"(%h) {name="output"} :(tensor)->() + infrt.return %h:tensor } // CHECK-LABEL: @op_execute func @op_execute(%a:!infrt.lod_tensor, %b:!infrt.lod_tensor, %c:!infrt.lod_tensor) -> !infrt.lod_tensor { %g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor, !infrt.lod_tensor) -> tensor %h = "pd.abs"(%g):(tensor) -> tensor - "pd.fetch"(%h) {name="output"} :(tensor)->() + infrt.return %h:tensor } diff --git a/tools/infrt/custom_pdop.td b/tools/infrt/custom_pdop.td deleted file mode 100644 index 23ab8668ae6dc20356ce2ccf24d5258438c041d5..0000000000000000000000000000000000000000 --- a/tools/infrt/custom_pdop.td +++ /dev/null @@ -1,37 +0,0 @@ -def PD_FeedOp : PD_Op<"feed", [NoSideEffect]> { - let summary = "Feed Op"; - - let description = [{ - Feed a tensor into the model. - }]; - - let arguments = (ins StrAttr:$name); - let results = (outs PD_Tensor:$out); - - let assemblyFormat = [{ - `(` `)` attr-dict `:` type($out) - }]; -} - -def PD_FetchOp : PD_Op<"fetch", [Terminator]> { - let summary = "fetch Op"; - - let description = [{ - Return the output tensor from the subgraph. - }]; - - let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); -} - -def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { - let summary = "constant Op"; - let description = [{}]; - - let arguments = (ins ElementsAttr:$value); - let results = (outs PD_Tensor:$output); - let hasFolder = 1; - - let builders = [ - OpBuilder<(ins "mlir::Attribute":$value)>, - ]; -} diff --git a/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py b/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py index a85bf231d5c84a2e7cd59a8b583285ef88c678d7..a4f93a5d6c320fc51900a9a6d42a412b90df9d78 100644 --- a/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py +++ b/tools/infrt/generate_pd_op_dialect_from_paddle_op_maker.py @@ -209,7 +209,6 @@ def get_constraint(op_type, op_proto): # funtion to generate paddle op dialect file def convert_op_proto_into_mlir(op_descs): dst_dialect_file = "../../paddle/infrt/dialect/pd/ir/pd_ops.td" - custom_dialect_file = "custom_pdop.td" # 1. Head files comment_ = "/*===- TableGen'source file -----------------------------------------------===*\\\n\ @@ -372,19 +371,13 @@ def convert_op_proto_into_mlir(op_descs): ops_mlir_file.write(RESULTS) ops_mlir_file.write("}\n") + with open(dst_dialect_file, 'a') as ops_mlir_file: + ops_mlir_file.write("\n#endif // PD_OPS") + print("Skipped ops num: " + str(len(skipped_op_list))) print("Automatically generated op dialects num: " + str( len(automatically_generated_op_dialect))) - # 3. custom op dialect and end of file - with open(dst_dialect_file, 'a') as ops_mlir_file: - with open(custom_dialect_file, 'r') as custom_ops_file: - custom_ops = custom_ops_file.readlines() - ops_mlir_file.writelines(custom_ops) - - end_ = "\n#endif // PD_OPS" - ops_mlir_file.write(end_) - if __name__ == "__main__": all_op_protos_dict = get_all_ops_desc()