提交 7d3e3c14 编写于 作者: J Jenni Kilduff 提交者: TensorFlower Gardener

[mlir][tosa] Adds MHLO -> TOSA legalizations for iota

This creates a const op filled with [0, 1, 2...iotaSize] values, then tiles it to the iota result shape

PiperOrigin-RevId: 480930548
上级 09165bd3
......@@ -177,6 +177,64 @@ struct ConvertMhloDotOp : public OpRewritePattern<mhlo::DotOp> {
}
};
// TODO(jennik): Consider the case of a non-constant expansion.
struct ConvertMhloIotaOp : public OpRewritePattern<mhlo::IotaOp> {
using OpRewritePattern<mhlo::IotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::IotaOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getResult().getType();
auto elementType = resultType.cast<ShapedType>().getElementType();
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
if (!resultRankedType) {
return rewriter.notifyMatchFailure(op, "result tensor must be ranked");
}
if (!resultRankedType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "result tensor must be static");
}
auto resultShape = resultRankedType.getShape();
auto iotaDimension = op.getIotaDimension();
int64_t iotaArrayLength = resultShape[iotaDimension];
// Create a const op of [0, 1, 2...iotaArrayLength - 1] to be tiled.
llvm::SmallVector<mlir::Attribute, 4> constValues;
constValues.resize(iotaArrayLength);
for (int i = 0; i < iotaArrayLength; i++) {
if (elementType.isa<FloatType>()) {
constValues[i] = rewriter.getFloatAttr(elementType, i);
} else {
constValues[i] = rewriter.getIntegerAttr(elementType, i);
}
}
RankedTensorType constType =
RankedTensorType::get(iotaArrayLength, elementType);
auto constOp = rewriter.create<tosa::ConstOp>(
op.getLoc(), constType, DenseElementsAttr::get(constType, constValues));
// Create the multiples attr for the tile op, where all dimensions except
// the iota dimension are multiplied.
llvm::SmallVector<int64_t, 4> tileMultiples;
size_t tileMultiplesSize = resultShape.size();
tileMultiples.resize(tileMultiplesSize);
for (int i = 0; i < tileMultiplesSize; i++) {
if (i == iotaDimension) {
tileMultiples[i] = 1;
} else {
tileMultiples[i] = resultShape[i];
}
}
// Tile the const array to the result shape of the iota op.
rewriter.replaceOpWithNewOp<tosa::TileOp>(
op, resultType, constOp, rewriter.getI64ArrayAttr(tileMultiples));
return success();
}
};
struct ConvertMhloReduceOp : public OpRewritePattern<mhlo::ReduceOp> {
using OpRewritePattern<mhlo::ReduceOp>::OpRewritePattern;
......@@ -305,6 +363,7 @@ LogicalResult LegalizeMhlo::initialize(MLIRContext* ctx) {
patternList.addWithLabel<ConvertMhloCompareOp>({"MhloCompare"}, ctx);
patternList.addWithLabel<ConvertMhloConcatenateOp>({"MhloConcatenate"}, ctx);
patternList.addWithLabel<ConvertMhloDotOp>({"MhloDot"}, ctx);
patternList.addWithLabel<ConvertMhloIotaOp>({"MhloIota"}, ctx);
patternList.addWithLabel<ConvertMhloReduceOp>({"MhloReduce"}, ctx);
patternList.addWithLabel<ConvertMhloSliceOp>({"MhloSlice"}, ctx);
patternList.addWithLabel<ConvertMhloTransposeOp>({"MhloTranspose"}, ctx);
......
......@@ -14,3 +14,19 @@ func.func @constant_f64() -> tensor<10xf64> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<10xf64>
return %0 : tensor<10xf64>
}
// CHECK-LABEL: @iota_dimension_0
func.func @iota_dimension_0() -> tensor<4x8xf32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>}
// CHECK-DAG: %[[VAR1:.*]] = "tosa.tile"(%[[VAR0]]) {multiples = [1, 8]}
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>)
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: @iota_dimension_1
func.func @iota_dimension_1() -> tensor<4x8xi32> {
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi32>}
// CHECK-DAG: %[[VAR1:.*]] = "tosa.tile"(%[[VAR0]]) {multiples = [4, 1]}
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>)
return %0 : tensor<4x8xi32>
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册