diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index c5f87c602a30403d1b8e82443a8ae5f796b69a91..c2b3d8a353772497e51cc5d39a8165347259d545 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: invert_permutation func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> { // CHECK-NEXT: %[[UPDATES:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> - // CHECK-NEXT: %[[PERM:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK-NEXT: %[[INDICES:.*]] = "tf.Transpose"(%arg0, %[[PERM]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32> + // CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK-NEXT: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32> // CHECK-NEXT: "tf.TensorScatterUpdate"(%arg0, %[[INDICES]], %[[UPDATES]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32> %0 = "tf.InvertPermutation"(%arg0) : (tensor<5xi32>) -> tensor<5xi32> return %0 : tensor<5xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index f934e2ac169ccefa5894cfec38d150380bd80828..c0de6f557ab20aae81a5531a390d68a545d62641 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -253,8 +253,8 @@ class LowerDynamicStitchOp : public OpRewritePattern { // %delta = "tf.Const"() {value = dense<1> : tensor} // %updates = "tf.Range"(%start, %limit, %delta) : // (tensor, tensor, tensor) -> tensor<5xi32> -// %perm = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} -// %indices = "tf.Transpose"(%x, %perm) : (tensor<5xi32, tensor<2xi32) -> +// %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} +// %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) -> // tensor<5x1xi32> // "tf.TensorScatterUpdate"(%x, %indices, %updates) : // (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32> @@ -268,13 +268,12 @@ class LowerInvertPermutationOp LogicalResult matchAndRewrite(TF::InvertPermutationOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto x_type = op.x().getType().cast(); - Type int_type = x_type.getElementType(); // Could be i32 or i64. - + auto x_type = op.x().getType().dyn_cast(); // x input must have static shape. - if (!x_type.hasStaticShape()) { + if (!x_type || !x_type.hasStaticShape()) { return failure(); } + Type int_type = x_type.getElementType(); // Could be i32 or i64. auto result_type = x_type; auto start = @@ -287,13 +286,11 @@ class LowerInvertPermutationOp auto updates = rewriter.create(loc, result_type, start, limit, delta); - auto perm_type = RankedTensorType::get({2}, int_type); - auto perm = rewriter.create( - loc, DenseElementsAttr::get(perm_type, {1, 0})); - auto transposed_x_type = - RankedTensorType::get({x_type.getShape()[0], 1}, int_type); - auto indices = - rewriter.create(loc, transposed_x_type, op.x(), perm); + auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32)); + auto shape = rewriter.create( + loc, DenseElementsAttr::get( + shape_type, {static_cast(x_type.getDimSize(0)), 1})); + auto indices = rewriter.create(loc, op.x(), shape); rewriter.replaceOpWithNewOp( op, result_type, op.x(), indices, updates);