提交 ace5d01e 编写于 作者: M Majid Dadashi 提交者: TensorFlower Gardener

Remove all the scatter re-write patterns from mhlo -> TF.

These patterns are used only by TFLite which now has direct support for stablehlo.scatter.

PiperOrigin-RevId: 564569917
上级 6e3904f9
......@@ -475,6 +475,7 @@ cc_library(
],
deps = [
":passes_inc_gen",
"//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_device_pass_inc_gen",
......
......@@ -3313,14 +3313,15 @@ func.func @convert_dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor<ui32
func.return %0 : tensor<4x2xf32>
}
// CHECK-LABEL: func @convert_scatter_update(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = arith.constant dense<[4, 1]> : tensor<2xi64>
// CHECK: %[[VAL_4:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_3]]) : {{.*}} -> tensor<4x1xi32>
// CHECK: %[[VAL_5:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32>
// CHECK: return %[[VAL_5]] : tensor<20x6xf32>
// CHECK-LABEL: func.func @convert_scatter_update(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK: return %[[VAL_3]] : tensor<20x6xf32>
// CHECK: }
func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
......@@ -3338,14 +3339,15 @@ func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>,
func.return %0 : tensor<20x6xf32>
}
// CHECK-LABEL: func @convert_scatter_update_with_non_trailing_update_window_dims(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<10x3xf32>) -> tensor<5x10xf32> {
// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_2]], %[[VAL_3]]) : (tensor<10x3xf32>, tensor<2xi64>) -> tensor<3x10xf32>
// CHECK: %[[VAL_5:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_1]], %[[VAL_4]]) : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<3x10xf32>) -> tensor<5x10xf32>
// CHECK: return %[[VAL_5]] : tensor<5x10xf32>
// CHECK-LABEL: func.func @convert_scatter_update_with_non_trailing_update_window_dims(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<10x3xf32>) -> tensor<5x10xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<10x3xf32>) -> tensor<5x10xf32>
// CHECK: return %[[VAL_3]] : tensor<5x10xf32>
// CHECK: }
func.func @convert_scatter_update_with_non_trailing_update_window_dims(
%arg0: tensor<5x10xf32>,
......@@ -3367,16 +3369,15 @@ func.func @convert_scatter_update_with_non_trailing_update_window_dims(
func.return %0 : tensor<5x10xf32>
}
// CHECK-LABEL: func @convert_scatter_update_to_non_trailing_operand_dimensions(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x3x7xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> {
// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<[1, 3, 0, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<5x4x3x7xf32>, tensor<4xi64>) -> tensor<4x7x5x3xf32>
// CHECK: %[[VAL_5:.*]] = "tf.TensorScatterUpdate"(%[[VAL_4]], %[[VAL_1]], %[[VAL_2]]) : (tensor<4x7x5x3xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<4x7x5x3xf32>
// CHECK: %[[VAL_6:.*]] = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[VAL_7:.*]] = "tf.Transpose"(%[[VAL_5]], %[[VAL_6]]) : (tensor<4x7x5x3xf32>, tensor<4xi64>) -> tensor<5x4x3x7xf32>
// CHECK: return %[[VAL_7]] : tensor<5x4x3x7xf32>
// CHECK-LABEL: func.func @convert_scatter_update_to_non_trailing_operand_dimensions(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x3x7xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1, 2], inserted_window_dims = [1, 3], scatter_dims_to_operand_dims = [1, 3], index_vector_dim = 1>, unique_indices = false} : (tensor<5x4x3x7xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32>
// CHECK: return %[[VAL_3]] : tensor<5x4x3x7xf32>
// CHECK: }
func.func @convert_scatter_update_to_non_trailing_operand_dimensions(
%arg0: tensor<5x4x3x7xf32>,
......@@ -3397,20 +3398,15 @@ func.func @convert_scatter_update_to_non_trailing_operand_dimensions(
func.return %0 : tensor<5x4x3x7xf32>
}
// CHECK-LABEL: func @convert_scatter_update_reshape_indices_and_updates(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<16x1504xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<1xi32>,
// CHECK-SAME: %[[ARG_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> {
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[VAL_0:.*]] = "tf.Transpose"(%[[ARG_0]], %[[CST]]) : (tensor<16x1504xf32>, tensor<2xi64>) -> tensor<1504x16xf32>
// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[VAL_1:.*]] = "tf.Reshape"(%[[ARG_1]], %[[CST_0]]) : (tensor<1xi32>, tensor<2xi32>) -> tensor<1x1xi32>
// CHECK: %[[CST_1:.*]] = "tf.Const"() {value = dense<[1, 16]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[ARG_2]], %[[CST_1]]) : (tensor<16xf32>, tensor<2xi32>) -> tensor<1x16xf32>
// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterUpdate"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1504x16xf32>, tensor<1x1xi32>, tensor<1x16xf32>) -> tensor<1504x16xf32>
// CHECK: %[[CST_2:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_3]], %[[CST_2]]) : (tensor<1504x16xf32>, tensor<2xi64>) -> tensor<16x1504xf32>
// CHECK: return %[[VAL_4]]
// CHECK-LABEL: func.func @convert_scatter_update_reshape_indices_and_updates(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x1504xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK: }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32>
// CHECK: return %[[VAL_3]] : tensor<16x1504xf32>
// CHECK: }
func.func @convert_scatter_update_reshape_indices_and_updates(
%arg0: tensor<16x1504xf32>,
......@@ -3430,11 +3426,15 @@ func.func @convert_scatter_update_reshape_indices_and_updates(
func.return %0 : tensor<16x1504xf32>
}
// CHECK-LABEL: func @convert_scatter_add(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterAdd"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32>
// CHECK-LABEL: func.func @convert_scatter_add(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: %[[VAL_6:.*]] = "tf.AddV2"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: mhlo.return %[[VAL_6]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK: return %[[VAL_3]] : tensor<20x6xf32>
// CHECK: }
func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
......@@ -3454,11 +3454,15 @@ func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>,
func.return %0 : tensor<20x6xf32>
}
// CHECK-LABEL: func @convert_scatter_max(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterMax"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32>
// CHECK-LABEL: func.func @convert_scatter_max(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: %[[VAL_6:.*]] = "tf.Maximum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: mhlo.return %[[VAL_6]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK: return %[[VAL_3]] : tensor<20x6xf32>
// CHECK: }
func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
......@@ -3478,11 +3482,15 @@ func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>,
func.return %0 : tensor<20x6xf32>
}
// CHECK-LABEL: func @convert_scatter_min(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterMin"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32>
// CHECK-LABEL: func.func @convert_scatter_min(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: mhlo.return %[[VAL_6]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK: return %[[VAL_3]] : tensor<20x6xf32>
// CHECK: }
func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
......@@ -3502,11 +3510,15 @@ func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>,
func.return %0 : tensor<20x6xf32>
}
// CHECK-LABEL: func @convert_scatter_sub(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "tf.TensorScatterSub"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<20x6xf32>
// CHECK-LABEL: func.func @convert_scatter_sub(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK: %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) ({
// CHECK: ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK: %[[VAL_6:.*]] = "tf.Sub"(%[[VAL_4]], %[[VAL_5]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: mhlo.return %[[VAL_6]] : tensor<f32>
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK: return %[[VAL_3]] : tensor<20x6xf32>
// CHECK: }
func.func @convert_scatter_sub(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
......
......@@ -59,6 +59,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "stablehlo/dialect/BroadcastUtils.h" // from @stablehlo
#include "stablehlo/dialect/ChloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
......@@ -110,65 +111,6 @@ LogicalResult GetConstantSplatValue(Value value, SplatValueType& splat_value) {
return success();
}
struct PermutationAndShape {
DenseIntElementsAttr permutation;
ShapedType shape;
};
// Returns a DenseIntElementsAttr for a permutation and the shape after
// applying the permutation to a given shape through a transpose.
PermutationAndShape GetPermutationAndTransposedShape(
llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
ConversionPatternRewriter& rewriter) {
assert(permutation_array.size() == input_type.getRank());
llvm::SmallVector<int64_t> transposed_shape(permutation_array.size());
for (int64_t i = 0; i < permutation_array.size(); ++i) {
transposed_shape[i] = input_type.getDimSize(permutation_array[i]);
}
auto transposed_type =
RankedTensorType::get(transposed_shape, input_type.getElementType());
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()),
permutation_array);
return {permutation, transposed_type};
}
// Returns the inverse permutation array for a permutation array.
llvm::SmallVector<int64_t> GetInversePermutationArray(
llvm::ArrayRef<int64_t> permutation_array) {
llvm::SmallVector<int64_t> inverse_permutation_array(
permutation_array.size());
const auto permutation_array_size = permutation_array.size();
for (int64_t i = 0; i < permutation_array_size; ++i) {
inverse_permutation_array[permutation_array[i]] = i;
}
return inverse_permutation_array;
}
// Returns the DenseIntElementsAttr for an inverse permutation given a
// permutation_array.
DenseIntElementsAttr GetInversePermutation(
llvm::ArrayRef<int64_t> permutation_array,
ConversionPatternRewriter& rewriter) {
SmallVector<int64_t, 4> inverse_permutation_array =
GetInversePermutationArray(permutation_array);
return DenseIntElementsAttr::get(
RankedTensorType::get(inverse_permutation_array.size(),
rewriter.getI64Type()),
inverse_permutation_array);
}
// Returns a DenseIntElementsAttr for an inverse permutation and the shape after
// applying the inverse permutation to a given shape through a transpose.
PermutationAndShape GetInversePermutationAndShape(
llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
ConversionPatternRewriter& rewriter) {
SmallVector<int64_t, 4> inverse_permutation_array =
GetInversePermutationArray(permutation_array);
return GetPermutationAndTransposedShape(inverse_permutation_array, input_type,
rewriter);
}
// Common functionality for ConvertConvOp classes.
template <int SupportedSpatialDims>
struct ConvertNdConvOp {
......@@ -1165,33 +1107,6 @@ struct DimensionVector {
llvm::SmallVector<int64_t, 4> sizes;
};
// Create a single const integer.
Value BuildIntConstOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter, int64_t const_value,
Type type) {
Value result_const =
builder.create<TF::ConstOp>(rewriter.getIntegerAttr(type, const_value));
return result_const;
}
// Create a const integer vector tensor (1-dim).
Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter,
ArrayRef<int64_t> const_value, Type type) {
DenseIntElementsAttr const_value_raw;
if (type == rewriter.getI64Type()) {
const_value_raw = rewriter.getI64TensorAttr(const_value);
} else {
// Convert I64 const array to I32.
llvm::SmallVector<int32_t> const_i32_vec;
for (auto element : const_value) {
const_i32_vec.push_back(static_cast<int32_t>(element));
}
const_value_raw = rewriter.getI32TensorAttr(const_i32_vec);
}
Value result_const = builder.create<TF::ConstOp>(const_value_raw);
return result_const;
}
// Create a tensor that is reshaped from input.
Value BuildReshapeOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter, Value input,
......@@ -1925,43 +1840,6 @@ Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) {
dot_general_op.getLoc());
}
// Checks if the specified region is a binary reduction function that takes 2
// inputs, passes it to an instance of the specified reduction op and then
// returns the result.
template <typename ReductionOp>
LogicalResult MatchBinaryReduceFunction(mlir::Region& function) {
Block& body = function.front();
if (body.getNumArguments() != 2) return failure();
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure();
if (return_op.getNumOperands() != 1) return failure();
ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
return_op.getOperands().front().getDefiningOp());
if (!reduce_op) return failure();
if (reduce_op.getLhs() != body.getArgument(0) ||
reduce_op.getRhs() != body.getArgument(1))
return failure();
return success();
}
// Check if the specified region is a binary reduction function that takes 2
// inputs and returns the second input. Functions like this are used by update
// scatter like ops.
template <>
LogicalResult MatchBinaryReduceFunction<void>(mlir::Region& function) {
Block& body = function.front();
if (body.getNumArguments() != 2) return failure();
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure();
if (return_op.getNumOperands() != 1) return failure();
if (return_op.getOperands().front() != body.getArgument(1)) return failure();
return success();
}
// Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the
// init value doesn't match the expectation of TfReduceOp.
template <typename TfReduceOp, typename TfBinOp>
......@@ -3019,124 +2897,6 @@ bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr,
return false;
}
// Check that `arr` is an R1 iota with integer element type starting from `0`
// with `size` number of values.
bool IsIotaAttr(ArrayRef<int64_t> arr, int64_t size) {
if (arr.size() != size) return false;
int64_t iota = 0;
for (auto s : arr) {
if (s != iota) return false;
++iota;
}
return true;
}
// Convert updates into canonical form as expected by tf.scatter ops.
//
// tf.scatter expects `update_window_dims` to be the trailing dimensions.
//
// To support scatter ops generated by numpy-like slice updates:
// nd_array[:, [i,j]] = [i_values, j_values]
//
// `updates` must be transposed when the update_window_dims are the leading
// dimensions of `updates`.
//
// Other values of `update_window_dims` are left unsupported.
//
// Eg 1. An update in canonical form:
// * indices shape(A,B,C)
// * updates shape(A,B,D,E,F)
// Then:
// * D,E,F are the update window dims [2,3,4]
// * C is the index vector dimension
// * A,B iterate over the updates and indices
//
// If `update_window_dims` are not the trailing dimensions then updates must be
// transposed.
//
// Eg 2. An update in non-canonical form:
// * indices shape(a,b,c)
// * updates shape(d,e,f,a,b)
// Then:
// * d,e,f are the update window dims [0,1,2]
// * c is the index vector dimension
// * a,b iterate over the updates and indices
//
// The update needs permuting to be in the form (a,b,d,e,f) so that the update
// window dims are the trailing dimensions.
//
// To canonicalize the updates above, replace the updates with:
// transpose(updates, permutation={3,4,0,1,2})
//
// Note: NormalizeIndexVector is assumed to have run on the indices already so
// that the index_vector_dim is the trailing dimension in `indices`.
LogicalResult CanonicalizeScatterUpdates(
Operation* scatter_op, llvm::ArrayRef<int64_t> update_window_dims,
const Value& indices, const ShapedType& indices_type, Value& updates,
ShapedType& updates_type, ConversionPatternRewriter& rewriter) {
auto canonical_update_window_dims = llvm::to_vector(
llvm::seq<int64_t>(indices_type.getRank() - 1, updates_type.getRank()));
if (canonical_update_window_dims == update_window_dims) return success();
// Permute updates if `update_window_dims` are leading indices.
// Other possibilities for `update_window_dims` are not supported yet.
if (!IsIotaAttr(update_window_dims, update_window_dims.size()))
return rewriter.notifyMatchFailure(
scatter_op, "update_window_dims are not leading or trailing indices");
SmallVector<int64_t, 4> permutation_array(updates_type.getRank());
int64_t dim = 0;
// Move leading indices to the back of the array.
const auto permutation_array_size = permutation_array.size();
for (int64_t i = update_window_dims.size(); i < permutation_array_size; ++i) {
permutation_array[i] = dim;
++dim;
}
// Move trailing indices to the front of the array.
for (int64_t i = 0; i < update_window_dims.size(); ++i) {
permutation_array[i] = dim;
++dim;
}
auto permutation_and_shape = GetPermutationAndTransposedShape(
permutation_array, updates_type, rewriter);
auto transposed_updates = rewriter.create<mhlo::TransposeOp>(
scatter_op->getLoc(), permutation_and_shape.shape, updates,
permutation_and_shape.permutation);
updates = transposed_updates;
updates_type = permutation_and_shape.shape;
return success();
}
// If index_vector_dim == indices.rank() then insert the implicit extra
// dimension into indices to normalize everything to index_vector_dim ==
// indices.rank() - 1.
LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices,
ShapedType& indices_type,
int64_t index_vector_dim,
ConversionPatternRewriter& rewriter) {
if (index_vector_dim == indices_type.getRank()) {
llvm::SmallVector<int64_t, 4> new_start_indices_shape(
indices_type.getShape().begin(), indices_type.getShape().end());
new_start_indices_shape.push_back(1);
indices_type = RankedTensorType::get(new_start_indices_shape,
indices_type.getElementType());
indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
indices_type, indices);
} else if (index_vector_dim != indices_type.getRank() - 1) {
// If index_vector_dim isn't the last dimension in indices then it isn't
// supported yet.
// TODO(tberghammer): Transpose indices to support this usecase.
return rewriter.notifyMatchFailure(
parent_op,
"index vector dim isn't the last dimension in start indices");
}
return success();
}
class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
public:
using OpConversionPattern::OpConversionPattern;
......@@ -3551,157 +3311,6 @@ class ConvertIfOp : public OpConversionPattern<mhlo::IfOp> {
}
};
template <typename BinaryOp, typename TfOp>
class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ScatterOp scatter_op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
OperandRange operands = scatter_op.getInputs();
Value indices = scatter_op.getScatterIndices();
OperandRange updates = scatter_op.getUpdates();
if (operands.size() != 1 || updates.size() != 1) return failure();
ShapedType operand_type = operands[0].getType().cast<ShapedType>();
ShapedType indices_type = indices.getType().cast<ShapedType>();
ShapedType updates_type = updates[0].getType().cast<ShapedType>();
Value new_updates = updates[0];
// Can only convert with static shaped scatter.
if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
!updates_type.hasStaticShape()) {
return failure();
}
// Match the scatter computation against computations supported by TF.
if (failed(MatchBinaryReduceFunction<BinaryOp>(
scatter_op.getUpdateComputation()))) {
return failure();
}
auto scatter_dimension_numbers = scatter_op.getScatterDimensionNumbers();
// Normalize indices so index_vector_dim == indices.rank() - 1.
int64_t index_vector_dim = scatter_dimension_numbers.getIndexVectorDim();
if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
index_vector_dim, rewriter))) {
return failure();
}
// Transform updates so that update window dims are the trailing dimensions
// in the update tensor.
auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims();
if (failed(CanonicalizeScatterUpdates(scatter_op, update_window_dims,
indices, indices_type, new_updates,
updates_type, rewriter))) {
return failure();
}
auto inserted_window_dims =
scatter_dimension_numbers.getInsertedWindowDims();
auto scatter_dims_to_operand_dims =
scatter_dimension_numbers.getScatterDimsToOperandDims();
if (IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) &&
IsIotaAttr(scatter_dims_to_operand_dims,
indices_type.getShape().back())) {
rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
scatter_op.getResult(0).getType(),
operands[0], indices, new_updates);
return success();
}
// Insert tranposes to support scatter operations generated from
// numpy-like slice operations:
// nd_array[:, [i,j]] = [i_values, j_values]
//
if (scatter_dims_to_operand_dims != inserted_window_dims) {
// Support only dimension numbers generated by numpy-like slice
// operations.
return rewriter.notifyMatchFailure(
scatter_op, "unsupported scatter_dims_to_operand_dims");
}
// Transpose the operand and so that the trailing dimensions of the
// operand are being updated. Then apply a tf.scatter op and transpose
// back the result to get the same shape as the original operand.
SmallVector<int64_t, 4> permutation_array;
for (int64_t i = 0; i < scatter_dims_to_operand_dims.size(); ++i) {
permutation_array.push_back(scatter_dims_to_operand_dims[i]);
}
for (int64_t i = 0; i < operand_type.getRank(); ++i) {
if (!llvm::is_contained(scatter_dims_to_operand_dims, i)) {
permutation_array.push_back(i);
}
}
auto permutation_and_shape = GetPermutationAndTransposedShape(
permutation_array, operand_type, rewriter);
Location loc = scatter_op.getLoc();
auto transposed_operand = rewriter.create<mhlo::TransposeOp>(
loc, permutation_and_shape.shape, operands[0],
permutation_and_shape.permutation);
Value new_indices = indices;
int64_t index_depth =
permutation_and_shape.shape.getRank() - inserted_window_dims.size();
int64_t num_updates = indices_type.getDimSize(0);
// For TF::TensorScatterUpdateOp, `indices` must have at least 2 axes:
// `(num_updates, index_depth)`. Reshape indices and updates if necessary.
if (std::is_same<TfOp, TF::TensorScatterUpdateOp>::value &&
indices_type.getRank() == 1 && updates_type.getRank() == 1 &&
index_depth == 1 && num_updates == 1) {
ImplicitLocOpBuilder builder(loc, rewriter);
auto indices_shape = BuildIntArrayConstOp(
builder, rewriter,
llvm::SmallVector<int64_t>({num_updates, index_depth}),
rewriter.getI32Type());
new_indices = rewriter.create<TF::ReshapeOp>(
loc,
RankedTensorType::get({num_updates, index_depth},
indices_type.getElementType()),
indices, indices_shape);
auto updates_shape = BuildIntArrayConstOp(
builder, rewriter,
llvm::SmallVector<int64_t>({num_updates, updates_type.getDimSize(0)}),
rewriter.getI32Type());
new_updates = rewriter.create<TF::ReshapeOp>(
loc,
RankedTensorType::get({1, updates_type.getDimSize(0)},
updates_type.getElementType()),
new_updates, updates_shape);
}
// Apply TF scatter to update the trailing dimensions of the
// transposed operand.
auto tf_scatter_op =
rewriter.create<TfOp>(loc, permutation_and_shape.shape,
transposed_operand, new_indices, new_updates);
// Reverse the earlier transpose.
auto inverse_permutation =
GetInversePermutation(permutation_array, rewriter);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
scatter_op, scatter_op.getResult(0).getType(), tf_scatter_op,
inverse_permutation);
return success();
}
};
using ConvertScatterAddOp =
ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
using ConvertScatterMaxOp =
ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
using ConvertScatterMinOp =
ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
using ConvertScatterSubOp =
ConvertScatterOp<mhlo::SubtractOp, TF::TensorScatterSubOp>;
using ConvertScatterUpdateOp =
ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
// Converts mhlo.pad to tf.PadV2
Value ConvertPadOp(PatternRewriter& rewriter, Operation* old_op) {
auto pad_op = cast<mhlo::PadOp>(old_op);
......@@ -4140,19 +3749,18 @@ void LegalizeHloToTf::runOnOperation() {
void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns,
MLIRContext* context) {
patterns->add<
ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp,
ConvertNonTrivialConvOp, ConvertDynamicSliceOp,
ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertIfOp,
ConvertMaxPoolOp, ConvertPopulationCountOp, ConvertScatterAddOp,
ConvertScatterMaxOp, ConvertScatterMinOp, ConvertScatterSubOp,
ConvertScatterUpdateOp, ConvertSliceOp, ConvertReduceOpToTfArgmax,
ConvertReduceOpToTfArgmin, ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
ConvertReduceOpToTfAll, ConvertReduceOpToTfProd, ConvertReduceOpToTfAny,
ConvertReduceOpToTfSum, ConvertSortToTfTopk, ConvertIotaOpToTfRange,
ConvertWhileOp, ConvertLoweredCumSumOp, ConvertLoweredCumProdOp,
ConvertGetDimensionSizeOp, ConvertDynamicIotaOp,
ConvertRealDynamicSliceOp>(context);
patterns
->add<ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp,
ConvertNonTrivialConvOp, ConvertDynamicSliceOp,
ConvertDynamicUpdateSliceOp, ConvertGatherOp, ConvertIfOp,
ConvertMaxPoolOp, ConvertPopulationCountOp, ConvertSliceOp,
ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
ConvertReduceOpToTfAll, ConvertReduceOpToTfProd,
ConvertReduceOpToTfAny, ConvertReduceOpToTfSum, ConvertSortToTfTopk,
ConvertIotaOpToTfRange, ConvertWhileOp, ConvertLoweredCumSumOp,
ConvertLoweredCumProdOp, ConvertGetDimensionSizeOp,
ConvertDynamicIotaOp, ConvertRealDynamicSliceOp>(context);
populateWithGenerated(*patterns);
}
......
# Groups the implementations of the HLO to TF operation conversions.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [
"//visibility:public",
],
licenses = ["notice"],
)
cc_library(
name = "util",
srcs = [
"util.cc",
],
hdrs = [
"util.h",
],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_a_m_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@local_xla//xla/mlir_hlo",
],
)
cc_library(
name = "scatter",
srcs = [
"scatter.cc",
],
hdrs = [
"scatter.h",
],
deps = [
":util",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@local_xla//xla/mlir_hlo",
],
)
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h"
#include <cstdint>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace odml {
LogicalResult CanonicalizeScatterUpdates(
Operation* scatter_op, llvm::ArrayRef<int64_t> update_window_dims,
const Value& indices, const ShapedType& indices_type, Value& updates,
ShapedType& updates_type, ConversionPatternRewriter& rewriter) {
auto canonical_update_window_dims = llvm::to_vector(
llvm::seq<int64_t>(indices_type.getRank() - 1, updates_type.getRank()));
if (canonical_update_window_dims == update_window_dims) return success();
// Permute updates if `update_window_dims` are leading indices.
// Other possibilities for `update_window_dims` are not supported yet.
if (!IsIotaAttr(update_window_dims, update_window_dims.size()))
return rewriter.notifyMatchFailure(
scatter_op, "update_window_dims are not leading or trailing indices");
SmallVector<int64_t, 4> permutation_array(updates_type.getRank());
int64_t dim = 0;
// Move leading indices to the back of the array.
const auto permutation_array_size = permutation_array.size();
for (int64_t i = update_window_dims.size(); i < permutation_array_size; ++i) {
permutation_array[i] = dim;
++dim;
}
// Move trailing indices to the front of the array.
for (int64_t i = 0; i < update_window_dims.size(); ++i) {
permutation_array[i] = dim;
++dim;
}
auto permutation_and_shape = GetPermutationAndTransposedShape(
permutation_array, updates_type, rewriter);
auto transposed_updates = rewriter.create<mhlo::TransposeOp>(
scatter_op->getLoc(), permutation_and_shape.shape, updates,
permutation_and_shape.permutation);
updates = transposed_updates;
updates_type = permutation_and_shape.shape;
return success();
}
} // end namespace odml
} // end namespace mlir
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_
#include <cstdint>
#include <type_traits>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace odml {
// Convert updates into canonical form as expected by tf.scatter ops.
//
// tf.scatter expects `update_window_dims` to be the trailing dimensions.
//
// To support scatter ops generated by numpy-like slice updates:
// nd_array[:, [i,j]] = [i_values, j_values]
//
// `updates` must be transposed when the update_window_dims are the leading
// dimensions of `updates`.
//
// Other values of `update_window_dims` are left unsupported.
//
// Eg 1. An update in canonical form:
// * indices shape(A,B,C)
// * updates shape(A,B,D,E,F)
// Then:
// * D,E,F are the update window dims [2,3,4]
// * C is the index vector dimension
// * A,B iterate over the updates and indices
//
// If `update_window_dims` are not the trailing dimensions then updates must be
// transposed.
//
// Eg 2. An update in non-canonical form:
// * indices shape(a,b,c)
// * updates shape(d,e,f,a,b)
// Then:
// * d,e,f are the update window dims [0,1,2]
// * c is the index vector dimension
// * a,b iterate over the updates and indices
//
// The update needs permuting to be in the form (a,b,d,e,f) so that the update
// window dims are the trailing dimensions.
//
// To canonicalize the updates above, replace the updates with:
// transpose(updates, permutation={3,4,0,1,2})
//
// Note: NormalizeIndexVector is assumed to have run on the indices already so
// that the index_vector_dim is the trailing dimension in `indices`.
LogicalResult CanonicalizeScatterUpdates(
Operation* scatter_op, llvm::ArrayRef<int64_t> update_window_dims,
const Value& indices, const ShapedType& indices_type, Value& updates,
ShapedType& updates_type, ConversionPatternRewriter& rewriter);
template <typename BinaryOp, typename TfOp>
class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ScatterOp scatter_op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
{
OperandRange operands = scatter_op.getInputs();
Value indices = scatter_op.getScatterIndices();
OperandRange updates = scatter_op.getUpdates();
if (operands.size() != 1 || updates.size() != 1) return failure();
ShapedType operand_type = operands[0].getType().cast<ShapedType>();
ShapedType indices_type = indices.getType().cast<ShapedType>();
ShapedType updates_type = updates[0].getType().cast<ShapedType>();
Value new_updates = updates[0];
// Can only convert with static shaped scatter.
if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
!updates_type.hasStaticShape()) {
return failure();
}
// Match the scatter computation against computations supported by TF.
if (failed(MatchBinaryReduceFunction<BinaryOp>(
scatter_op.getUpdateComputation()))) {
return failure();
}
auto scatter_dimension_numbers = scatter_op.getScatterDimensionNumbers();
// Normalize indices so index_vector_dim == indices.rank() - 1.
int64_t index_vector_dim = scatter_dimension_numbers.getIndexVectorDim();
if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
index_vector_dim, rewriter))) {
return failure();
}
// Transform updates so that update window dims are the trailing
// dimensions in the update tensor.
auto update_window_dims = scatter_dimension_numbers.getUpdateWindowDims();
if (failed(CanonicalizeScatterUpdates(scatter_op, update_window_dims,
indices, indices_type, new_updates,
updates_type, rewriter))) {
return failure();
}
auto inserted_window_dims =
scatter_dimension_numbers.getInsertedWindowDims();
auto scatter_dims_to_operand_dims =
scatter_dimension_numbers.getScatterDimsToOperandDims();
if (IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) &&
IsIotaAttr(scatter_dims_to_operand_dims,
indices_type.getShape().back())) {
rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
scatter_op.getResult(0).getType(),
operands[0], indices, new_updates);
return success();
}
// Insert tranposes to support scatter operations generated from
// numpy-like slice operations:
// nd_array[:, [i,j]] = [i_values, j_values]
//
if (scatter_dims_to_operand_dims != inserted_window_dims) {
// Support only dimension numbers generated by numpy-like slice
// operations.
return rewriter.notifyMatchFailure(
scatter_op, "unsupported scatter_dims_to_operand_dims");
}
// Transpose the operand and so that the trailing dimensions of the
// operand are being updated. Then apply a tf.scatter op and transpose
// back the result to get the same shape as the original operand.
SmallVector<int64_t, 4> permutation_array;
for (int64_t i = 0; i < scatter_dims_to_operand_dims.size(); ++i) {
permutation_array.push_back(scatter_dims_to_operand_dims[i]);
}
for (int64_t i = 0; i < operand_type.getRank(); ++i) {
if (!llvm::is_contained(scatter_dims_to_operand_dims, i)) {
permutation_array.push_back(i);
}
}
auto permutation_and_shape = GetPermutationAndTransposedShape(
permutation_array, operand_type, rewriter);
Location loc = scatter_op.getLoc();
auto transposed_operand = rewriter.create<mhlo::TransposeOp>(
loc, permutation_and_shape.shape, operands[0],
permutation_and_shape.permutation);
Value new_indices = indices;
int64_t index_depth =
permutation_and_shape.shape.getRank() - inserted_window_dims.size();
int64_t num_updates = indices_type.getDimSize(0);
// For TF::TensorScatterUpdateOp, `indices` must have at least 2 axes:
// `(num_updates, index_depth)`. Reshape indices and updates if necessary.
if (std::is_same<TfOp, TF::TensorScatterUpdateOp>::value &&
indices_type.getRank() == 1 && updates_type.getRank() == 1 &&
index_depth == 1 && num_updates == 1) {
ImplicitLocOpBuilder builder(loc, rewriter);
auto indices_shape = BuildIntArrayConstOp(
builder, rewriter,
llvm::SmallVector<int64_t>({num_updates, index_depth}),
rewriter.getI32Type());
new_indices = rewriter.create<TF::ReshapeOp>(
loc,
RankedTensorType::get({num_updates, index_depth},
indices_type.getElementType()),
indices, indices_shape);
auto updates_shape =
BuildIntArrayConstOp(builder, rewriter,
llvm::SmallVector<int64_t>(
{num_updates, updates_type.getDimSize(0)}),
rewriter.getI32Type());
new_updates = rewriter.create<TF::ReshapeOp>(
loc,
RankedTensorType::get({1, updates_type.getDimSize(0)},
updates_type.getElementType()),
new_updates, updates_shape);
}
// Apply TF scatter to update the trailing dimensions of the
// transposed operand.
auto tf_scatter_op =
rewriter.create<TfOp>(loc, permutation_and_shape.shape,
transposed_operand, new_indices, new_updates);
// Reverse the earlier transpose.
auto inverse_permutation =
GetInversePermutation(permutation_array, rewriter);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(
scatter_op, scatter_op.getResult(0).getType(), tf_scatter_op,
inverse_permutation);
return success();
}
}
};
using ConvertScatterAddOp =
ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
using ConvertScatterMaxOp =
ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
using ConvertScatterMinOp =
ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
using ConvertScatterSubOp =
ConvertScatterOp<mhlo::SubtractOp, TF::TensorScatterSubOp>;
using ConvertScatterUpdateOp =
ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
} // end namespace odml
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include <stdint.h>
#include <cassert>
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace odml {
bool IsIotaAttr(ArrayRef<int64_t> arr, int64_t size) {
if (arr.size() != size) return false;
int64_t iota = 0;
for (auto s : arr) {
if (s != iota) return false;
++iota;
}
return true;
}
PermutationAndShape GetPermutationAndTransposedShape(
llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
ConversionPatternRewriter& rewriter) {
assert(permutation_array.size() == input_type.getRank());
llvm::SmallVector<int64_t> transposed_shape(permutation_array.size());
for (int64_t i = 0; i < permutation_array.size(); ++i) {
transposed_shape[i] = input_type.getDimSize(permutation_array[i]);
}
auto transposed_type =
RankedTensorType::get(transposed_shape, input_type.getElementType());
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
RankedTensorType::get(permutation_array.size(), rewriter.getI64Type()),
permutation_array);
return {permutation, transposed_type};
}
Value BuildIntConstOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter, int64_t const_value,
Type type) {
Value result_const =
builder.create<TF::ConstOp>(rewriter.getIntegerAttr(type, const_value));
return result_const;
}
Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter,
ArrayRef<int64_t> const_value, Type type) {
DenseIntElementsAttr const_value_raw;
if (type == rewriter.getI64Type()) {
const_value_raw = rewriter.getI64TensorAttr(const_value);
} else {
// Convert I64 const array to I32.
llvm::SmallVector<int32_t> const_i32_vec;
for (auto element : const_value) {
const_i32_vec.push_back(static_cast<int32_t>(element));
}
const_value_raw = rewriter.getI32TensorAttr(const_i32_vec);
}
Value result_const = builder.create<TF::ConstOp>(const_value_raw);
return result_const;
}
llvm::SmallVector<int64_t> GetInversePermutationArray(
llvm::ArrayRef<int64_t> permutation_array) {
llvm::SmallVector<int64_t> inverse_permutation_array(
permutation_array.size());
const auto permutation_array_size = permutation_array.size();
for (int64_t i = 0; i < permutation_array_size; ++i) {
inverse_permutation_array[permutation_array[i]] = i;
}
return inverse_permutation_array;
}
DenseIntElementsAttr GetInversePermutation(
llvm::ArrayRef<int64_t> permutation_array,
ConversionPatternRewriter& rewriter) {
SmallVector<int64_t, 4> inverse_permutation_array =
GetInversePermutationArray(permutation_array);
return DenseIntElementsAttr::get(
RankedTensorType::get(inverse_permutation_array.size(),
rewriter.getI64Type()),
inverse_permutation_array);
}
PermutationAndShape GetInversePermutationAndShape(
llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
ConversionPatternRewriter& rewriter) {
SmallVector<int64_t, 4> inverse_permutation_array =
GetInversePermutationArray(permutation_array);
return GetPermutationAndTransposedShape(inverse_permutation_array, input_type,
rewriter);
}
LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices,
ShapedType& indices_type,
int64_t index_vector_dim,
ConversionPatternRewriter& rewriter) {
if (index_vector_dim == indices_type.getRank()) {
llvm::SmallVector<int64_t, 4> new_start_indices_shape(
indices_type.getShape().begin(), indices_type.getShape().end());
new_start_indices_shape.push_back(1);
indices_type = RankedTensorType::get(new_start_indices_shape,
indices_type.getElementType());
indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
indices_type, indices);
} else if (index_vector_dim != indices_type.getRank() - 1) {
// If index_vector_dim isn't the last dimension in indices then it isn't
// supported yet.
// TODO(tberghammer): Transpose indices to support this usecase.
return rewriter.notifyMatchFailure(
parent_op,
"index vector dim isn't the last dimension in start indices");
}
return success();
}
// Check if the specified region is a binary reduction function that takes 2
// inputs and returns the second input. Functions like this are used by update
// scatter like ops.
template <>
LogicalResult MatchBinaryReduceFunction<void>(mlir::Region& function) {
Block& body = function.front();
if (body.getNumArguments() != 2) return failure();
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure();
if (return_op.getNumOperands() != 1) return failure();
if (return_op.getOperands().front() != body.getArgument(1)) return failure();
return success();
}
} // namespace odml
} // namespace mlir
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_
#include <stdint.h>
#include <cassert>
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace odml {
struct PermutationAndShape {
DenseIntElementsAttr permutation;
ShapedType shape;
};
// Check that `arr` is an R1 iota with integer element type starting from `0`
// with `size` number of values.
bool IsIotaAttr(ArrayRef<int64_t> arr, int64_t size);
// Returns a DenseIntElementsAttr for a permutation and the shape after
// applying the permutation to a given shape through a transpose.
PermutationAndShape GetPermutationAndTransposedShape(
llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
ConversionPatternRewriter& rewriter);
// Create a single const integer.
Value BuildIntConstOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter, int64_t const_value,
Type type);
// Create a const integer vector tensor (1-dim).
Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder,
ConversionPatternRewriter& rewriter,
ArrayRef<int64_t> const_value, Type type);
// Returns the inverse permutation array for a permutation array.
llvm::SmallVector<int64_t> GetInversePermutationArray(
llvm::ArrayRef<int64_t> permutation_array);
// Returns the DenseIntElementsAttr for an inverse permutation given a
// permutation_array.
DenseIntElementsAttr GetInversePermutation(
llvm::ArrayRef<int64_t> permutation_array,
ConversionPatternRewriter& rewriter);
// Returns a DenseIntElementsAttr for an inverse permutation and the shape after
// applying the inverse permutation to a given shape through a transpose.
PermutationAndShape GetInversePermutationAndShape(
llvm::ArrayRef<int64_t> permutation_array, ShapedType input_type,
ConversionPatternRewriter& rewriter);
// If index_vector_dim == indices.rank() then insert the implicit extra
// dimension into indices to normalize everything to index_vector_dim ==
// indices.rank() - 1.
LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices,
ShapedType& indices_type,
int64_t index_vector_dim,
ConversionPatternRewriter& rewriter);
// Checks if the specified region is a binary reduction function that takes 2
// inputs, passes it to an instance of the specified reduction op and then
// returns the result.
template <typename ReductionOp>
LogicalResult MatchBinaryReduceFunction(mlir::Region& function) {
Block& body = function.front();
if (body.getNumArguments() != 2) return failure();
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure();
if (return_op.getNumOperands() != 1) return failure();
ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
return_op.getOperands().front().getDefiningOp());
if (!reduce_op) return failure();
if (reduce_op.getLhs() != body.getArgument(0) ||
reduce_op.getRhs() != body.getArgument(1))
return failure();
return success();
}
// Check if the specified region is a binary reduction function that takes 2
// inputs and returns the second input. Functions like this are used by update
// scatter like ops.
template <>
LogicalResult MatchBinaryReduceFunction<void>(mlir::Region& function);
} // namespace odml
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册