From 63a744342488a6b9dc0bf6f3429572e5ded46594 Mon Sep 17 00:00:00 2001 From: Haitang Hu Date: Wed, 28 Oct 2020 11:10:31 -0700 Subject: [PATCH] Add Optraits helpers for verifing operand/result rank. PiperOrigin-RevId: 339497489 Change-Id: Ic28a1b83ce19ae5bde086495f9d937abbcea4cc0 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 2 +- .../compiler/mlir/tensorflow/ir/tf_op_base.td | 24 +++++++++++ .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 43 ++++--------------- .../mlir/tensorflow/tests/tf-ops.mlir | 8 ++-- 4 files changed, 37 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 03a7fa10cf0..9c35f9aa89c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1335,7 +1335,7 @@ broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect, SameOperandsAndResultElementType, TF_OperandHasRank<0, 1>, TF_OperandHasRank<1, 1>, TF_ResultHasRank<0, 1>, TF_ResultHasRank<1, 1>]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 15c0d7b10f7..7b68337d3cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -100,6 +100,30 @@ class TF_AllTypesMatch names> : TF_AllTypesMatchPred< !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; +//===----------------------------------------------------------------------===// +// Rank/Shape helpers. +//===----------------------------------------------------------------------===// + +class TF_OperandIsUnrankedPred : + CPred<"$_op.getOperand(" # n # ").getType().isa()">; + +class TF_ResultIsUnrankedPred : + CPred<"$_op.getResult(" # n # ").getType().isa()">; + +// Returns true if the n-th operand has unknown rank or has rank m. +class TF_OperandHasRank : + PredOpTrait<"operand " # n # " is " # m # "-D", + Or<[TF_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # + ").getType().cast().getRank() == " # m>]>>; + +// Returns true if the n-th result has unknown rank or has rank m. +class TF_ResultHasRank : + PredOpTrait<"result " # n # " is " # m # "-D", + Or<[TF_ResultIsUnrankedPred, + CPred<"$_op.getResult(" # n # + ").getType().cast().getRank() == " # m>]>>; + //===----------------------------------------------------------------------===// // TensorFlow op side effects //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 58603d6e4da..7bbb7f3bad0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -546,8 +546,8 @@ OpFoldResult BroadcastToOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// namespace { -// Returns `true` if both s0 & s1 is defined via constant op, and fills s0_shape -// & s1_shape. +// Returns `true` if both s0 & s1 are defined via constant op, and fills +// s0_shape & s1_shape. bool ExtractInputConstShape(BroadcastGradientArgsOp op, DenseIntElementsAttr &s0, DenseIntElementsAttr &s1, SmallVectorImpl &s0_shape, @@ -594,37 +594,11 @@ void GetOutputShapeForBroadcastGradientArgs(ArrayRef bcasted_shape, } // namespace // Verifies that, -// * all inputs/outputs rank is 1. -// For cases where static shape can be extracted from inputs, verifies that, // * Broadcast compatability for input shapes. // * Output shape dimension matches the expected dimension size for input // shapes. static LogicalResult Verify(BroadcastGradientArgsOp op) { - // Check rank = 1 for input/outputs. - RankedTensorType s0_ty = GetRankedTensorTypeForOperand(op.s0()); - RankedTensorType s1_ty = GetRankedTensorTypeForOperand(op.s1()); - RankedTensorType r0_ty = GetRankedTensorTypeForOperand(op.r0()); - RankedTensorType r1_ty = GetRankedTensorTypeForOperand(op.r1()); - if (s0_ty && s0_ty.getRank() != 1) - return op.emitOpError() - << "requires 's0' to be a rank 1 tensor, but got rank " - << s0_ty.getRank(); - if (s1_ty && s1_ty.getRank() != 1) - return op.emitOpError() - << "requires 's1' to be a rank 1 tensor, but got rank " - << s1_ty.getRank(); - if (r0_ty && r0_ty.getRank() != 1) - return op.emitOpError() - << "requires 'r0' to be a rank 1 tensor, but got rank " - << r0_ty.getRank(); - if (r1_ty && r1_ty.getRank() != 1) - return op.emitOpError() - << "requires 'r1' to be a rank 1 tensor, but got rank " - << r1_ty.getRank(); - - SmallVector s0_shape; - SmallVector s1_shape; - + SmallVector s0_shape, s1_shape; DenseIntElementsAttr s0, s1; if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success(); @@ -635,11 +609,12 @@ static LogicalResult Verify(BroadcastGradientArgsOp op) { "for 's0' and 's1', but got " << s0 << " and " << s1; - SmallVector r0; - SmallVector r1; + SmallVector r0, r1; GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0, r1); + RankedTensorType r0_ty = GetRankedTensorTypeForOperand(op.r0()); + RankedTensorType r1_ty = GetRankedTensorTypeForOperand(op.r1()); if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getShape()[0] != r0.size()) return op.emitOpError() << "requires dimension 0 size of 'r0' to be " << r0.size() << " but got " << r0_ty.getShape()[0]; @@ -652,9 +627,8 @@ static LogicalResult Verify(BroadcastGradientArgsOp op) { LogicalResult BroadcastGradientArgsOp::fold( ArrayRef operands, SmallVectorImpl &results) { + SmallVector s0_shape, s1_shape; DenseIntElementsAttr s0, s1; - SmallVector s0_shape; - SmallVector s1_shape; if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape)) return failure(); @@ -667,8 +641,7 @@ LogicalResult BroadcastGradientArgsOp::fold( assert(bcast_compatible); (void)bcast_compatible; - SmallVector r0; - SmallVector r1; + SmallVector r0, r1; GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0, r1); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 62975090869..02b9d63cebb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -185,7 +185,7 @@ func @testBroadcastGradientArgsIncompatibleBroadcastShape() -> (tensor<1xi32>, t func @testBroadcastGradientArgsInvalidS0Rank() -> (tensor<2x2xi32>, tensor<0xi32>) { %s0 = "tf.Const"() {value = dense<[[4, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32> - // expected-error @+1 {{requires 's0' to be a rank 1 tensor, but got rank 2}} + // expected-error @+1 {{failed to verify that operand 0 is 1-D}} %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<2x2xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>) return %r0, %r1 : tensor<1xi32>, tensor<0xi32> } @@ -195,7 +195,7 @@ func @testBroadcastGradientArgsInvalidS0Rank() -> (tensor<2x2xi32>, tensor<0xi32 func @testBroadcastGradientArgsInvalidS1Rank() -> (tensor<2xi32>, tensor) { %s0 = "tf.Const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32> %s1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // expected-error @+1 {{requires 's1' to be a rank 1 tensor, but got rank 0}} + // expected-error @+1 {{failed to verify that operand 1 is 1-D}} %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<2xi32>, tensor) -> (tensor<1xi32>, tensor<0xi32>) return %r0, %r1 : tensor<1xi32>, tensor<0xi32> } @@ -205,7 +205,7 @@ func @testBroadcastGradientArgsInvalidS1Rank() -> (tensor<2xi32>, tensor) { func @testBroadcastGradientArgsInvalidR0Rank() -> (tensor<2x2xi32>, tensor<0xi32>) { %s0 = "tf.Const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32> %s1 = "tf.Const"() {value = dense<[4, 4]> : tensor<2xi32>} : () -> tensor<2xi32> - // expected-error @+1 {{requires 'r0' to be a rank 1 tensor, but got rank 2}} + // expected-error @+1 {{failed to verify that result 0 is 1-D}} %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2x2xi32>, tensor<0xi32>) return %r0, %r1 : tensor<2x2xi32>, tensor<0xi32> } @@ -213,7 +213,7 @@ func @testBroadcastGradientArgsInvalidR0Rank() -> (tensor<2x2xi32>, tensor<0xi32 // ----- func @testBroadcastGradientArgsInvalidR1Rank(%s0: tensor<4xi32>, %s1: tensor<4xi32>) -> (tensor<1xi32>, tensor) { - // expected-error @+1 {{requires 'r1' to be a rank 1 tensor, but got rank 0}} + // expected-error @+1 {{failed to verify that result 1 is 1-D}} %r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor) return %r0, %r1 : tensor<1xi32>, tensor } -- GitLab