diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 2ca3d2b21bf387abbc9d58ae997a272178dfe2d2..54ed2a255580ddacc96f8bd8836f8e8bf47a1a0b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -475,6 +475,7 @@ cc_library( ], deps = [ ":passes_inc_gen", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_device_pass_inc_gen", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index 5200dbb0a64709ff7ab7cd3e341c5fc3e499068f..e685a3e3c2706066cbd033730d5c69f9c9e807fd 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -3313,14 +3313,15 @@ func.func @convert_dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor } -// CHECK-LABEL: func @convert_scatter_update( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant dense<[4, 1]> : tensor<2xi64> -// CHECK: %[[VAL_4:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_3]]) : {{.*}} -> tensor<4x1xi32> -// CHECK: %[[VAL_5:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32> -// CHECK: return %[[VAL_5]] : tensor<20x6xf32> +// CHECK-LABEL: func.func @convert_scatter_update( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_5]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> +// CHECK: return %[[VAL_3]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ @@ -3338,14 +3339,15 @@ func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>, func.return %0 : tensor<20x6xf32> } -// CHECK-LABEL: func @convert_scatter_update_with_non_trailing_update_window_dims( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<10x3xf32>) -> tensor<5x10xf32> { -// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_2]], %[[VAL_3]]) : (tensor<10x3xf32>, tensor<2xi64>) -> tensor<3x10xf32> -// CHECK: %[[VAL_5:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_1]], %[[VAL_4]]) : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<3x10xf32>) -> tensor<5x10xf32> -// CHECK: return %[[VAL_5]] : tensor<5x10xf32> +// CHECK-LABEL: func.func @convert_scatter_update_with_non_trailing_update_window_dims( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<10x3xf32>) -> tensor<5x10xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_5]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<10x3xf32>) -> tensor<5x10xf32> +// CHECK: return %[[VAL_3]] : tensor<5x10xf32> // CHECK: } func.func @convert_scatter_update_with_non_trailing_update_window_dims( %arg0: tensor<5x10xf32>, @@ -3367,16 +3369,15 @@ func.func @convert_scatter_update_with_non_trailing_update_window_dims( func.return %0 : tensor<5x10xf32> } -// CHECK-LABEL: func @convert_scatter_update_to_non_trailing_operand_dimensions( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x3x7xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> { -// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 3, 0, 2]> : tensor<4xi64>} : () -> tensor<4xi64> -// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<5x4x3x7xf32>, tensor<4xi64>) -> tensor<4x7x5x3xf32> -// CHECK: %[[VAL_5:.*]] = "tf.TensorScatterUpdate"(%[[VAL_4]], %[[VAL_1]], %[[VAL_2]]) : (tensor<4x7x5x3xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<4x7x5x3xf32> -// CHECK: %[[VAL_6:.*]] = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> -// CHECK: %[[VAL_7:.*]] = "tf.Transpose"(%[[VAL_5]], %[[VAL_6]]) : (tensor<4x7x5x3xf32>, tensor<4xi64>) -> tensor<5x4x3x7xf32> -// CHECK: return %[[VAL_7]] : tensor<5x4x3x7xf32> +// CHECK-LABEL: func.func @convert_scatter_update_to_non_trailing_operand_dimensions( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x3x7xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_5]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<5x4x3x7xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> +// CHECK: return %[[VAL_3]] : tensor<5x4x3x7xf32> // CHECK: } func.func @convert_scatter_update_to_non_trailing_operand_dimensions( %arg0: tensor<5x4x3x7xf32>, @@ -3397,20 +3398,15 @@ func.func @convert_scatter_update_to_non_trailing_operand_dimensions( func.return %0 : tensor<5x4x3x7xf32> } -// CHECK-LABEL: func @convert_scatter_update_reshape_indices_and_updates( -// CHECK-SAME: %[[ARG_0:.*]]: tensor<16x1504xf32>, -// CHECK-SAME: %[[ARG_1:.*]]: tensor<1xi32>, -// CHECK-SAME: %[[ARG_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> { -// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: %[[VAL_0:.*]] = "tf.Transpose"(%[[ARG_0]], %[[CST]]) : (tensor<16x1504xf32>, tensor<2xi64>) -> tensor<1504x16xf32> -// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> -// CHECK: %[[VAL_1:.*]] = "tf.Reshape"(%[[ARG_1]], %[[CST_0]]) : (tensor<1xi32>, tensor<2xi32>) -> tensor<1x1xi32> -// CHECK: %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 16]> : tensor<2xi32>} : () -> tensor<2xi32> -// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[ARG_2]], %[[CST_1]]) : (tensor<16xf32>, tensor<2xi32>) -> tensor<1x16xf32> -// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1504x16xf32>, tensor<1x1xi32>, tensor<1x16xf32>) -> tensor<1504x16xf32> -// CHECK: %[[CST_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> -// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_3]], %[[CST_2]]) : (tensor<1504x16xf32>, tensor<2xi64>) -> tensor<16x1504xf32> -// CHECK: return %[[VAL_4]] +// CHECK-LABEL: func.func @convert_scatter_update_reshape_indices_and_updates( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x1504xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: mhlo.return %[[VAL_5]] : tensor +// CHECK: }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32> +// CHECK: return %[[VAL_3]] : tensor<16x1504xf32> // CHECK: } func.func @convert_scatter_update_reshape_indices_and_updates( %arg0: tensor<16x1504xf32>, @@ -3430,11 +3426,15 @@ func.func @convert_scatter_update_reshape_indices_and_updates( func.return %0 : tensor<16x1504xf32> } -// CHECK-LABEL: func @convert_scatter_add( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterAdd"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32> +// CHECK-LABEL: func.func @convert_scatter_add( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: %[[VAL_6:.*]] = "tf.AddV2"(%[[VAL_4]], %[[VAL_5]]) : (tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_6]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_3]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -3454,11 +3454,15 @@ func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, func.return %0 : tensor<20x6xf32> } -// CHECK-LABEL: func @convert_scatter_max( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterMax"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32> +// CHECK-LABEL: func.func @convert_scatter_max( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: %[[VAL_6:.*]] = "tf.Maximum"(%[[VAL_4]], %[[VAL_5]]) : (tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_6]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_3]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -3478,11 +3482,15 @@ func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, func.return %0 : tensor<20x6xf32> } -// CHECK-LABEL: func @convert_scatter_min( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterMin"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32> +// CHECK-LABEL: func.func @convert_scatter_min( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_6]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_3]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { @@ -3502,11 +3510,15 @@ func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, func.return %0 : tensor<20x6xf32> } -// CHECK-LABEL: func @convert_scatter_sub( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { -// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterSub"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32> +// CHECK-LABEL: func.func @convert_scatter_sub( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> { +// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({ +// CHECK: ^bb0(%[[VAL_4:.*]]: tensor, %[[VAL_5:.*]]: tensor): +// CHECK: %[[VAL_6:.*]] = "tf.Sub"(%[[VAL_4]], %[[VAL_5]]) : (tensor, tensor) -> tensor +// CHECK: mhlo.return %[[VAL_6]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32> // CHECK: return %[[VAL_3]] : tensor<20x6xf32> // CHECK: } func.func @convert_scatter_sub(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 140d47b13a919d105aa31dd7a16a67c5a2095c52..3bd13fbec92cf608fa0c118ba0f83a9fae0e9743 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -59,6 +59,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/BroadcastUtils.h" // from @stablehlo #include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -110,65 +111,6 @@ LogicalResult GetConstantSplatValue(Value value, SplatValueType& splat_value) { return success(); } -struct PermutationAndShape { - DenseIntElementsAttr permutation; - ShapedType shape; -}; - -// Returns a DenseIntElementsAttr for a permutation and the shape after -// applying the permutation to a given shape through a transpose. -PermutationAndShape GetPermutationAndTransposedShape( - llvm::ArrayRef permutation_array, ShapedType input_type, - ConversionPatternRewriter& rewriter) { - assert(permutation_array.size() == input_type.getRank()); - llvm::SmallVector transposed_shape(permutation_array.size()); - for (int64_t i = 0; i < permutation_array.size(); ++i) { - transposed_shape[i] = input_type.getDimSize(permutation_array[i]); - } - auto transposed_type = - RankedTensorType::get(transposed_shape, input_type.getElementType()); - DenseIntElementsAttr permutation = DenseIntElementsAttr::get( - RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()), - permutation_array); - return {permutation, transposed_type}; -} - -// Returns the inverse permutation array for a permutation array. -llvm::SmallVector GetInversePermutationArray( - llvm::ArrayRef permutation_array) { - llvm::SmallVector inverse_permutation_array( - permutation_array.size()); - const auto permutation_array_size = permutation_array.size(); - for (int64_t i = 0; i < permutation_array_size; ++i) { - inverse_permutation_array[permutation_array[i]] = i; - } - return inverse_permutation_array; -} - -// Returns the DenseIntElementsAttr for an inverse permutation given a -// permutation_array. -DenseIntElementsAttr GetInversePermutation( - llvm::ArrayRef permutation_array, - ConversionPatternRewriter& rewriter) { - SmallVector inverse_permutation_array = - GetInversePermutationArray(permutation_array); - return DenseIntElementsAttr::get( - RankedTensorType::get(inverse_permutation_array.size(), - rewriter.getI64Type()), - inverse_permutation_array); -} - -// Returns a DenseIntElementsAttr for an inverse permutation and the shape after -// applying the inverse permutation to a given shape through a transpose. -PermutationAndShape GetInversePermutationAndShape( - llvm::ArrayRef permutation_array, ShapedType input_type, - ConversionPatternRewriter& rewriter) { - SmallVector inverse_permutation_array = - GetInversePermutationArray(permutation_array); - return GetPermutationAndTransposedShape(inverse_permutation_array, input_type, - rewriter); -} - // Common functionality for ConvertConvOp classes. template struct ConvertNdConvOp { @@ -1165,33 +1107,6 @@ struct DimensionVector { llvm::SmallVector sizes; }; -// Create a single const integer. -Value BuildIntConstOp(ImplicitLocOpBuilder& builder, - ConversionPatternRewriter& rewriter, int64_t const_value, - Type type) { - Value result_const = - builder.create(rewriter.getIntegerAttr(type, const_value)); - return result_const; -} -// Create a const integer vector tensor (1-dim). -Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder, - ConversionPatternRewriter& rewriter, - ArrayRef const_value, Type type) { - DenseIntElementsAttr const_value_raw; - if (type == rewriter.getI64Type()) { - const_value_raw = rewriter.getI64TensorAttr(const_value); - } else { - // Convert I64 const array to I32. - llvm::SmallVector const_i32_vec; - for (auto element : const_value) { - const_i32_vec.push_back(static_cast(element)); - } - const_value_raw = rewriter.getI32TensorAttr(const_i32_vec); - } - Value result_const = builder.create(const_value_raw); - return result_const; -} - // Create a tensor that is reshaped from input. Value BuildReshapeOp(ImplicitLocOpBuilder& builder, ConversionPatternRewriter& rewriter, Value input, @@ -1925,43 +1840,6 @@ Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { dot_general_op.getLoc()); } -// Checks if the specified region is a binary reduction function that takes 2 -// inputs, passes it to an instance of the specified reduction op and then -// returns the result. -template -LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { - Block& body = function.front(); - if (body.getNumArguments() != 2) return failure(); - - mhlo::ReturnOp return_op = dyn_cast(body.back()); - if (!return_op) return failure(); - if (return_op.getNumOperands() != 1) return failure(); - - ReductionOp reduce_op = dyn_cast_or_null( - return_op.getOperands().front().getDefiningOp()); - if (!reduce_op) return failure(); - if (reduce_op.getLhs() != body.getArgument(0) || - reduce_op.getRhs() != body.getArgument(1)) - return failure(); - - return success(); -} - -// Check if the specified region is a binary reduction function that takes 2 -// inputs and returns the second input. Functions like this are used by update -// scatter like ops. -template <> -LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { - Block& body = function.front(); - if (body.getNumArguments() != 2) return failure(); - - mhlo::ReturnOp return_op = dyn_cast(body.back()); - if (!return_op) return failure(); - if (return_op.getNumOperands() != 1) return failure(); - if (return_op.getOperands().front() != body.getArgument(1)) return failure(); - return success(); -} - // Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the // init value doesn't match the expectation of TfReduceOp. template @@ -3019,124 +2897,6 @@ bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr, return false; } -// Check that `arr` is an R1 iota with integer element type starting from `0` -// with `size` number of values. -bool IsIotaAttr(ArrayRef arr, int64_t size) { - if (arr.size() != size) return false; - int64_t iota = 0; - for (auto s : arr) { - if (s != iota) return false; - ++iota; - } - return true; -} - -// Convert updates into canonical form as expected by tf.scatter ops. -// -// tf.scatter expects `update_window_dims` to be the trailing dimensions. -// -// To support scatter ops generated by numpy-like slice updates: -// nd_array[:, [i,j]] = [i_values, j_values] -// -// `updates` must be transposed when the update_window_dims are the leading -// dimensions of `updates`. -// -// Other values of `update_window_dims` are left unsupported. -// -// Eg 1. An update in canonical form: -// * indices shape(A,B,C) -// * updates shape(A,B,D,E,F) -// Then: -// * D,E,F are the update window dims [2,3,4] -// * C is the index vector dimension -// * A,B iterate over the updates and indices -// -// If `update_window_dims` are not the trailing dimensions then updates must be -// transposed. -// -// Eg 2. An update in non-canonical form: -// * indices shape(a,b,c) -// * updates shape(d,e,f,a,b) -// Then: -// * d,e,f are the update window dims [0,1,2] -// * c is the index vector dimension -// * a,b iterate over the updates and indices -// -// The update needs permuting to be in the form (a,b,d,e,f) so that the update -// window dims are the trailing dimensions. -// -// To canonicalize the updates above, replace the updates with: -// transpose(updates, permutation={3,4,0,1,2}) -// -// Note: NormalizeIndexVector is assumed to have run on the indices already so -// that the index_vector_dim is the trailing dimension in `indices`. -LogicalResult CanonicalizeScatterUpdates( - Operation* scatter_op, llvm::ArrayRef update_window_dims, - const Value& indices, const ShapedType& indices_type, Value& updates, - ShapedType& updates_type, ConversionPatternRewriter& rewriter) { - auto canonical_update_window_dims = llvm::to_vector( - llvm::seq(indices_type.getRank() - 1, updates_type.getRank())); - - if (canonical_update_window_dims == update_window_dims) return success(); - - // Permute updates if `update_window_dims` are leading indices. - // Other possibilities for `update_window_dims` are not supported yet. - if (!IsIotaAttr(update_window_dims, update_window_dims.size())) - return rewriter.notifyMatchFailure( - scatter_op, "update_window_dims are not leading or trailing indices"); - - SmallVector permutation_array(updates_type.getRank()); - int64_t dim = 0; - // Move leading indices to the back of the array. - const auto permutation_array_size = permutation_array.size(); - for (int64_t i = update_window_dims.size(); i < permutation_array_size; ++i) { - permutation_array[i] = dim; - ++dim; - } - // Move trailing indices to the front of the array. - for (int64_t i = 0; i < update_window_dims.size(); ++i) { - permutation_array[i] = dim; - ++dim; - } - - auto permutation_and_shape = GetPermutationAndTransposedShape( - permutation_array, updates_type, rewriter); - - auto transposed_updates = rewriter.create( - scatter_op->getLoc(), permutation_and_shape.shape, updates, - permutation_and_shape.permutation); - - updates = transposed_updates; - updates_type = permutation_and_shape.shape; - return success(); -} - -// If index_vector_dim == indices.rank() then insert the implicit extra -// dimension into indices to normalize everything to index_vector_dim == -// indices.rank() - 1. -LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices, - ShapedType& indices_type, - int64_t index_vector_dim, - ConversionPatternRewriter& rewriter) { - if (index_vector_dim == indices_type.getRank()) { - llvm::SmallVector new_start_indices_shape( - indices_type.getShape().begin(), indices_type.getShape().end()); - new_start_indices_shape.push_back(1); - indices_type = RankedTensorType::get(new_start_indices_shape, - indices_type.getElementType()); - indices = rewriter.create(parent_op->getLoc(), - indices_type, indices); - } else if (index_vector_dim != indices_type.getRank() - 1) { - // If index_vector_dim isn't the last dimension in indices then it isn't - // supported yet. - // TODO(tberghammer): Transpose indices to support this usecase. - return rewriter.notifyMatchFailure( - parent_op, - "index vector dim isn't the last dimension in start indices"); - } - return success(); -} - class ConvertGatherOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -3551,157 +3311,6 @@ class ConvertIfOp : public OpConversionPattern { } }; -template -class ConvertScatterOp : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ScatterOp scatter_op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - OperandRange operands = scatter_op.getInputs(); - Value indices = scatter_op.getScatterIndices(); - OperandRange updates = scatter_op.getUpdates(); - if (operands.size() != 1 || updates.size() != 1) return failure(); - - ShapedType operand_type = operands[0].getType().cast(); - ShapedType indices_type = indices.getType().cast(); - ShapedType updates_type = updates[0].getType().cast(); - - Value new_updates = updates[0]; - - // Can only convert with static shaped scatter. - if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() || - !updates_type.hasStaticShape()) { - return failure(); - } - - // Match the scatter computation against computations supported by TF. - if (failed(MatchBinaryReduceFunction( - scatter_op.getUpdateComputation()))) { - return failure(); - } - - auto scatter_dimension_numbers = scatter_op.getScatterDimensionNumbers(); - - // Normalize indices so index_vector_dim == indices.rank() - 1. - int64_t index_vector_dim = scatter_dimension_numbers.getIndexVectorDim(); - if (failed(NormalizeIndexVector(scatter_op, indices, indices_type, - index_vector_dim, rewriter))) { - return failure(); - } - - // Transform updates so that update window dims are the trailing dimensions - // in the update tensor. - auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims(); - if (failed(CanonicalizeScatterUpdates(scatter_op, update_window_dims, - indices, indices_type, new_updates, - updates_type, rewriter))) { - return failure(); - } - - auto inserted_window_dims = - scatter_dimension_numbers.getInsertedWindowDims(); - auto scatter_dims_to_operand_dims = - scatter_dimension_numbers.getScatterDimsToOperandDims(); - - if (IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) && - IsIotaAttr(scatter_dims_to_operand_dims, - indices_type.getShape().back())) { - rewriter.replaceOpWithNewOp(scatter_op, - scatter_op.getResult(0).getType(), - operands[0], indices, new_updates); - return success(); - } - // Insert tranposes to support scatter operations generated from - // numpy-like slice operations: - // nd_array[:, [i,j]] = [i_values, j_values] - // - if (scatter_dims_to_operand_dims != inserted_window_dims) { - // Support only dimension numbers generated by numpy-like slice - // operations. - return rewriter.notifyMatchFailure( - scatter_op, "unsupported scatter_dims_to_operand_dims"); - } - - // Transpose the operand and so that the trailing dimensions of the - // operand are being updated. Then apply a tf.scatter op and transpose - // back the result to get the same shape as the original operand. - - SmallVector permutation_array; - for (int64_t i = 0; i < scatter_dims_to_operand_dims.size(); ++i) { - permutation_array.push_back(scatter_dims_to_operand_dims[i]); - } - for (int64_t i = 0; i < operand_type.getRank(); ++i) { - if (!llvm::is_contained(scatter_dims_to_operand_dims, i)) { - permutation_array.push_back(i); - } - } - auto permutation_and_shape = GetPermutationAndTransposedShape( - permutation_array, operand_type, rewriter); - - Location loc = scatter_op.getLoc(); - auto transposed_operand = rewriter.create( - loc, permutation_and_shape.shape, operands[0], - permutation_and_shape.permutation); - - Value new_indices = indices; - int64_t index_depth = - permutation_and_shape.shape.getRank() - inserted_window_dims.size(); - int64_t num_updates = indices_type.getDimSize(0); - // For TF::TensorScatterUpdateOp, `indices` must have at least 2 axes: - // `(num_updates, index_depth)`. Reshape indices and updates if necessary. - if (std::is_same::value && - indices_type.getRank() == 1 && updates_type.getRank() == 1 && - index_depth == 1 && num_updates == 1) { - ImplicitLocOpBuilder builder(loc, rewriter); - auto indices_shape = BuildIntArrayConstOp( - builder, rewriter, - llvm::SmallVector({num_updates, index_depth}), - rewriter.getI32Type()); - new_indices = rewriter.create( - loc, - RankedTensorType::get({num_updates, index_depth}, - indices_type.getElementType()), - indices, indices_shape); - auto updates_shape = BuildIntArrayConstOp( - builder, rewriter, - llvm::SmallVector({num_updates, updates_type.getDimSize(0)}), - rewriter.getI32Type()); - new_updates = rewriter.create( - loc, - RankedTensorType::get({1, updates_type.getDimSize(0)}, - updates_type.getElementType()), - new_updates, updates_shape); - } - - // Apply TF scatter to update the trailing dimensions of the - // transposed operand. - auto tf_scatter_op = - rewriter.create(loc, permutation_and_shape.shape, - transposed_operand, new_indices, new_updates); - - // Reverse the earlier transpose. - auto inverse_permutation = - GetInversePermutation(permutation_array, rewriter); - rewriter.replaceOpWithNewOp( - scatter_op, scatter_op.getResult(0).getType(), tf_scatter_op, - inverse_permutation); - - return success(); - } -}; -using ConvertScatterAddOp = - ConvertScatterOp; -using ConvertScatterMaxOp = - ConvertScatterOp; -using ConvertScatterMinOp = - ConvertScatterOp; -using ConvertScatterSubOp = - ConvertScatterOp; -using ConvertScatterUpdateOp = - ConvertScatterOp; - // Converts mhlo.pad to tf.PadV2 Value ConvertPadOp(PatternRewriter& rewriter, Operation* old_op) { auto pad_op = cast(old_op); @@ -4140,19 +3749,18 @@ void LegalizeHloToTf::runOnOperation() { void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, MLIRContext* context) { - patterns->add< - ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp, - ConvertNonTrivialConvOp, ConvertDynamicSliceOp, - ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertIfOp, - ConvertMaxPoolOp, ConvertPopulationCountOp, ConvertScatterAddOp, - ConvertScatterMaxOp, ConvertScatterMinOp, ConvertScatterSubOp, - ConvertScatterUpdateOp, ConvertSliceOp, ConvertReduceOpToTfArgmax, - ConvertReduceOpToTfArgmin, ConvertReduceOpToTfMax, ConvertReduceOpToTfMin, - ConvertReduceOpToTfAll, ConvertReduceOpToTfProd, ConvertReduceOpToTfAny, - ConvertReduceOpToTfSum, ConvertSortToTfTopk, ConvertIotaOpToTfRange, - ConvertWhileOp, ConvertLoweredCumSumOp, ConvertLoweredCumProdOp, - ConvertGetDimensionSizeOp, ConvertDynamicIotaOp, - ConvertRealDynamicSliceOp>(context); + patterns + ->add(context); populateWithGenerated(*patterns); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f771ee191985cb37d9bae80f93edf53e9458a1b7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -0,0 +1,48 @@ +# Groups the implementations of the HLO to TF operation conversions. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "util", + srcs = [ + "util.cc", + ], + hdrs = [ + "util.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "scatter", + srcs = [ + "scatter.cc", + ], + hdrs = [ + "scatter.h", + ], + deps = [ + ":util", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.cc new file mode 100644 index 0000000000000000000000000000000000000000..6453fbf4467e7a1f1048a3265836dbd8244d84b8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h" + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +LogicalResult CanonicalizeScatterUpdates( + Operation* scatter_op, llvm::ArrayRef update_window_dims, + const Value& indices, const ShapedType& indices_type, Value& updates, + ShapedType& updates_type, ConversionPatternRewriter& rewriter) { + auto canonical_update_window_dims = llvm::to_vector( + llvm::seq(indices_type.getRank() - 1, updates_type.getRank())); + + if (canonical_update_window_dims == update_window_dims) return success(); + + // Permute updates if `update_window_dims` are leading indices. + // Other possibilities for `update_window_dims` are not supported yet. + if (!IsIotaAttr(update_window_dims, update_window_dims.size())) + return rewriter.notifyMatchFailure( + scatter_op, "update_window_dims are not leading or trailing indices"); + + SmallVector permutation_array(updates_type.getRank()); + int64_t dim = 0; + // Move leading indices to the back of the array. + const auto permutation_array_size = permutation_array.size(); + for (int64_t i = update_window_dims.size(); i < permutation_array_size; ++i) { + permutation_array[i] = dim; + ++dim; + } + // Move trailing indices to the front of the array. + for (int64_t i = 0; i < update_window_dims.size(); ++i) { + permutation_array[i] = dim; + ++dim; + } + + auto permutation_and_shape = GetPermutationAndTransposedShape( + permutation_array, updates_type, rewriter); + + auto transposed_updates = rewriter.create( + scatter_op->getLoc(), permutation_and_shape.shape, updates, + permutation_and_shape.permutation); + + updates = transposed_updates; + updates_type = permutation_and_shape.shape; + return success(); +} + +} // end namespace odml +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..fb0e0d80a4eb9bc9da31e0c5e09c5bbf3a494371 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h @@ -0,0 +1,242 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Convert updates into canonical form as expected by tf.scatter ops. +// +// tf.scatter expects `update_window_dims` to be the trailing dimensions. +// +// To support scatter ops generated by numpy-like slice updates: +// nd_array[:, [i,j]] = [i_values, j_values] +// +// `updates` must be transposed when the update_window_dims are the leading +// dimensions of `updates`. +// +// Other values of `update_window_dims` are left unsupported. +// +// Eg 1. An update in canonical form: +// * indices shape(A,B,C) +// * updates shape(A,B,D,E,F) +// Then: +// * D,E,F are the update window dims [2,3,4] +// * C is the index vector dimension +// * A,B iterate over the updates and indices +// +// If `update_window_dims` are not the trailing dimensions then updates must be +// transposed. +// +// Eg 2. An update in non-canonical form: +// * indices shape(a,b,c) +// * updates shape(d,e,f,a,b) +// Then: +// * d,e,f are the update window dims [0,1,2] +// * c is the index vector dimension +// * a,b iterate over the updates and indices +// +// The update needs permuting to be in the form (a,b,d,e,f) so that the update +// window dims are the trailing dimensions. +// +// To canonicalize the updates above, replace the updates with: +// transpose(updates, permutation={3,4,0,1,2}) +// +// Note: NormalizeIndexVector is assumed to have run on the indices already so +// that the index_vector_dim is the trailing dimension in `indices`. +LogicalResult CanonicalizeScatterUpdates( + Operation* scatter_op, llvm::ArrayRef update_window_dims, + const Value& indices, const ShapedType& indices_type, Value& updates, + ShapedType& updates_type, ConversionPatternRewriter& rewriter); + +template +class ConvertScatterOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ScatterOp scatter_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + { + OperandRange operands = scatter_op.getInputs(); + Value indices = scatter_op.getScatterIndices(); + OperandRange updates = scatter_op.getUpdates(); + if (operands.size() != 1 || updates.size() != 1) return failure(); + + ShapedType operand_type = operands[0].getType().cast(); + ShapedType indices_type = indices.getType().cast(); + ShapedType updates_type = updates[0].getType().cast(); + + Value new_updates = updates[0]; + + // Can only convert with static shaped scatter. + if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() || + !updates_type.hasStaticShape()) { + return failure(); + } + + // Match the scatter computation against computations supported by TF. + if (failed(MatchBinaryReduceFunction( + scatter_op.getUpdateComputation()))) { + return failure(); + } + + auto scatter_dimension_numbers = scatter_op.getScatterDimensionNumbers(); + + // Normalize indices so index_vector_dim == indices.rank() - 1. + int64_t index_vector_dim = scatter_dimension_numbers.getIndexVectorDim(); + if (failed(NormalizeIndexVector(scatter_op, indices, indices_type, + index_vector_dim, rewriter))) { + return failure(); + } + + // Transform updates so that update window dims are the trailing + // dimensions in the update tensor. + auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims(); + if (failed(CanonicalizeScatterUpdates(scatter_op, update_window_dims, + indices, indices_type, new_updates, + updates_type, rewriter))) { + return failure(); + } + + auto inserted_window_dims = + scatter_dimension_numbers.getInsertedWindowDims(); + auto scatter_dims_to_operand_dims = + scatter_dimension_numbers.getScatterDimsToOperandDims(); + + if (IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) && + IsIotaAttr(scatter_dims_to_operand_dims, + indices_type.getShape().back())) { + rewriter.replaceOpWithNewOp(scatter_op, + scatter_op.getResult(0).getType(), + operands[0], indices, new_updates); + return success(); + } + // Insert tranposes to support scatter operations generated from + // numpy-like slice operations: + // nd_array[:, [i,j]] = [i_values, j_values] + // + if (scatter_dims_to_operand_dims != inserted_window_dims) { + // Support only dimension numbers generated by numpy-like slice + // operations. + return rewriter.notifyMatchFailure( + scatter_op, "unsupported scatter_dims_to_operand_dims"); + } + + // Transpose the operand and so that the trailing dimensions of the + // operand are being updated. Then apply a tf.scatter op and transpose + // back the result to get the same shape as the original operand. + + SmallVector permutation_array; + for (int64_t i = 0; i < scatter_dims_to_operand_dims.size(); ++i) { + permutation_array.push_back(scatter_dims_to_operand_dims[i]); + } + for (int64_t i = 0; i < operand_type.getRank(); ++i) { + if (!llvm::is_contained(scatter_dims_to_operand_dims, i)) { + permutation_array.push_back(i); + } + } + auto permutation_and_shape = GetPermutationAndTransposedShape( + permutation_array, operand_type, rewriter); + + Location loc = scatter_op.getLoc(); + auto transposed_operand = rewriter.create( + loc, permutation_and_shape.shape, operands[0], + permutation_and_shape.permutation); + + Value new_indices = indices; + int64_t index_depth = + permutation_and_shape.shape.getRank() - inserted_window_dims.size(); + int64_t num_updates = indices_type.getDimSize(0); + // For TF::TensorScatterUpdateOp, `indices` must have at least 2 axes: + // `(num_updates, index_depth)`. Reshape indices and updates if necessary. + if (std::is_same::value && + indices_type.getRank() == 1 && updates_type.getRank() == 1 && + index_depth == 1 && num_updates == 1) { + ImplicitLocOpBuilder builder(loc, rewriter); + auto indices_shape = BuildIntArrayConstOp( + builder, rewriter, + llvm::SmallVector({num_updates, index_depth}), + rewriter.getI32Type()); + new_indices = rewriter.create( + loc, + RankedTensorType::get({num_updates, index_depth}, + indices_type.getElementType()), + indices, indices_shape); + auto updates_shape = + BuildIntArrayConstOp(builder, rewriter, + llvm::SmallVector( + {num_updates, updates_type.getDimSize(0)}), + rewriter.getI32Type()); + new_updates = rewriter.create( + loc, + RankedTensorType::get({1, updates_type.getDimSize(0)}, + updates_type.getElementType()), + new_updates, updates_shape); + } + + // Apply TF scatter to update the trailing dimensions of the + // transposed operand. + auto tf_scatter_op = + rewriter.create(loc, permutation_and_shape.shape, + transposed_operand, new_indices, new_updates); + + // Reverse the earlier transpose. + auto inverse_permutation = + GetInversePermutation(permutation_array, rewriter); + rewriter.replaceOpWithNewOp( + scatter_op, scatter_op.getResult(0).getType(), tf_scatter_op, + inverse_permutation); + + return success(); + } + } +}; + +using ConvertScatterAddOp = + ConvertScatterOp; +using ConvertScatterMaxOp = + ConvertScatterOp; +using ConvertScatterMinOp = + ConvertScatterOp; +using ConvertScatterSubOp = + ConvertScatterOp; +using ConvertScatterUpdateOp = + ConvertScatterOp; + +} // end namespace odml +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cd9a689f8e780b7fb2c1753aa2cd5e69c5b1a03 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc @@ -0,0 +1,162 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" + +#include + +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +bool IsIotaAttr(ArrayRef arr, int64_t size) { + if (arr.size() != size) return false; + int64_t iota = 0; + for (auto s : arr) { + if (s != iota) return false; + ++iota; + } + return true; +} + +PermutationAndShape GetPermutationAndTransposedShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter& rewriter) { + assert(permutation_array.size() == input_type.getRank()); + llvm::SmallVector transposed_shape(permutation_array.size()); + for (int64_t i = 0; i < permutation_array.size(); ++i) { + transposed_shape[i] = input_type.getDimSize(permutation_array[i]); + } + auto transposed_type = + RankedTensorType::get(transposed_shape, input_type.getElementType()); + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()), + permutation_array); + return {permutation, transposed_type}; +} + +Value BuildIntConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, int64_t const_value, + Type type) { + Value result_const = + builder.create(rewriter.getIntegerAttr(type, const_value)); + return result_const; +} + +Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, + ArrayRef const_value, Type type) { + DenseIntElementsAttr const_value_raw; + if (type == rewriter.getI64Type()) { + const_value_raw = rewriter.getI64TensorAttr(const_value); + } else { + // Convert I64 const array to I32. + llvm::SmallVector const_i32_vec; + for (auto element : const_value) { + const_i32_vec.push_back(static_cast(element)); + } + const_value_raw = rewriter.getI32TensorAttr(const_i32_vec); + } + Value result_const = builder.create(const_value_raw); + return result_const; +} + +llvm::SmallVector GetInversePermutationArray( + llvm::ArrayRef permutation_array) { + llvm::SmallVector inverse_permutation_array( + permutation_array.size()); + const auto permutation_array_size = permutation_array.size(); + for (int64_t i = 0; i < permutation_array_size; ++i) { + inverse_permutation_array[permutation_array[i]] = i; + } + return inverse_permutation_array; +} + +DenseIntElementsAttr GetInversePermutation( + llvm::ArrayRef permutation_array, + ConversionPatternRewriter& rewriter) { + SmallVector inverse_permutation_array = + GetInversePermutationArray(permutation_array); + return DenseIntElementsAttr::get( + RankedTensorType::get(inverse_permutation_array.size(), + rewriter.getI64Type()), + inverse_permutation_array); +} + +PermutationAndShape GetInversePermutationAndShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter& rewriter) { + SmallVector inverse_permutation_array = + GetInversePermutationArray(permutation_array); + return GetPermutationAndTransposedShape(inverse_permutation_array, input_type, + rewriter); +} + +LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices, + ShapedType& indices_type, + int64_t index_vector_dim, + ConversionPatternRewriter& rewriter) { + if (index_vector_dim == indices_type.getRank()) { + llvm::SmallVector new_start_indices_shape( + indices_type.getShape().begin(), indices_type.getShape().end()); + new_start_indices_shape.push_back(1); + indices_type = RankedTensorType::get(new_start_indices_shape, + indices_type.getElementType()); + indices = rewriter.create(parent_op->getLoc(), + indices_type, indices); + } else if (index_vector_dim != indices_type.getRank() - 1) { + // If index_vector_dim isn't the last dimension in indices then it isn't + // supported yet. + // TODO(tberghammer): Transpose indices to support this usecase. + return rewriter.notifyMatchFailure( + parent_op, + "index vector dim isn't the last dimension in start indices"); + } + return success(); +} + +// Check if the specified region is a binary reduction function that takes 2 +// inputs and returns the second input. Functions like this are used by update +// scatter like ops. +template <> +LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { + Block& body = function.front(); + if (body.getNumArguments() != 2) return failure(); + + mhlo::ReturnOp return_op = dyn_cast(body.back()); + if (!return_op) return failure(); + if (return_op.getNumOperands() != 1) return failure(); + if (return_op.getOperands().front() != body.getArgument(1)) return failure(); + return success(); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h new file mode 100644 index 0000000000000000000000000000000000000000..72485597d6c4ff0da3413e5b87cf9b65d5c1b783 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h @@ -0,0 +1,118 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_ + +#include + +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +struct PermutationAndShape { + DenseIntElementsAttr permutation; + ShapedType shape; +}; + +// Check that `arr` is an R1 iota with integer element type starting from `0` +// with `size` number of values. +bool IsIotaAttr(ArrayRef arr, int64_t size); + +// Returns a DenseIntElementsAttr for a permutation and the shape after +// applying the permutation to a given shape through a transpose. +PermutationAndShape GetPermutationAndTransposedShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter& rewriter); + +// Create a single const integer. +Value BuildIntConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, int64_t const_value, + Type type); + +// Create a const integer vector tensor (1-dim). +Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, + ArrayRef const_value, Type type); + +// Returns the inverse permutation array for a permutation array. +llvm::SmallVector GetInversePermutationArray( + llvm::ArrayRef permutation_array); + +// Returns the DenseIntElementsAttr for an inverse permutation given a +// permutation_array. +DenseIntElementsAttr GetInversePermutation( + llvm::ArrayRef permutation_array, + ConversionPatternRewriter& rewriter); + +// Returns a DenseIntElementsAttr for an inverse permutation and the shape after +// applying the inverse permutation to a given shape through a transpose. +PermutationAndShape GetInversePermutationAndShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter& rewriter); + +// If index_vector_dim == indices.rank() then insert the implicit extra +// dimension into indices to normalize everything to index_vector_dim == +// indices.rank() - 1. +LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices, + ShapedType& indices_type, + int64_t index_vector_dim, + ConversionPatternRewriter& rewriter); + +// Checks if the specified region is a binary reduction function that takes 2 +// inputs, passes it to an instance of the specified reduction op and then +// returns the result. +template +LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { + Block& body = function.front(); + if (body.getNumArguments() != 2) return failure(); + + mhlo::ReturnOp return_op = dyn_cast(body.back()); + if (!return_op) return failure(); + if (return_op.getNumOperands() != 1) return failure(); + + ReductionOp reduce_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!reduce_op) return failure(); + if (reduce_op.getLhs() != body.getArgument(0) || + reduce_op.getRhs() != body.getArgument(1)) + return failure(); + + return success(); +} + +// Check if the specified region is a binary reduction function that takes 2 +// inputs and returns the second input. Functions like this are used by update +// scatter like ops. +template <> +LogicalResult MatchBinaryReduceFunction(mlir::Region& function); +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_