From 597653628ce664add3ada7301839acddf016a994 Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Tue, 29 Mar 2022 11:02:04 +0800 Subject: [PATCH] add rewrite pattern form paddle mlir to trt mlir (#41011) --- paddle/infrt/dialect/tensorrt/convert.h | 86 ++++++++++++++++ .../infrt/dialect/tensorrt/pd_lower_to_trt.td | 19 +++- .../dialect/tensorrt/trt_op_converter_pass.cc | 1 + .../dialect/tensorrt/trt_op_teller_pass.cc | 3 + paddle/infrt/dialect/tensorrt/trt_ops.td | 98 ++++++++++++++++++- 5 files changed, 203 insertions(+), 4 deletions(-) create mode 100644 paddle/infrt/dialect/tensorrt/convert.h diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h new file mode 100644 index 00000000000..327c6a3e138 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace infrt { +namespace trt { +static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, + mlir::Operation *op) { + ::mlir::Operation::operand_range Input(op->getOperands()); + ::mlir::Operation::operand_range Filter(op->getOperands()); + + ::mlir::SmallVector<::mlir::Value, 4> operands; + auto castedOp0 = ::llvm::dyn_cast(op); + (void)castedOp0; + Input = castedOp0.getODSOperands(0); + Filter = castedOp0.getODSOperands(1); + operands.push_back((*Input.begin())); + operands.push_back((*Input.begin())); + + ::mlir::SmallVector<::mlir::Type, 4> resultTypes; + for (auto v : castedOp0.getODSResults(0)) { + resultTypes.push_back(v.getType()); + } + ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; + { + auto tblgen_attr = rewriter.getSI32IntegerAttr(3); + attributes.emplace_back(rewriter.getStringAttr("out_channel_num"), + tblgen_attr); + } + { + auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3}); + attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr); + } + { + auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides"); + (void)tblgen_attr; + attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr); + } + { + auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings"); + (void)tblgen_attr; + attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr); + } + { + auto tblgen_attr = + op->getAttrOfType<::mlir::StringAttr>("padding_algorithm"); + (void)tblgen_attr; + attributes.emplace_back(rewriter.getStringAttr("padding_mode"), + tblgen_attr); + } + { + auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups"); + (void)tblgen_attr; + attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr); + } + { + auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations"); + (void)tblgen_attr; + attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr); + } + { + auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format"); + (void)tblgen_attr; + attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr); + } + return rewriter.create( + op->getLoc(), resultTypes, operands, attributes); +} +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index 1c5ba689368..f833600a36a 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -14,7 +14,7 @@ def PD2TRT_Matmul_Lower : Pat< (TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>; def PD2TRT_ElementwiseAdd_Lower : Pat< - (PD_Elementwise_addOp $X, $Y, ConstantAttr), + (PD_Elementwise_addOp $X, $Y, $_), (TRT_ElementWiseOp $X, $Y, (TRT_createNvinferEnumAttr<"nvinfer1::ElementWiseOperation", "kSUM">))>; def PD2TRT_Relu_Lower : Pat< @@ -25,4 +25,21 @@ def PD2TRT_Relu6_Lower : Pat< (PD_Relu6Op $X, $threshold), (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>; +def createTRTConv2dOp : NativeCodeCall<"createTRTConv2dOp($_builder, $0.getDefiningOp())">; + +def PD2TRT_Conv2d_Lower : Pat< + (PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format), + (createTRTConv2dOp $old_value)>; + +def PD2TRT_Pooling_Lower : Pat< + (PD_Pool2dOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm), + (TRT_PoolingOp $Input, (INFRT_createI32Attr<"0">)/*kmax*/, $ksize, $strides, $paddings, $padding_algorithm)>; + +def PD2TRT_MatrixMultipl_Lower : Pat< + (PD_MulOp $Input1, $Input2, $x_num_col_dims, $y_num_col_dims), + (TRT_MatrixMultiplOp $Input1, (INFRT_createI32Attr<"0">)/*kNONE*/, $Input2, (INFRT_createI32Attr<"0">)/*kNONE*/)>; + +def PD2TRT_SoftMax_Lower : Pat< + (PD_SoftmaxOp $Input, $axis, $_), + (TRT_SoftMaxOp $Input, $axis)>; #endif // PD_LOWER_TO_TRT diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index e3dab7093c5..9516df70bb0 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -21,6 +21,7 @@ #include "paddle/infrt/dialect/pd/ir/pd_ops.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/ir/phi_base.h" +#include "paddle/infrt/dialect/tensorrt/convert.h" #include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 7c9ec16d204..9c3d80d77e1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -20,6 +20,7 @@ #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" namespace infrt { namespace trt { @@ -42,6 +43,8 @@ void TRTOpTellerPass::runOnFunction() { if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; + if (::llvm::dyn_cast_or_null<::infrt::phi::TensorMapGetTensorOp>(op)) + continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index d0585532adf..eb64cafef29 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -9,6 +9,7 @@ include "paddle/infrt/dialect/tensorrt/trt_op_base.td" include "paddle/infrt/dialect/infrt/ir/infrt_base.td" include "paddle/infrt/dialect/phi/ir/infrt_phi_base.td" +include "paddle/infrt/dialect/pd/ir/pd_op_base.td" def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { let summary = "trt CreateEngine Op"; @@ -16,7 +17,7 @@ def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator< Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); - let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); + let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); let results = (outs TRT_EngineType:$engine); } @@ -75,9 +76,32 @@ def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> { let arguments = (ins DenseTensor:$input_tensor, DenseTensor:$kernel_weights, - DenseTensor:$bias_weights, + Optional:$bias_weights, SI32Attr:$out_channel_num, - I32ArrayAttr:$kernel_size + I32ArrayAttr:$kernel_size, + I32ArrayAttr:$strides, + I32ArrayAttr:$paddings, + StrAttr:$padding_mode, + SI32Attr:$groups, + I32ArrayAttr:$dilations + ); + let results = (outs + DenseTensor:$output_tensor + ); +} + +def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> { + let summary = "TensorRT IPoolingLayer "; + let description = [{ + TensorRT IPoolingLayer + }]; + let arguments = (ins + DenseTensor:$input_tensor, + I32Attr:$pool_type, + I32ArrayAttr:$window_size, + I32ArrayAttr:$strides, + I32ArrayAttr:$paddings, + StrAttr:$padding_mode ); let results = (outs DenseTensor:$output_tensor @@ -109,4 +133,72 @@ def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> { let results = (outs DenseTensor:$output); } +def TRT_ScaleOp : TRT_Op<"scale", [NoSideEffect]> { + let summary = "TensorRT IScaleLayer"; + let description = [{ + + TensorRT IScaleLayer + + }]; + let arguments = (ins + DenseTensor:$input_tensor, + DefaultValuedAttr:$mode, + DenseTensor:$shift, + DenseTensor:$scale, + DenseTensor:$power + ); + + let results = (outs DenseTensor:$Out); +} + +def TRT_MatrixMultiplOp : TRT_Op<"MatrixMultiplOp", [NoSideEffect]> { + let summary = "TensorRT IMatrixMultiplyLayer"; + let description = [{ + + TensorRT IMatrixMultiplyLayer + + }]; + let arguments = (ins + DenseTensor:$input1, + DefaultValuedAttr:$matrix_operation1, + DenseTensor:$input2, + DefaultValuedAttr:$matrix_operation2 + ); + + let results = (outs DenseTensor:$Out); +} + +def TRT_SoftMaxOp : TRT_Op<"SoftMaxOp", [NoSideEffect]> { + let summary = "TensorRT ISoftMaxLayer"; + let description = [{ + + TensorRT ISoftMaxLayer + + }]; + let arguments = (ins + DenseTensor:$input_tensor, + SI32Attr:$axis + ); + + let results = (outs DenseTensor:$Out); +} + +def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> { + let summary = "TensorRT IScaleLayer"; + let description = [{ + + TensorRT IScaleLayer + + }]; + let arguments = (ins + DenseTensor:$input_tensor, + I32Attr:$mode, + DenseTensor:$shift, + DenseTensor:$scale, + DenseTensor:$power, + I32Attr:$axis + ); + + let results = (outs DenseTensor:$Out); +} #endif // TRT_OPS -- GitLab