未验证 提交 59765362 编写于 作者: W weishengying 提交者: GitHub

add rewrite pattern form paddle mlir to trt mlir (#41011)

上级 869287f8
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Builders.h>
#include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter,
mlir::Operation *op) {
::mlir::Operation::operand_range Input(op->getOperands());
::mlir::Operation::operand_range Filter(op->getOperands());
::mlir::SmallVector<::mlir::Value, 4> operands;
auto castedOp0 = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op);
(void)castedOp0;
Input = castedOp0.getODSOperands(0);
Filter = castedOp0.getODSOperands(1);
operands.push_back((*Input.begin()));
operands.push_back((*Input.begin()));
::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : castedOp0.getODSResults(0)) {
resultTypes.push_back(v.getType());
}
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
{
auto tblgen_attr = rewriter.getSI32IntegerAttr(3);
attributes.emplace_back(rewriter.getStringAttr("out_channel_num"),
tblgen_attr);
}
{
auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3});
attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr);
}
{
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr);
}
{
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr);
}
{
auto tblgen_attr =
op->getAttrOfType<::mlir::StringAttr>("padding_algorithm");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("padding_mode"),
tblgen_attr);
}
{
auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr);
}
{
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr);
}
{
auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr);
}
return rewriter.create<trt::ConvolutionOp>(
op->getLoc(), resultTypes, operands, attributes);
}
} // namespace trt
} // namespace infrt
...@@ -14,7 +14,7 @@ def PD2TRT_Matmul_Lower : Pat< ...@@ -14,7 +14,7 @@ def PD2TRT_Matmul_Lower : Pat<
(TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>; (TRT_MatrixMultiplyOp $X, $transpose_X, $Y, $transpose_Y)>;
def PD2TRT_ElementwiseAdd_Lower : Pat< def PD2TRT_ElementwiseAdd_Lower : Pat<
(PD_Elementwise_addOp $X, $Y, ConstantAttr<SI32Attr, "-1">), (PD_Elementwise_addOp $X, $Y, $_),
(TRT_ElementWiseOp $X, $Y, (TRT_createNvinferEnumAttr<"nvinfer1::ElementWiseOperation", "kSUM">))>; (TRT_ElementWiseOp $X, $Y, (TRT_createNvinferEnumAttr<"nvinfer1::ElementWiseOperation", "kSUM">))>;
def PD2TRT_Relu_Lower : Pat< def PD2TRT_Relu_Lower : Pat<
...@@ -25,4 +25,21 @@ def PD2TRT_Relu6_Lower : Pat< ...@@ -25,4 +25,21 @@ def PD2TRT_Relu6_Lower : Pat<
(PD_Relu6Op $X, $threshold), (PD_Relu6Op $X, $threshold),
(TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>; (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>;
def createTRTConv2dOp : NativeCodeCall<"createTRTConv2dOp($_builder, $0.getDefiningOp())">;
def PD2TRT_Conv2d_Lower : Pat<
(PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format),
(createTRTConv2dOp $old_value)>;
def PD2TRT_Pooling_Lower : Pat<
(PD_Pool2dOp $Input, $pooling_type, $ksize, $global_pooling, $strides, $paddings, $exclusive, $adaptive, $ceil_mode, $data_format, $padding_algorithm),
(TRT_PoolingOp $Input, (INFRT_createI32Attr<"0">)/*kmax*/, $ksize, $strides, $paddings, $padding_algorithm)>;
def PD2TRT_MatrixMultipl_Lower : Pat<
(PD_MulOp $Input1, $Input2, $x_num_col_dims, $y_num_col_dims),
(TRT_MatrixMultiplOp $Input1, (INFRT_createI32Attr<"0">)/*kNONE*/, $Input2, (INFRT_createI32Attr<"0">)/*kNONE*/)>;
def PD2TRT_SoftMax_Lower : Pat<
(PD_SoftmaxOp $Input, $axis, $_),
(TRT_SoftMaxOp $Input, $axis)>;
#endif // PD_LOWER_TO_TRT #endif // PD_LOWER_TO_TRT
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/infrt/dialect/pd/ir/pd_ops.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h" #include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/tensorrt/convert.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" #include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/ir/pd_ops.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -42,6 +43,8 @@ void TRTOpTellerPass::runOnFunction() { ...@@ -42,6 +43,8 @@ void TRTOpTellerPass::runOnFunction() {
if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue; if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::phi::TensorMapGetTensorOp>(op))
continue;
builder.setInsertionPoint(op); builder.setInsertionPoint(op);
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
......
...@@ -9,6 +9,7 @@ include "paddle/infrt/dialect/tensorrt/trt_op_base.td" ...@@ -9,6 +9,7 @@ include "paddle/infrt/dialect/tensorrt/trt_op_base.td"
include "paddle/infrt/dialect/infrt/ir/infrt_base.td" include "paddle/infrt/dialect/infrt/ir/infrt_base.td"
include "paddle/infrt/dialect/phi/ir/infrt_phi_base.td" include "paddle/infrt/dialect/phi/ir/infrt_phi_base.td"
include "paddle/infrt/dialect/pd/ir/pd_op_base.td"
def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> {
let summary = "trt CreateEngine Op"; let summary = "trt CreateEngine Op";
...@@ -16,7 +17,7 @@ def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator< ...@@ -16,7 +17,7 @@ def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<
Describe a tensorrt subgraph. Describe a tensorrt subgraph.
}]; }];
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let arguments = (ins Variadic<DenseTensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once); let arguments = (ins Variadic<PD_Tensor>:$inputs, DefaultValuedAttr<BoolAttr, "true">:$run_once);
let results = (outs TRT_EngineType:$engine); let results = (outs TRT_EngineType:$engine);
} }
...@@ -75,9 +76,32 @@ def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> { ...@@ -75,9 +76,32 @@ def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> {
let arguments = (ins let arguments = (ins
DenseTensor:$input_tensor, DenseTensor:$input_tensor,
DenseTensor:$kernel_weights, DenseTensor:$kernel_weights,
DenseTensor:$bias_weights, Optional<DenseTensor>:$bias_weights,
SI32Attr:$out_channel_num, SI32Attr:$out_channel_num,
I32ArrayAttr:$kernel_size I32ArrayAttr:$kernel_size,
I32ArrayAttr:$strides,
I32ArrayAttr:$paddings,
StrAttr:$padding_mode,
SI32Attr:$groups,
I32ArrayAttr:$dilations
);
let results = (outs
DenseTensor:$output_tensor
);
}
def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> {
let summary = "TensorRT IPoolingLayer ";
let description = [{
TensorRT IPoolingLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
I32Attr:$pool_type,
I32ArrayAttr:$window_size,
I32ArrayAttr:$strides,
I32ArrayAttr:$paddings,
StrAttr:$padding_mode
); );
let results = (outs let results = (outs
DenseTensor:$output_tensor DenseTensor:$output_tensor
...@@ -109,4 +133,72 @@ def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> { ...@@ -109,4 +133,72 @@ def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> {
let results = (outs DenseTensor:$output); let results = (outs DenseTensor:$output);
} }
def TRT_ScaleOp : TRT_Op<"scale", [NoSideEffect]> {
let summary = "TensorRT IScaleLayer";
let description = [{
TensorRT IScaleLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
DefaultValuedAttr<I32Attr, "0">:$mode,
DenseTensor:$shift,
DenseTensor:$scale,
DenseTensor:$power
);
let results = (outs DenseTensor:$Out);
}
def TRT_MatrixMultiplOp : TRT_Op<"MatrixMultiplOp", [NoSideEffect]> {
let summary = "TensorRT IMatrixMultiplyLayer";
let description = [{
TensorRT IMatrixMultiplyLayer
}];
let arguments = (ins
DenseTensor:$input1,
DefaultValuedAttr<I32Attr, "0">:$matrix_operation1,
DenseTensor:$input2,
DefaultValuedAttr<I32Attr, "0">:$matrix_operation2
);
let results = (outs DenseTensor:$Out);
}
def TRT_SoftMaxOp : TRT_Op<"SoftMaxOp", [NoSideEffect]> {
let summary = "TensorRT ISoftMaxLayer";
let description = [{
TensorRT ISoftMaxLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
SI32Attr:$axis
);
let results = (outs DenseTensor:$Out);
}
def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> {
let summary = "TensorRT IScaleLayer";
let description = [{
TensorRT IScaleLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
I32Attr:$mode,
DenseTensor:$shift,
DenseTensor:$scale,
DenseTensor:$power,
I32Attr:$axis
);
let results = (outs DenseTensor:$Out);
}
#endif // TRT_OPS #endif // TRT_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册