diff --git a/tensorflow/compiler/xla/mlir_hlo/tosa/lib/Transforms/legalize_mhlo.cc b/tensorflow/compiler/xla/mlir_hlo/tosa/lib/Transforms/legalize_mhlo.cc index 1721606d1c14ebebacee979cb0368d8f3c6e68e2..6cf80b7af3d279574513722e3f55769e54ec0cde 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tosa/lib/Transforms/legalize_mhlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/tosa/lib/Transforms/legalize_mhlo.cc @@ -177,6 +177,64 @@ struct ConvertMhloDotOp : public OpRewritePattern { } }; +// TODO(jennik): Consider the case of a non-constant expansion. +struct ConvertMhloIotaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::IotaOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getResult().getType(); + auto elementType = resultType.cast().getElementType(); + auto resultRankedType = resultType.dyn_cast(); + + if (!resultRankedType) { + return rewriter.notifyMatchFailure(op, "result tensor must be ranked"); + } + if (!resultRankedType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "result tensor must be static"); + } + + auto resultShape = resultRankedType.getShape(); + auto iotaDimension = op.getIotaDimension(); + int64_t iotaArrayLength = resultShape[iotaDimension]; + + // Create a const op of [0, 1, 2...iotaArrayLength - 1] to be tiled. + llvm::SmallVector constValues; + constValues.resize(iotaArrayLength); + for (int i = 0; i < iotaArrayLength; i++) { + if (elementType.isa()) { + constValues[i] = rewriter.getFloatAttr(elementType, i); + } else { + constValues[i] = rewriter.getIntegerAttr(elementType, i); + } + } + + RankedTensorType constType = + RankedTensorType::get(iotaArrayLength, elementType); + auto constOp = rewriter.create( + op.getLoc(), constType, DenseElementsAttr::get(constType, constValues)); + + // Create the multiples attr for the tile op, where all dimensions except + // the iota dimension are multiplied. + llvm::SmallVector tileMultiples; + size_t tileMultiplesSize = resultShape.size(); + tileMultiples.resize(tileMultiplesSize); + + for (int i = 0; i < tileMultiplesSize; i++) { + if (i == iotaDimension) { + tileMultiples[i] = 1; + } else { + tileMultiples[i] = resultShape[i]; + } + } + + // Tile the const array to the result shape of the iota op. + rewriter.replaceOpWithNewOp( + op, resultType, constOp, rewriter.getI64ArrayAttr(tileMultiples)); + return success(); + } +}; + struct ConvertMhloReduceOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -305,6 +363,7 @@ LogicalResult LegalizeMhlo::initialize(MLIRContext* ctx) { patternList.addWithLabel({"MhloCompare"}, ctx); patternList.addWithLabel({"MhloConcatenate"}, ctx); patternList.addWithLabel({"MhloDot"}, ctx); + patternList.addWithLabel({"MhloIota"}, ctx); patternList.addWithLabel({"MhloReduce"}, ctx); patternList.addWithLabel({"MhloSlice"}, ctx); patternList.addWithLabel({"MhloTranspose"}, ctx); diff --git a/tensorflow/compiler/xla/mlir_hlo/tosa/tests/nullary.mlir b/tensorflow/compiler/xla/mlir_hlo/tosa/tests/nullary.mlir index bf396a1bd54c8c9fb40ac1eb3eeb4f213f476bce..a708a98c16525a74f8505bc4ed916f7a8439c226 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tosa/tests/nullary.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tosa/tests/nullary.mlir @@ -14,3 +14,19 @@ func.func @constant_f64() -> tensor<10xf64> { %0 = mhlo.constant dense<0.000000e+00> : tensor<10xf64> return %0 : tensor<10xf64> } + +// CHECK-LABEL: @iota_dimension_0 +func.func @iota_dimension_0() -> tensor<4x8xf32> { + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} + // CHECK-DAG: %[[VAR1:.*]] = "tosa.tile"(%[[VAR0]]) {multiples = [1, 8]} + %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>) + return %0 : tensor<4x8xf32> +} + +// CHECK-LABEL: @iota_dimension_1 +func.func @iota_dimension_1() -> tensor<4x8xi32> { + // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi32>} + // CHECK-DAG: %[[VAR1:.*]] = "tosa.tile"(%[[VAR0]]) {multiples = [4, 1]} + %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>) + return %0 : tensor<4x8xi32> +}