提交 63a74434 编写于 作者: H Haitang Hu 提交者: TensorFlower Gardener

Add Optraits helpers for verifing operand/result rank.

PiperOrigin-RevId: 339497489
Change-Id: Ic28a1b83ce19ae5bde086495f9d937abbcea4cc0
上级 f486d439
......@@ -1335,7 +1335,7 @@ broadcasted shape. `s0`, `s1` and `r0` are all integer vectors.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect, SameOperandsAndResultElementType]> {
def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect, SameOperandsAndResultElementType, TF_OperandHasRank<0, 1>, TF_OperandHasRank<1, 1>, TF_ResultHasRank<0, 1>, TF_ResultHasRank<1, 1>]> {
let summary = [{
Return the reduction indices for computing gradients of s0 op s1 with broadcast.
}];
......
......@@ -100,6 +100,30 @@ class TF_AllTypesMatch<list<string> names> :
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
//===----------------------------------------------------------------------===//
// Rank/Shape helpers.
//===----------------------------------------------------------------------===//
class TF_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
class TF_ResultIsUnrankedPred<int n> :
CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
// Returns true if the n-th operand has unknown rank or has rank m.
class TF_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TF_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th result has unknown rank or has rank m.
class TF_ResultHasRank<int n, int m> :
PredOpTrait<"result " # n # " is " # m # "-D",
Or<[TF_ResultIsUnrankedPred<n>,
CPred<"$_op.getResult(" # n #
").getType().cast<ShapedType>().getRank() == " # m>]>>;
//===----------------------------------------------------------------------===//
// TensorFlow op side effects
//===----------------------------------------------------------------------===//
......
......@@ -546,8 +546,8 @@ OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
namespace {
// Returns `true` if both s0 & s1 is defined via constant op, and fills s0_shape
// & s1_shape.
// Returns `true` if both s0 & s1 are defined via constant op, and fills
// s0_shape & s1_shape.
bool ExtractInputConstShape(BroadcastGradientArgsOp op,
DenseIntElementsAttr &s0, DenseIntElementsAttr &s1,
SmallVectorImpl<int64_t> &s0_shape,
......@@ -594,37 +594,11 @@ void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
} // namespace
// Verifies that,
// * all inputs/outputs rank is 1.
// For cases where static shape can be extracted from inputs, verifies that,
// * Broadcast compatability for input shapes.
// * Output shape dimension matches the expected dimension size for input
// shapes.
static LogicalResult Verify(BroadcastGradientArgsOp op) {
// Check rank = 1 for input/outputs.
RankedTensorType s0_ty = GetRankedTensorTypeForOperand(op.s0());
RankedTensorType s1_ty = GetRankedTensorTypeForOperand(op.s1());
RankedTensorType r0_ty = GetRankedTensorTypeForOperand(op.r0());
RankedTensorType r1_ty = GetRankedTensorTypeForOperand(op.r1());
if (s0_ty && s0_ty.getRank() != 1)
return op.emitOpError()
<< "requires 's0' to be a rank 1 tensor, but got rank "
<< s0_ty.getRank();
if (s1_ty && s1_ty.getRank() != 1)
return op.emitOpError()
<< "requires 's1' to be a rank 1 tensor, but got rank "
<< s1_ty.getRank();
if (r0_ty && r0_ty.getRank() != 1)
return op.emitOpError()
<< "requires 'r0' to be a rank 1 tensor, but got rank "
<< r0_ty.getRank();
if (r1_ty && r1_ty.getRank() != 1)
return op.emitOpError()
<< "requires 'r1' to be a rank 1 tensor, but got rank "
<< r1_ty.getRank();
SmallVector<int64_t, 4> s0_shape;
SmallVector<int64_t, 4> s1_shape;
SmallVector<int64_t, 4> s0_shape, s1_shape;
DenseIntElementsAttr s0, s1;
if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success();
......@@ -635,11 +609,12 @@ static LogicalResult Verify(BroadcastGradientArgsOp op) {
"for 's0' and 's1', but got "
<< s0 << " and " << s1;
SmallVector<int64_t, 4> r0;
SmallVector<int64_t, 4> r1;
SmallVector<int64_t, 4> r0, r1;
GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
r1);
RankedTensorType r0_ty = GetRankedTensorTypeForOperand(op.r0());
RankedTensorType r1_ty = GetRankedTensorTypeForOperand(op.r1());
if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getShape()[0] != r0.size())
return op.emitOpError() << "requires dimension 0 size of 'r0' to be "
<< r0.size() << " but got " << r0_ty.getShape()[0];
......@@ -652,9 +627,8 @@ static LogicalResult Verify(BroadcastGradientArgsOp op) {
LogicalResult BroadcastGradientArgsOp::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
SmallVector<int64_t, 4> s0_shape, s1_shape;
DenseIntElementsAttr s0, s1;
SmallVector<int64_t, 4> s0_shape;
SmallVector<int64_t, 4> s1_shape;
if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape))
return failure();
......@@ -667,8 +641,7 @@ LogicalResult BroadcastGradientArgsOp::fold(
assert(bcast_compatible);
(void)bcast_compatible;
SmallVector<int64_t, 4> r0;
SmallVector<int64_t, 4> r1;
SmallVector<int64_t, 4> r0, r1;
GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
r1);
......
......@@ -185,7 +185,7 @@ func @testBroadcastGradientArgsIncompatibleBroadcastShape() -> (tensor<1xi32>, t
func @testBroadcastGradientArgsInvalidS0Rank() -> (tensor<2x2xi32>, tensor<0xi32>) {
%s0 = "tf.Const"() {value = dense<[[4, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%s1 = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error @+1 {{requires 's0' to be a rank 1 tensor, but got rank 2}}
// expected-error @+1 {{failed to verify that operand 0 is 1-D}}
%r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<2x2xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<0xi32>)
return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
}
......@@ -195,7 +195,7 @@ func @testBroadcastGradientArgsInvalidS0Rank() -> (tensor<2x2xi32>, tensor<0xi32
func @testBroadcastGradientArgsInvalidS1Rank() -> (tensor<2xi32>, tensor<i32>) {
%s0 = "tf.Const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%s1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{requires 's1' to be a rank 1 tensor, but got rank 0}}
// expected-error @+1 {{failed to verify that operand 1 is 1-D}}
%r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<2xi32>, tensor<i32>) -> (tensor<1xi32>, tensor<0xi32>)
return %r0, %r1 : tensor<1xi32>, tensor<0xi32>
}
......@@ -205,7 +205,7 @@ func @testBroadcastGradientArgsInvalidS1Rank() -> (tensor<2xi32>, tensor<i32>) {
func @testBroadcastGradientArgsInvalidR0Rank() -> (tensor<2x2xi32>, tensor<0xi32>) {
%s0 = "tf.Const"() {value = dense<[4, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%s1 = "tf.Const"() {value = dense<[4, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error @+1 {{requires 'r0' to be a rank 1 tensor, but got rank 2}}
// expected-error @+1 {{failed to verify that result 0 is 1-D}}
%r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2x2xi32>, tensor<0xi32>)
return %r0, %r1 : tensor<2x2xi32>, tensor<0xi32>
}
......@@ -213,7 +213,7 @@ func @testBroadcastGradientArgsInvalidR0Rank() -> (tensor<2x2xi32>, tensor<0xi32
// -----
func @testBroadcastGradientArgsInvalidR1Rank(%s0: tensor<4xi32>, %s1: tensor<4xi32>) -> (tensor<1xi32>, tensor<i32>) {
// expected-error @+1 {{requires 'r1' to be a rank 1 tensor, but got rank 0}}
// expected-error @+1 {{failed to verify that result 1 is 1-D}}
%r0, %r1 = "tf.BroadcastGradientArgs"(%s0, %s1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<1xi32>, tensor<i32>)
return %r0, %r1 : tensor<1xi32>, tensor<i32>
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册