未验证 提交 532eba99 编写于 作者: W weishengying 提交者: GitHub

add rewrite pattern form paddle mlir to trt mlir (#41087)

上级 495ca4aa
...@@ -22,65 +22,101 @@ namespace infrt { ...@@ -22,65 +22,101 @@ namespace infrt {
namespace trt { namespace trt {
static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter,
mlir::Operation *op) { mlir::Operation *op) {
::mlir::Operation::operand_range Input(op->getOperands()); auto conv_op = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op);
::mlir::Operation::operand_range Filter(op->getOperands());
::mlir::SmallVector<::mlir::Value, 4> operands; ::mlir::SmallVector<::mlir::Value, 4> operands;
auto castedOp0 = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op); ::mlir::Operation::operand_range Input = conv_op.getODSOperands(0);
(void)castedOp0; ::mlir::Operation::operand_range Filter = conv_op.getODSOperands(1);
Input = castedOp0.getODSOperands(0);
Filter = castedOp0.getODSOperands(1);
operands.push_back((*Input.begin()));
operands.push_back((*Input.begin())); operands.push_back((*Input.begin()));
operands.push_back((*Filter.begin()));
::mlir::SmallVector<::mlir::Type, 4> resultTypes; ::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : castedOp0.getODSResults(0)) { for (auto v : conv_op.getODSResults(0)) {
resultTypes.push_back(v.getType()); resultTypes.push_back(v.getType());
} }
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
{ {
// TODO(weishengying) : get out_channel_num for filter shape
auto tblgen_attr = rewriter.getSI32IntegerAttr(3); auto tblgen_attr = rewriter.getSI32IntegerAttr(3);
attributes.emplace_back(rewriter.getStringAttr("out_channel_num"), attributes.emplace_back(rewriter.getStringAttr("out_channel_num"),
tblgen_attr); tblgen_attr);
} }
{ {
// TODO(weishengying) : get kernel_size for filter shape
auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3}); auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3});
attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr);
} }
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides"); auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr);
} }
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings"); auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr);
} }
{ {
auto tblgen_attr = auto tblgen_attr =
op->getAttrOfType<::mlir::StringAttr>("padding_algorithm"); op->getAttrOfType<::mlir::StringAttr>("padding_algorithm");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("padding_mode"), attributes.emplace_back(rewriter.getStringAttr("padding_mode"),
tblgen_attr); tblgen_attr);
} }
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups"); auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr);
} }
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations"); auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr);
} }
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format"); auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format");
(void)tblgen_attr;
attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr);
} }
return rewriter.create<trt::ConvolutionOp>( return rewriter.create<trt::ConvolutionOp>(
op->getLoc(), resultTypes, operands, attributes); op->getLoc(), resultTypes, operands, attributes);
} }
static mlir::Value createTRTShuffledOp(mlir::PatternRewriter &rewriter,
mlir::Operation *op,
const mlir::Value &input,
const mlir::Attribute &start,
const mlir::Attribute &stop) {
auto flatten_op = ::llvm::dyn_cast<infrt::pd::Flatten_contiguous_rangeOp>(op);
::mlir::SmallVector<::mlir::Value, 4> operands;
operands.push_back(input);
::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : flatten_op.getODSResults(0)) {
resultTypes.push_back(v.getType());
}
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
mlir::IntegerAttr start_attr = start.dyn_cast<mlir::IntegerAttr>();
mlir::IntegerAttr stop_attr = stop.dyn_cast<mlir::IntegerAttr>();
int start_axis = start_attr.getSInt();
int stop_axis = stop_attr.getSInt();
// TODO(weishengying) : get dim form DenseTonsor
int dims = 4;
// TODO(weishengying) : get input_dims form DenseTonsor
int input_dims[4] = {1, 2048, 1, 1};
int dim_prod = 1;
std::vector<int> flatten_dim(dims - (stop_axis - start_axis));
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input_dims[i];
dim_prod *= dim_i;
if (i + 1 == stop_axis) {
flatten_dim[j++] = dim_prod;
}
} else {
flatten_dim[j++] = input_dims[i];
}
}
auto reshape_arrt = rewriter.getI32ArrayAttr(flatten_dim);
attributes.emplace_back(rewriter.getStringAttr("reshape"), reshape_arrt);
return rewriter.create<trt::ShuffleOp>(
op->getLoc(), resultTypes, operands, attributes);
}
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -42,4 +42,10 @@ def PD2TRT_MatrixMultipl_Lower : Pat< ...@@ -42,4 +42,10 @@ def PD2TRT_MatrixMultipl_Lower : Pat<
def PD2TRT_SoftMax_Lower : Pat< def PD2TRT_SoftMax_Lower : Pat<
(PD_SoftmaxOp $Input, $axis, $_), (PD_SoftmaxOp $Input, $axis, $_),
(TRT_SoftMaxOp $Input, $axis)>; (TRT_SoftMaxOp $Input, $axis)>;
def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">;
def PD2TRT_Flatten_contiguous_range_Lower : Pat<
(PD_Flatten_contiguous_rangeOp:$out $input, $start_axis, $end_axis),
(createTRTShuffledOp $out, $input, $start_axis, $end_axis)>;
#endif // PD_LOWER_TO_TRT #endif // PD_LOWER_TO_TRT
...@@ -109,6 +109,63 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern { ...@@ -109,6 +109,63 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
} }
}; };
struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern {
explicit PD2TRT_Batch_Norm_Lower(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("pd.batch_norm", 1, context, {"trt.scaleNd"}) {}
::mlir::LogicalResult matchAndRewrite(
::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override {
auto casted_op = ::llvm::dyn_cast<infrt::pd::Batch_normOp>(op);
::mlir::SmallVector<::mlir::Value, 4> operands;
::mlir::Operation::operand_range Input = casted_op.getODSOperands(0);
::mlir::Operation::operand_range Scale = casted_op.getODSOperands(1);
::mlir::Operation::operand_range Bias = casted_op.getODSOperands(2);
// TODO(weishengying) : recompute this via params
operands.push_back((*Input.begin()));
operands.push_back((*Scale.begin()));
operands.push_back((*Bias.begin()));
operands.push_back((*Bias.begin()));
trt::ScaleNdOp scaleNd_op;
// inputs
::mlir::SmallVector<::mlir::Value, 4> trt_inputs;
for (auto v : operands) {
trt_inputs.push_back(v);
}
// resultTypes
::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : casted_op.getODSResults(0)) {
resultTypes.push_back(v.getType());
}
// attributes
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
{
auto mode_attr = rewriter.getI32IntegerAttr(1);
attributes.emplace_back(rewriter.getStringAttr("mode"), mode_attr);
}
{
auto axis_attr = rewriter.getI32IntegerAttr(-1);
attributes.emplace_back(rewriter.getStringAttr("axis"), axis_attr);
}
auto result = rewriter
.create<trt::ScaleNdOp>(
op->getLoc(), resultTypes, operands, attributes)
.getODSResults(0);
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
// TODO(weishengying) : update it
for (uint32_t i = 0; i < casted_op.getNumResults(); i++) {
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{result}) {
tblgen_repl_values.push_back(v);
}
}
rewriter.replaceOp(op, tblgen_repl_values);
return ::mlir::success();
}
};
void TRTOpConverterPass::runOnOperation() { void TRTOpConverterPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the // The first thing to define is the conversion target. This will define the
// final target for this lowering. // final target for this lowering.
...@@ -126,6 +183,7 @@ void TRTOpConverterPass::runOnOperation() { ...@@ -126,6 +183,7 @@ void TRTOpConverterPass::runOnOperation() {
// the set of patterns that will lower the TensorRT operations. // the set of patterns that will lower the TensorRT operations.
::mlir::RewritePatternSet patterns(&getContext()); ::mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns); populateWithGenerated(patterns);
patterns.add<PD2TRT_Batch_Norm_Lower>(&getContext());
patterns.add<PD2TRT_GraphLower>(&getContext()); patterns.add<PD2TRT_GraphLower>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
......
...@@ -201,4 +201,19 @@ def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> { ...@@ -201,4 +201,19 @@ def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> {
let results = (outs DenseTensor:$Out); let results = (outs DenseTensor:$Out);
} }
def TRT_ShuffleOp : TRT_Op<"Shuffle", [NoSideEffect]> {
let summary = "TensorRT IShuffleLayer";
let description = [{
TensorRT IShuffleLayer
}];
let arguments = (ins
DenseTensor:$input_tensor,
I32ArrayAttr:$reshape
);
let results = (outs DenseTensor:$Out);
}
#endif // TRT_OPS #endif // TRT_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册