提交 f36cd515 编写于 作者: S Smit Hinsu 提交者: TensorFlower Gardener

Add type constraints in patterns using BinBroadcastDimensions and ConstantSplat helpers

ConstantSplat requires static shaped types
BinBroadcastDimensions requires ranked types

PiperOrigin-RevId: 285863722
Change-Id: Ia2e1220568ab4eae8683b4c7b74ab1c4a38a1240
上级 42afc3e5
......@@ -94,6 +94,20 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @div_dynamic
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%0 = "tf.Div"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @div_unranked
func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK: tf.Div
%0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @maximum
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.max %arg0, %arg1 : tensor<4xf32>
......@@ -201,6 +215,13 @@ func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
return %0: tensor<?xi1>
}
// CHECK-LABEL: func @and_unranked
func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> {
// CHECK: tf.LogicalAnd
%0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @or
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
// CHECK-NEXT: xla_hlo.or
......@@ -345,6 +366,20 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: func @floordiv_dynamic
func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
// CHECK: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @floordiv_unranked
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}
// CHECK-LABEL: func @floormod_broadcast_numerator
func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
......@@ -379,6 +414,20 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32
return %0: tensor<2x3xi32>
}
// CHECK-LABEL: func @floormod_dynamic
func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
// CHECK: tf.FloorMod
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @floormod_unranked
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorMod
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}
// CHECK-LABEL: func @broadcast_to
func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> {
%cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
......@@ -450,6 +499,13 @@ func @equal_incompatible_shape_both_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @equal_unranked
func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> {
// CHECK: "tf.Equal"
%0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @notequal
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
......@@ -517,6 +573,20 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
return %0: tensor<1x2xi1>
}
// CHECK-LABEL: func @greater_dynamic
func @greater_dynamic(%arg0: tensor<?xi32>) -> tensor<?xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
return %0: tensor<?xi1>
}
// CHECK-LABEL: func @greater_uranked
func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> {
// CHECK: "tf.Greater"
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
return %0: tensor<*xi1>
}
// CHECK-LABEL: func @greater_equal
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
......
......@@ -92,7 +92,7 @@ def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;
class DirectBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp $l, $r),
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
foreach fromToBinPair = [[TF_AddOp, HLO_AddOp],
......@@ -120,8 +120,9 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
// Performs a substitution of FloorDiv, pseudo code below:
//
// return floor(div(x, y))
def : Pat<(TF_FloorDivOp IEEEFloatTensor:$l, IEEEFloatTensor:$r),
(HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)))>;
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))),
[(IEEEFloatTensor $l)]>;
// Performs a substitution of FloorDir for integer tensors, which required
// additional correction for a negative numerator / denominator. Equivalent
......@@ -140,7 +141,9 @@ def : Pat<(TF_FloorDivOp IEEEFloatTensor:$l, IEEEFloatTensor:$r),
// without returning the broadcast of 'r' to broadcast('l', 'r').
//
// NOTE: This should be optimized for unsigned integers.
def : Pat<(TF_FloorDivOp SignedIntTensor:$l, SignedIntTensor:$r),
// Requires static shaped inputs to create constant splats and computation of
// broadcast attributes.
def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
(HLO_SelectOp
(HLO_CompareOp
(HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)),
......@@ -155,14 +158,17 @@ def : Pat<(TF_FloorDivOp SignedIntTensor:$l, SignedIntTensor:$r),
(HLO_ConstOp (ConstantSplat<"1"> $r)),
(NullDenseIntElementsAttr)),
(BinBroadcastDimensions $l, $r))),
(HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs)))>;
(HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))),
[(SignedIntTensor $l)]>;
// Performs a substitution of FloorMod designed to correct for possibly negative
// values. Pseudocode shown below:
//
// T trunc_mod = std::fmod(x, y);
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
def : Pat<(TF_FloorModOp $l, $r),
// Requires static shaped inputs to create constant splats and computation of
// broadcast attributes.
def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r),
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
......@@ -195,8 +201,9 @@ def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape),
//===----------------------------------------------------------------------===//
class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp SignedIntTensor:$l, SignedIntTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
[(SignedIntTensor $l)]>;
foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp],
[TF_LogicalOrOp, HLO_OrOp],
......@@ -208,7 +215,7 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp],
//===----------------------------------------------------------------------===//
class DirectComparePat<Op FromOp, StrEnumAttrCase direction>
: Pat<(FromOp $l, $r),
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction)>;
def : DirectComparePat<TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT>;
......@@ -217,7 +224,7 @@ def : DirectComparePat<TF_LessOp, HLO_COMPARISON_DIRECTION_LT>;
def : DirectComparePat<TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE>;
class EqualityPat<Op FromOp, StrEnumAttrCase direction>
: Pat<(FromOp $l, $r,
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r,
TrueBoolAttr:$incompatible_shape_error),
(HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction),
[(AreBroadcastCompatible $l, $r)]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册