提交 780ab29c 编写于 作者: K Ken Franko 提交者: TensorFlower Gardener

When marking ops for outside compilation, use a different attribute value for each op.

These ops are later clustered by _xla_outside_compiled attribute value.  Unsupported ops marked for outside compilation
should be clustered in a separate pass.

PiperOrigin-RevId: 328200978
Change-Id: Ia3c30138ebb6a42d2e1064277c2b7461f2f53204
上级 5451813a
......@@ -137,7 +137,7 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
// CHECK: "tf.IfRegion"
// CHECK: "tf.StringToNumber"
// CHECK-NOT: _xla_outside_compilation
// CHECK: _xla_outside_compilation = "auto", is_stateless = true
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.IfRegion"(%arg0) ( {
%3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
......@@ -166,7 +166,7 @@ func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32
%3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
"tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
// CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
// CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf.string>}
// CHECK-NEXT: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation
%4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
......@@ -198,7 +198,7 @@ func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> ten
// CHECK-NOT: _xla_outside_compilation
%4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%5 = "tf.IfRegion"(%4)({
// CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
// CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf.string>}
// CHECK-NEXT: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation
%6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
......@@ -229,7 +229,7 @@ func @while_region_captured_string(%arg0: tensor<i32>, %arg1: tensor<!tf.string>
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.WhileRegion"
// CHECK: "tf.StringToNumber"
// CHECK: _xla_outside_compilation = "auto", is_stateless = true
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
%1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%2:2 = "tf.WhileRegion"(%1, %arg0) ( {
^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
......
......@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
#include <utility>
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
......@@ -116,15 +117,32 @@ bool HasCapturedStringOperand(Operation* op) {
LogicalResult MarkUncompilableOps(
const Dialect* tf_dialect, Block* block,
llvm::DenseSet<OperationName>& supported_ops) {
// Automatically marked ops for outside compilation have
// `_xla_outside_compilation` attribute value of "auto" plus
// an increasing counter. Manually marked ops for outside compilation only
// have an increasing counteri for the attribute value. Therefore there is no
// collision in
// `_xla_outside_compilation` attribute between automatically and manually
// marking ops.
int outside_compiled_cluster_counter = 0;
block->walk([&](Operation* op) {
if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
op->setAttr(kXlaOutsideCompilationAttr,
StringAttr::get("auto", op->getContext()));
op->setAttr(
kXlaOutsideCompilationAttr,
StringAttr::get(
llvm::formatv("auto{0}", outside_compiled_cluster_counter).str(),
op->getContext()));
outside_compiled_cluster_counter++;
}
if (llvm::isa<TF::IfRegionOp, TF::WhileRegionOp>(op)) {
if (HasCapturedStringOperand(op)) {
op->setAttr(kXlaOutsideCompilationAttr,
StringAttr::get("auto", op->getContext()));
op->setAttr(
kXlaOutsideCompilationAttr,
StringAttr::get(
llvm::formatv("auto{0}", outside_compiled_cluster_counter)
.str(),
op->getContext()));
outside_compiled_cluster_counter++;
}
}
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册