From 9d53fd3a010d6e1545badc5933d6a26d2ec734d4 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 14 Apr 2020 15:45:50 -0700 Subject: [PATCH] Disable TPUVariableRuntimeReformattingPass when model parallelism is detected. Variable reformatting is currently not supported with model parallelism. PiperOrigin-RevId: 306529947 Change-Id: I49f83bc1c76762d7bfb4607905c65e413f7bf278 --- .../tpu-variable-runtime-reformatting.mlir | 90 ++++++++++++++++++- .../tpu_variable_runtime_reformatting.cc | 6 +- 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index d0ca8c09457..c2faef929d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -38,7 +38,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" return } - // CHECK: func @while_body_7560 + // CHECK-LABEL: func @while_body_7560 func @while_body_7560(%arg0: tensor, %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, @@ -112,7 +112,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // ----- -// Tests that the pass does not format variabls with other uses. +// Tests that the pass does not format variables with other uses. module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { // CHECK-LABEL: func @main @@ -135,7 +135,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) return } - // CHECK: func @while_body_7560 + // CHECK-LABEL: func @while_body_7560 // CHECK-NOT: TPUReshardVariables func @while_body_7560(%arg0: tensor, %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, @@ -198,3 +198,87 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1 : tensor } } + +// ----- + +// Tests that the pass does not format variables when model parallelism is +// present. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + // CHECK-LABEL: func @main + // CHECK-NOT: TPUReshardVariables + func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) { + + %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor + %1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3) + {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", + "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", + "tfdtype$DT_RESOURCE"], body = @while_body_7560, + cond = @while_cond_7550, device = "", is_stateless = false, + output_shapes = ["tfshape$", "tfshape$", "tfshape$", "tfshape$", "tfshape$"]} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + return + } + // CHECK-LABEL: func @while_body_7560 + // CHECK-NOT: TPUReshardVariables + func @while_body_7560(%arg0: tensor, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %compile:2 = "tf_device.launch"() ( { + %2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor) + tf_device.return %2#0, %2#1 : tensor, tensor + }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + %rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, + [%arg3, %arg4] as %arg31: tensor<*x!tf.resource>>) + {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + "tf_device.parallel_execute"() ({ + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + }, { + tf_device.return + }) {} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + return %1, %arg1, %arg2, %arg3, %arg4 : tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + // CHECK-LABEL: func @while_cond_7550 + func @while_cond_7550(%arg0: tensor, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) + -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index a6ea26b1ebf..3b832b2e78e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -570,7 +570,11 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() { replicate = nullptr; return WalkResult::interrupt(); }); - if (replicate) HandleReplicateOp(while_op, replicate, &getContext()); + // Model parallelism is not supported, and can be detected when a + // `tf_device.parallel_execute` op in the `tf_device.replicate` is present. + if (replicate && + replicate.GetBody().getOps().empty()) + HandleReplicateOp(while_op, replicate, &getContext()); }); } -- GitLab