提交 50f04155 编写于 作者: M Matthias Kramm 提交者: TensorFlower Gardener

Make mlprogram_util lower to stablehlo by default.

Also, add more optimization/simplification passes.

PiperOrigin-RevId: 481229419
上级 b0b5bc63
......@@ -403,6 +403,7 @@ cc_library(
":tensorflow_passes",
":tf_saved_model_passes",
"//tensorflow/compiler/mlir/xla:legalize_tf",
"//tensorflow/compiler/xla/mlir_hlo:mhlo_passes",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
......
// RUN: tf-opt --split-input-file -tf-lower-to-mlprogram-and-hlo %s -o - | FileCheck %s
module attributes {tf_saved_model.semantics} {
// CHECK-LABEL: func @lowers_to_mhlo
func.func @lowers_to_mhlo(%arg0: tensor<i32> {tf_saved_model.index_path = []}) -> (tensor<*xi32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["lowers_to_mhlo"]}
// CHECK-LABEL: func @lowers_to_stablehlo
func.func @lowers_to_stablehlo(%arg0: tensor<i32> {tf_saved_model.index_path = []}) -> (tensor<*xi32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["lowers_to_stablehlo"]}
{
// CHECK-DAG: [[one:%.*]] = mhlo.constant dense<1>
// CHECK-DAG: [[twenty:%.*]] = mhlo.constant dense<20>
// CHECK-DAG: [[r3:%.*]] = mhlo.subtract [[twenty]], %arg0
// CHECK-DAG: [[zero:%.*]] = mhlo.constant dense<0>
// CHECK-DAG: [[r4:%.*]] = mhlo.divide [[r3]], [[one]]
// CHECK-DAG: [[r5:%.*]] = mhlo.compare NE
// CHECK-DAG: [[r6:%.*]] = mhlo.compare LT
// CHECK: [[result:%.*]] = "mhlo.select"
// CHECK-DAG: [[one:%.*]] = stablehlo.constant dense<1>
// CHECK-DAG: [[twenty:%.*]] = stablehlo.constant dense<20>
// CHECK-DAG: [[r3:%.*]] = stablehlo.subtract [[twenty]], %arg0
// CHECK-DAG: [[zero:%.*]] = stablehlo.constant dense<0>
// CHECK-DAG: [[r4:%.*]] = stablehlo.divide [[r3]], [[one]]
// CHECK-DAG: [[r5:%.*]] = stablehlo.compare NE
// CHECK-DAG: [[r6:%.*]] = stablehlo.compare LT
// CHECK: [[result:%.*]] = stablehlo.select
// CHECK-NEXT: return [[result]]
%0 = tf_executor.graph {
%outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor<i32>} : () -> tensor<i32>
......@@ -91,7 +91,7 @@ module attributes {tf_saved_model.semantics} {
func.func @handles_variables_in_while_loops(%arg0: tensor<!tf_type.resource<tensor<i32>>> {tf._user_specified_name = "arg0", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0", tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["lowers_variable_ops"]}
{
// CHECK: mhlo.while
// CHECK: stablehlo.while
tf_executor.graph {
%0, %c0 = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%1, %c1 = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
......
......@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace tensorflow {
......@@ -35,6 +36,9 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) {
// Remove unused global tensors, or make then immutable if possible.
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
pm.addPass(mlir::TF::CreateNameAnonymousIteratorsPass());
// This will add regions to IfOp/WhileOp (turning them into IfRegionOp
// and WhileRegionOp), but be aware that those regions will still contain
// calls.
......@@ -46,27 +50,32 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) {
pm.addPass(mlir::tf_saved_model::CreateStripSavedModuleMetadataPass());
pm.addPass(mlir::TF::CreateRemoveUnusedArgumentsPass());
pm.addPass(mlir::TF::CreateRemoveUnusedWhileResultsPass());
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createSymbolDCEPass());
// Has to be a non-empty string, so tf2xla fallbacks kick in.
llvm::StringRef tf2xla_fallback_device_type = "cpu/gpu/tpu";
pm.addPass(mlir::TF::CreateTFShapeInferencePass());
llvm::StringRef tf2xla_fallback_device_type = "XLA_CPU_JIT";
pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
tf2xla_fallback_device_type, /*prefer_tf2xla=*/false));
pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::TF::CreateStripTfAttributesPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::TF::CreateOrderByDialectPass());
pm.addPass(mlir::TF::CreateGroupByDialectPass());
pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
}
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册