From 780ab29c5c6328348e78a5c75b3ddc25609b4045 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Mon, 24 Aug 2020 13:37:08 -0700 Subject: [PATCH] When marking ops for outside compilation, use a different attribute value for each op. These ops are later clustered by _xla_outside_compiled attribute value. Unsupported ops marked for outside compilation should be clustered in a separate pass. PiperOrigin-RevId: 328200978 Change-Id: Ia3c30138ebb6a42d2e1064277c2b7461f2f53204 --- .../mark_ops_for_outside_compilation.mlir | 8 +++--- .../mark_ops_for_outside_compilation.cc | 26 ++++++++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index df2add2208a..dc99d9d6343 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -137,7 +137,7 @@ func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> // CHECK: "tf.IfRegion" // CHECK: "tf.StringToNumber" // CHECK-NOT: _xla_outside_compilation - // CHECK: _xla_outside_compilation = "auto", is_stateless = true + // CHECK: _xla_outside_compilation = "auto1", is_stateless = true %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %2 = "tf.IfRegion"(%arg0) ( { %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor @@ -166,7 +166,7 @@ func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor : tensor} : () -> tensor "tf.Yield"(%3) : (tensor) -> () }, { - // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor} + // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor} // CHECK-NEXT: "tf.StringToNumber" // CHECK-SAME: _xla_outside_compilation %4 = "tf.Const"() {value = dense<"1.0"> : tensor} : () -> tensor @@ -198,7 +198,7 @@ func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> ten // CHECK-NOT: _xla_outside_compilation %4 = "tf.Const"() {value = dense : tensor} : () -> tensor %5 = "tf.IfRegion"(%4)({ - // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor} + // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor} // CHECK-NEXT: "tf.StringToNumber" // CHECK-SAME: _xla_outside_compilation %6 = "tf.Const"() {value = dense<"1.0"> : tensor} : () -> tensor @@ -229,7 +229,7 @@ func @while_region_captured_string(%arg0: tensor, %arg1: tensor // CHECK-NOT: _xla_outside_compilation // CHECK: "tf.WhileRegion" // CHECK: "tf.StringToNumber" - // CHECK: _xla_outside_compilation = "auto", is_stateless = true + // CHECK: _xla_outside_compilation = "auto1", is_stateless = true %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor %2:2 = "tf.WhileRegion"(%1, %arg0) ( { ^bb0(%carg0: tensor, %carg1: tensor): diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 38cbe3f404e..1ffe456405f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project @@ -116,15 +117,32 @@ bool HasCapturedStringOperand(Operation* op) { LogicalResult MarkUncompilableOps( const Dialect* tf_dialect, Block* block, llvm::DenseSet& supported_ops) { + // Automatically marked ops for outside compilation have + // `_xla_outside_compilation` attribute value of "auto" plus + // an increasing counter. Manually marked ops for outside compilation only + // have an increasing counteri for the attribute value. Therefore there is no + // collision in + // `_xla_outside_compilation` attribute between automatically and manually + // marking ops. + int outside_compiled_cluster_counter = 0; block->walk([&](Operation* op) { if (!IsSupportedOp(*op, supported_ops, tf_dialect)) { - op->setAttr(kXlaOutsideCompilationAttr, - StringAttr::get("auto", op->getContext())); + op->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get( + llvm::formatv("auto{0}", outside_compiled_cluster_counter).str(), + op->getContext())); + outside_compiled_cluster_counter++; } if (llvm::isa(op)) { if (HasCapturedStringOperand(op)) { - op->setAttr(kXlaOutsideCompilationAttr, - StringAttr::get("auto", op->getContext())); + op->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get( + llvm::formatv("auto{0}", outside_compiled_cluster_counter) + .str(), + op->getContext())); + outside_compiled_cluster_counter++; } } }); -- GitLab