From 532eba993e3e4b327f8fc2bb40f9904373e71b9e Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Wed, 30 Mar 2022 10:23:03 +0800 Subject: [PATCH] add rewrite pattern form paddle mlir to trt mlir (#41087) --- paddle/infrt/dialect/tensorrt/convert.h | 66 ++++++++++++++----- .../infrt/dialect/tensorrt/pd_lower_to_trt.td | 6 ++ .../dialect/tensorrt/trt_op_converter_pass.cc | 58 ++++++++++++++++ paddle/infrt/dialect/tensorrt/trt_ops.td | 15 +++++ 4 files changed, 130 insertions(+), 15 deletions(-) diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h index 327c6a3e138..1890c839eff 100644 --- a/paddle/infrt/dialect/tensorrt/convert.h +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -22,65 +22,101 @@ namespace infrt { namespace trt { static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, mlir::Operation *op) { - ::mlir::Operation::operand_range Input(op->getOperands()); - ::mlir::Operation::operand_range Filter(op->getOperands()); - + auto conv_op = ::llvm::dyn_cast(op); ::mlir::SmallVector<::mlir::Value, 4> operands; - auto castedOp0 = ::llvm::dyn_cast(op); - (void)castedOp0; - Input = castedOp0.getODSOperands(0); - Filter = castedOp0.getODSOperands(1); - operands.push_back((*Input.begin())); + ::mlir::Operation::operand_range Input = conv_op.getODSOperands(0); + ::mlir::Operation::operand_range Filter = conv_op.getODSOperands(1); operands.push_back((*Input.begin())); + operands.push_back((*Filter.begin())); ::mlir::SmallVector<::mlir::Type, 4> resultTypes; - for (auto v : castedOp0.getODSResults(0)) { + for (auto v : conv_op.getODSResults(0)) { resultTypes.push_back(v.getType()); } ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; { + // TODO(weishengying) : get out_channel_num for filter shape auto tblgen_attr = rewriter.getSI32IntegerAttr(3); attributes.emplace_back(rewriter.getStringAttr("out_channel_num"), tblgen_attr); } { + // TODO(weishengying) : get kernel_size for filter shape auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3}); attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr); } { auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides"); - (void)tblgen_attr; attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr); } { auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings"); - (void)tblgen_attr; attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr); } { auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("padding_algorithm"); - (void)tblgen_attr; attributes.emplace_back(rewriter.getStringAttr("padding_mode"), tblgen_attr); } { auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups"); - (void)tblgen_attr; attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr); } { auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations"); - (void)tblgen_attr; attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr); } { auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format"); - (void)tblgen_attr; attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr); } return rewriter.create( op->getLoc(), resultTypes, operands, attributes); } + +static mlir::Value createTRTShuffledOp(mlir::PatternRewriter &rewriter, + mlir::Operation *op, + const mlir::Value &input, + const mlir::Attribute &start, + const mlir::Attribute &stop) { + auto flatten_op = ::llvm::dyn_cast(op); + ::mlir::SmallVector<::mlir::Value, 4> operands; + operands.push_back(input); + + ::mlir::SmallVector<::mlir::Type, 4> resultTypes; + for (auto v : flatten_op.getODSResults(0)) { + resultTypes.push_back(v.getType()); + } + + ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; + mlir::IntegerAttr start_attr = start.dyn_cast(); + mlir::IntegerAttr stop_attr = stop.dyn_cast(); + + int start_axis = start_attr.getSInt(); + int stop_axis = stop_attr.getSInt(); + // TODO(weishengying) : get dim form DenseTonsor + int dims = 4; + // TODO(weishengying) : get input_dims form DenseTonsor + int input_dims[4] = {1, 2048, 1, 1}; + int dim_prod = 1; + + std::vector flatten_dim(dims - (stop_axis - start_axis)); + for (int i = 0, j = 0; i < dims; ++i) { + if (start_axis <= i + 1 && i + 1 <= stop_axis) { + int dim_i = input_dims[i]; + dim_prod *= dim_i; + if (i + 1 == stop_axis) { + flatten_dim[j++] = dim_prod; + } + } else { + flatten_dim[j++] = input_dims[i]; + } + } + auto reshape_arrt = rewriter.getI32ArrayAttr(flatten_dim); + attributes.emplace_back(rewriter.getStringAttr("reshape"), reshape_arrt); + return rewriter.create( + op->getLoc(), resultTypes, operands, attributes); +} } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index f833600a36a..b153e84b53f 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -42,4 +42,10 @@ def PD2TRT_MatrixMultipl_Lower : Pat< def PD2TRT_SoftMax_Lower : Pat< (PD_SoftmaxOp $Input, $axis, $_), (TRT_SoftMaxOp $Input, $axis)>; + +def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">; + +def PD2TRT_Flatten_contiguous_range_Lower : Pat< + (PD_Flatten_contiguous_rangeOp:$out $input, $start_axis, $end_axis), + (createTRTShuffledOp $out, $input, $start_axis, $end_axis)>; #endif // PD_LOWER_TO_TRT diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 9516df70bb0..6bcef3d913d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -109,6 +109,63 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern { } }; +struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { + explicit PD2TRT_Batch_Norm_Lower(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("pd.batch_norm", 1, context, {"trt.scaleNd"}) {} + ::mlir::LogicalResult matchAndRewrite( + ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { + auto casted_op = ::llvm::dyn_cast(op); + ::mlir::SmallVector<::mlir::Value, 4> operands; + ::mlir::Operation::operand_range Input = casted_op.getODSOperands(0); + ::mlir::Operation::operand_range Scale = casted_op.getODSOperands(1); + ::mlir::Operation::operand_range Bias = casted_op.getODSOperands(2); + + // TODO(weishengying) : recompute this via params + operands.push_back((*Input.begin())); + operands.push_back((*Scale.begin())); + operands.push_back((*Bias.begin())); + operands.push_back((*Bias.begin())); + + trt::ScaleNdOp scaleNd_op; + // inputs + ::mlir::SmallVector<::mlir::Value, 4> trt_inputs; + for (auto v : operands) { + trt_inputs.push_back(v); + } + + // resultTypes + ::mlir::SmallVector<::mlir::Type, 4> resultTypes; + for (auto v : casted_op.getODSResults(0)) { + resultTypes.push_back(v.getType()); + } + + // attributes + ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; + { + auto mode_attr = rewriter.getI32IntegerAttr(1); + attributes.emplace_back(rewriter.getStringAttr("mode"), mode_attr); + } + + { + auto axis_attr = rewriter.getI32IntegerAttr(-1); + attributes.emplace_back(rewriter.getStringAttr("axis"), axis_attr); + } + auto result = rewriter + .create( + op->getLoc(), resultTypes, operands, attributes) + .getODSResults(0); + ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + // TODO(weishengying) : update it + for (uint32_t i = 0; i < casted_op.getNumResults(); i++) { + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{result}) { + tblgen_repl_values.push_back(v); + } + } + rewriter.replaceOp(op, tblgen_repl_values); + return ::mlir::success(); + } +}; + void TRTOpConverterPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. @@ -126,6 +183,7 @@ void TRTOpConverterPass::runOnOperation() { // the set of patterns that will lower the TensorRT operations. ::mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); + patterns.add(&getContext()); patterns.add(&getContext()); // With the target and rewrite patterns defined, we can now attempt the diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index eb64cafef29..3fd3f377f4e 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -201,4 +201,19 @@ def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> { let results = (outs DenseTensor:$Out); } + +def TRT_ShuffleOp : TRT_Op<"Shuffle", [NoSideEffect]> { + let summary = "TensorRT IShuffleLayer"; + let description = [{ + + TensorRT IShuffleLayer + + }]; + let arguments = (ins + DenseTensor:$input_tensor, + I32ArrayAttr:$reshape + ); + + let results = (outs DenseTensor:$Out); +} #endif // TRT_OPS -- GitLab