提交 57f61ed3 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Move TF Broadcast op legalization process to the prepare_tf stage

This change is to get benefits from the constant folding logic from TF dialect.

PiperOrigin-RevId: 326169849
Change-Id: I432eeb902b1ace379388cbd1870e2cf35b828b17
上级 916bd91c
......@@ -237,28 +237,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "constant_utils",
srcs = [
"utils/constant_utils.cc",
],
hdrs = [
"utils/constant_utils.h",
],
copts = ["-std=c++14"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "lstm_utils",
srcs = [
......@@ -369,7 +347,6 @@ cc_library(
"transforms/passes.h",
],
deps = [
":constant_utils",
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
......
// RUN: tf-opt %s -tfl-prepare-tf -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
return %0: tensor<3x3xbf16>
// CHECK-LABEL: broadcast_to_bf16
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[CST]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: return [[MUL]] : tensor<3x3xbf16>
}
......@@ -1482,6 +1482,28 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
// CHECK: }
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
......
......@@ -595,24 +595,4 @@ func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
// CHECK: return %[[RES]]
}
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
}
......@@ -45,7 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
......@@ -138,6 +137,7 @@ DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Reciprocal);
DECL_CONVERT_OP(RandomUniform);
DECL_CONVERT_OP(BroadcastTo);
#undef DECL_CONVERT_OP
......@@ -483,6 +483,89 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite(
return success();
}
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
Location loc,
ShapedType shaped_type,
int value) {
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
break;
}
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
......@@ -503,6 +586,31 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
return success();
}
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
auto output_type = tf_broadcast_to_op.output().getType();
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
if (!status_or_const_op.ok()) {
return failure();
}
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
status_or_const_op.ValueOrDie());
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::MulOp>(
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
fused_activation_function);
return success();
}
// Legalize unidirectional sequence lstm.
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
......@@ -643,7 +751,7 @@ void LegalizeTF::runOnFunction() {
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp>(context);
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,
......
......@@ -57,7 +57,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
......@@ -687,48 +686,6 @@ struct ConvertTFStridedSlice : public RewritePattern {
}
};
struct ConvertTFBroadcastTo : public RewritePattern {
explicit ConvertTFBroadcastTo(MLIRContext *context)
: RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
Type element_type = input_type.getElementType();
// Allow lowering when low dimension inputs are given and its type is F32 or
// I32.
if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
shape_type.getDimSize(0) <= 4)))
return failure();
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
(element_type.getKind() == mlir::StandardTypes::BF16) ||
(element_type.getKind() == mlir::StandardTypes::Integer &&
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
return failure();
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
if (!status_or_const_op.ok()) {
return failure();
}
auto tf_fill_op = rewriter.create<TF::FillOp>(
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
status_or_const_op.ValueOrDie());
auto mul_op = rewriter.create<TF::MulOp>(
op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
rewriter.replaceOp(op, mul_op.getResult());
return success();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
// Returns success if all the operations in the `op`'s regions including `op`
......@@ -810,7 +767,7 @@ void PrepareTFPass::runOnFunction() {
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
}
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}
......
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace TFL {
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
int value) {
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
break;
}
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
} // namespace TFL
} // namespace mlir
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/stream_executor/lib/statusor.h"
namespace mlir {
namespace TFL {
// Returns a Constant op with a single value.
stream_executor::port::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册