提交 9d53fd3a 编写于 作者: A Andy Ly 提交者: TensorFlower Gardener

Disable TPUVariableRuntimeReformattingPass when model parallelism is detected.

Variable reformatting is currently not supported with model parallelism.

PiperOrigin-RevId: 306529947
Change-Id: I49f83bc1c76762d7bfb4607905c65e413f7bf278
上级 4a071b39
......@@ -38,7 +38,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
return
}
// CHECK: func @while_body_7560
// CHECK-LABEL: func @while_body_7560
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
......@@ -112,7 +112,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// -----
// Tests that the pass does not format variabls with other uses.
// Tests that the pass does not format variables with other uses.
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
......@@ -135,7 +135,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
return
}
// CHECK: func @while_body_7560
// CHECK-LABEL: func @while_body_7560
// CHECK-NOT: TPUReshardVariables
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
......@@ -198,3 +198,87 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %1 : tensor<i1>
}
}
// -----
// Tests that the pass does not format variables when model parallelism is
// present.
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
// CHECK-LABEL: func @main
// CHECK-NOT: TPUReshardVariables
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false,
output_shapes = ["tfshape$", "tfshape$", "tfshape$", "tfshape$", "tfshape$"]}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
return
}
// CHECK-LABEL: func @while_body_7560
// CHECK-NOT: TPUReshardVariables
func @while_body_7560(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.parallel_execute"() ({
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
tf_device.return
}, {
tf_device.return
}) {} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
}
// CHECK-LABEL: func @while_cond_7550
func @while_cond_7550(%arg0: tensor<i32>,
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
-> tensor<i1> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
}
......@@ -570,7 +570,11 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() {
replicate = nullptr;
return WalkResult::interrupt();
});
if (replicate) HandleReplicateOp(while_op, replicate, &getContext());
// Model parallelism is not supported, and can be detected when a
// `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
if (replicate &&
replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
HandleReplicateOp(while_op, replicate, &getContext());
});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册