提交 780ab29c 编写于 作者: K Ken Franko 提交者: TensorFlower Gardener

When marking ops for outside compilation, use a different attribute value for each op.

These ops are later clustered by _xla_outside_compiled attribute value.  Unsupported ops marked for outside compilation
should be clustered in a separate pass.

PiperOrigin-RevId: 328200978
Change-Id: Ia3c30138ebb6a42d2e1064277c2b7461f2f53204
上级 5451813a
...@@ -137,7 +137,7 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> ...@@ -137,7 +137,7 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
// CHECK: "tf.IfRegion" // CHECK: "tf.IfRegion"
// CHECK: "tf.StringToNumber" // CHECK: "tf.StringToNumber"
// CHECK-NOT: _xla_outside_compilation // CHECK-NOT: _xla_outside_compilation
// CHECK: _xla_outside_compilation = "auto", is_stateless = true // CHECK: _xla_outside_compilation = "auto1", is_stateless = true
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.IfRegion"(%arg0) ( { %2 = "tf.IfRegion"(%arg0) ( {
%3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32> %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
...@@ -166,7 +166,7 @@ func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32 ...@@ -166,7 +166,7 @@ func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32
%3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32> %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
"tf.Yield"(%3) : (tensor<f32>) -> () "tf.Yield"(%3) : (tensor<f32>) -> ()
}, { }, {
// CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>} // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf.string>}
// CHECK-NEXT: "tf.StringToNumber" // CHECK-NEXT: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation // CHECK-SAME: _xla_outside_compilation
%4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string> %4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
...@@ -198,7 +198,7 @@ func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> ten ...@@ -198,7 +198,7 @@ func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> ten
// CHECK-NOT: _xla_outside_compilation // CHECK-NOT: _xla_outside_compilation
%4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1> %4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%5 = "tf.IfRegion"(%4)({ %5 = "tf.IfRegion"(%4)({
// CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>} // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf.string>}
// CHECK-NEXT: "tf.StringToNumber" // CHECK-NEXT: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation // CHECK-SAME: _xla_outside_compilation
%6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string> %6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
...@@ -229,7 +229,7 @@ func @while_region_captured_string(%arg0: tensor<i32>, %arg1: tensor<!tf.string> ...@@ -229,7 +229,7 @@ func @while_region_captured_string(%arg0: tensor<i32>, %arg1: tensor<!tf.string>
// CHECK-NOT: _xla_outside_compilation // CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.WhileRegion" // CHECK: "tf.WhileRegion"
// CHECK: "tf.StringToNumber" // CHECK: "tf.StringToNumber"
// CHECK: _xla_outside_compilation = "auto", is_stateless = true // CHECK: _xla_outside_compilation = "auto1", is_stateless = true
%1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32> %1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%2:2 = "tf.WhileRegion"(%1, %arg0) ( { %2:2 = "tf.WhileRegion"(%1, %arg0) ( {
^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>): ^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
......
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project
...@@ -116,15 +117,32 @@ bool HasCapturedStringOperand(Operation* op) { ...@@ -116,15 +117,32 @@ bool HasCapturedStringOperand(Operation* op) {
LogicalResult MarkUncompilableOps( LogicalResult MarkUncompilableOps(
const Dialect* tf_dialect, Block* block, const Dialect* tf_dialect, Block* block,
llvm::DenseSet<OperationName>& supported_ops) { llvm::DenseSet<OperationName>& supported_ops) {
// Automatically marked ops for outside compilation have
// `_xla_outside_compilation` attribute value of "auto" plus
// an increasing counter. Manually marked ops for outside compilation only
// have an increasing counteri for the attribute value. Therefore there is no
// collision in
// `_xla_outside_compilation` attribute between automatically and manually
// marking ops.
int outside_compiled_cluster_counter = 0;
block->walk([&](Operation* op) { block->walk([&](Operation* op) {
if (!IsSupportedOp(*op, supported_ops, tf_dialect)) { if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
op->setAttr(kXlaOutsideCompilationAttr, op->setAttr(
StringAttr::get("auto", op->getContext())); kXlaOutsideCompilationAttr,
StringAttr::get(
llvm::formatv("auto{0}", outside_compiled_cluster_counter).str(),
op->getContext()));
outside_compiled_cluster_counter++;
} }
if (llvm::isa<TF::IfRegionOp, TF::WhileRegionOp>(op)) { if (llvm::isa<TF::IfRegionOp, TF::WhileRegionOp>(op)) {
if (HasCapturedStringOperand(op)) { if (HasCapturedStringOperand(op)) {
op->setAttr(kXlaOutsideCompilationAttr, op->setAttr(
StringAttr::get("auto", op->getContext())); kXlaOutsideCompilationAttr,
StringAttr::get(
llvm::formatv("auto{0}", outside_compiled_cluster_counter)
.str(),
op->getContext()));
outside_compiled_cluster_counter++;
} }
} }
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册