diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index bd1dcdf06ea7d5f84f797fbad068c3c5d240b925..555c11779f5b64eda1a9ae97af2d16d4f28ce034 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -237,28 +237,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "constant_utils", - srcs = [ - "utils/constant_utils.cc", - ], - hdrs = [ - "utils/constant_utils.h", - ], - copts = ["-std=c++14"], - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:mangling_util", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:status", - "//tensorflow/stream_executor/lib", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "lstm_utils", srcs = [ @@ -369,7 +347,6 @@ cc_library( "transforms/passes.h", ], deps = [ - ":constant_utils", ":lstm_utils", ":stateful_ops_utils", ":tensorflow_lite", diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 3c390df74b4cb6546ad1ce2cefefb7ece26dfc4a..90266b4e78eb98fd339e3bb0a2da143133bb20a0 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -1,11 +1,12 @@ -// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s +// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> { %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor +// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor) -> tensor<3x3xbf16> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> // CHECK: return [[MUL]] : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d02e4e705f4524d445535c08267bc52c5570916e..7cb9c4dd22cfc51a29ace4a5c5ad779093f55c96 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1482,6 +1482,28 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) { // CHECK: return [[VAL_4]] : tensor<28x1x28xf32> // CHECK: } +func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_f32 +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor +// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xf32> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +// CHECK: return [[MUL]] : tensor<3x3xf32> +} + +func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0: tensor<3x3xi32> + +// CHECK-LABEL: broadcast_to_i32 +// CHECK: [[CST:%.*]] = constant dense<1> : tensor +// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<3x3xi32> +// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> +// CHECK: return [[MUL]] : tensor<3x3xi32> +} + func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 6ee5b67d65ec8ca6e634d5c993b8766793404842..066139e179b870b81664f50d0e003c10018aec38 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -595,24 +595,4 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { // CHECK: return %[[RES]] } -func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { - %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - return %0: tensor<3x3xf32> - -// CHECK-LABEL: broadcast_to_f32 -// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> -// CHECK: return [[MUL]] : tensor<3x3xf32> -} - -func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { - %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> - return %0: tensor<3x3xi32> - -// CHECK-LABEL: broadcast_to_i32 -// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32> -// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> -// CHECK: return [[MUL]] : tensor<3x3xi32> -} - } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 297b1459fc551918a178020fd42d1b4927ffa70e..7a16e475ce3a9ed413c624e6d09bf39e157440fb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -45,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" -#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -138,6 +137,7 @@ DECL_CONVERT_OP(StridedSlice); DECL_CONVERT_OP(Unpack); DECL_CONVERT_OP(Reciprocal); DECL_CONVERT_OP(RandomUniform); +DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP @@ -483,6 +483,89 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite( return success(); } +StatusOr CreateConstOpWithSingleValue(PatternRewriter* rewriter, + Location loc, + ShapedType shaped_type, + int value) { + Type element_type = shaped_type.getElementType(); + ShapedType scalar_type = RankedTensorType::get({}, element_type); + Attribute attr; + switch (element_type.getKind()) { + case mlir::StandardTypes::F16: { + auto floatType = mlir::FloatType::getF16(element_type.getContext()); + auto floatAttr = + mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + break; + } + case mlir::StandardTypes::BF16: { + auto floatType = mlir::FloatType::getBF16(element_type.getContext()); + auto floatAttr = + mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + break; + } + case mlir::StandardTypes::F32: { + attr = + DenseElementsAttr::get(scalar_type, static_cast(value)); + break; + } + case mlir::StandardTypes::Complex: { + auto etype = element_type.cast().getElementType(); + if (etype.isF32()) { + auto dialect = etype.getContext()->getRegisteredDialect("tf"); + tensorflow::TensorProto repr; + repr.set_dtype(tensorflow::DT_COMPLEX64); + + tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); + shape->set_unknown_rank(false); + shape->add_dim()->set_size(int64_t{1}); + std::string content; + auto complex_value = + std::complex(static_cast(value), 0.0f); + content.assign(reinterpret_cast(&complex_value), + sizeof(complex_value)); + repr.set_tensor_content(content); + std::string mangled = tensorflow::mangling_util::MangleTensor(repr); + + attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); + break; + } + return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); + } + case mlir::StandardTypes::Integer: { + const auto& itype = element_type.cast(); + switch (itype.getWidth()) { + case 8: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 16: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 32: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 64: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + default: + return Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + break; + } + default: + return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type"); + } + return rewriter->create(loc, scalar_type, attr); +} + LogicalResult ConvertTFReciprocalOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reciprocal_op = cast(op); @@ -503,6 +586,31 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tf_broadcast_to_op = cast(op); + auto element_type = tf_broadcast_to_op.input().getType().cast(); + auto output_type = tf_broadcast_to_op.output().getType(); + + auto status_or_const_op = + CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1); + if (!status_or_const_op.ok()) { + return failure(); + } + + auto tfl_fill_op = rewriter.create( + op->getLoc(), output_type, tf_broadcast_to_op.shape(), + status_or_const_op.ValueOrDie()); + + StringAttr fused_activation_function = + StringAttr::get("NONE", rewriter.getContext()); + + rewriter.replaceOpWithNewOp( + op, output_type, tf_broadcast_to_op.input(), tfl_fill_op, + fused_activation_function); + return success(); +} + // Legalize unidirectional sequence lstm. struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context) @@ -643,7 +751,7 @@ void LegalizeTF::runOnFunction() { ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp, - ConvertTFRandomUniformOp>(context); + ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context); // Ophint python converter converted tf node pattern. patterns.insert(op); - auto input_type = tf_broadcast_to_op.input().getType().cast(); - auto output_type = tf_broadcast_to_op.output().getType().cast(); - auto shape_type = tf_broadcast_to_op.shape().getType().cast(); - Type element_type = input_type.getElementType(); - - // Allow lowering when low dimension inputs are given and its type is F32 or - // I32. - if (!((output_type.hasRank() && output_type.getRank() <= 4) || - (shape_type.hasStaticShape() && shape_type.getRank() == 1 && - shape_type.getDimSize(0) <= 4))) - return failure(); - - if (!((element_type.getKind() == mlir::StandardTypes::F32) || - (element_type.getKind() == mlir::StandardTypes::BF16) || - (element_type.getKind() == mlir::StandardTypes::Integer && - element_type.cast().getWidth() == 32))) - return failure(); - - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - auto tf_fill_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.shape(), - status_or_const_op.ValueOrDie()); - - auto mul_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op); - rewriter.replaceOp(op, mul_op.getResult()); - return success(); - } -}; - #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" // Returns success if all the operations in the `op`'s regions including `op` @@ -810,7 +767,7 @@ void PrepareTFPass::runOnFunction() { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); } - patterns.insert(ctx); applyPatternsAndFoldGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc deleted file mode 100644 index 8562f623258733e534157e39f8150a95e2912502..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" - -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/platform/status.h" - -namespace mlir { -namespace TFL { - -stream_executor::port::StatusOr CreateConstOpWithSingleValue( - PatternRewriter* rewriter, Location loc, ShapedType shaped_type, - int value) { - Type element_type = shaped_type.getElementType(); - ShapedType scalar_type = RankedTensorType::get({}, element_type); - Attribute attr; - switch (element_type.getKind()) { - case mlir::StandardTypes::F16: { - auto floatType = mlir::FloatType::getF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::BF16: { - auto floatType = mlir::FloatType::getBF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::F32: { - attr = - DenseElementsAttr::get(scalar_type, static_cast(value)); - break; - } - case mlir::StandardTypes::Complex: { - auto etype = element_type.cast().getElementType(); - if (etype.isF32()) { - auto dialect = etype.getContext()->getRegisteredDialect("tf"); - tensorflow::TensorProto repr; - repr.set_dtype(tensorflow::DT_COMPLEX64); - - tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); - shape->set_unknown_rank(false); - shape->add_dim()->set_size(int64_t{1}); - std::string content; - auto complex_value = - std::complex(static_cast(value), 0.0f); - content.assign(reinterpret_cast(&complex_value), - sizeof(complex_value)); - repr.set_tensor_content(content); - std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - - attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); - break; - } - return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); - } - case mlir::StandardTypes::Integer: { - const auto& itype = element_type.cast(); - switch (itype.getWidth()) { - case 8: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 16: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 32: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 64: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - default: - return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); - } - break; - } - default: - return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); - } - return rewriter->create(loc, scalar_type, attr); -} - -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h deleted file mode 100644 index 5c348021b5e5c5f539d5790e97b302ee83cdfa82..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace mlir { -namespace TFL { - -// Returns a Constant op with a single value. -stream_executor::port::StatusOr CreateConstOpWithSingleValue( - PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); - -} // namespace TFL -} // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_