提交 06b8171f 编写于 作者: S Smit Hinsu 提交者: TensorFlower Gardener

Lower TensorFlow LogicalNot, ShiftLeft and ShiftRight ops to corresponding HLO ops

PiperOrigin-RevId: 285837508
Change-Id: Ie870171df441467fd3cc06cfd5a2dcacd341437c
上级 e22dfad0
......@@ -87,6 +87,13 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
return %0: tensor<1x2xi32>
}
// CHECK-LABEL: func @shift_left
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
%0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @maximum
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.max %arg0, %arg1 : tensor<4xf32>
......@@ -145,6 +152,34 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
return %0: tensor<1x2xi32>
}
// CHECK-LABEL: func @shift_right
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @broadcast_shift_right
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
// CHECK: "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>}
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
// CHECK-LABEL: func @shift_right_unsigned
func @shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> {
// CHECK: tf.RightShift
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8>
return %0 : tensor<4x!tf.uint8>
}
// CHECK-LABEL: func @broadcast_shift_right_unsigned
func @broadcast_shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> {
// CHECK: tf.RightShift
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8>
return %0 : tensor<2x4x!tf.uint8>
}
// CHECK-LABEL: func @and
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
// CHECK-NEXT: xla_hlo.and
......@@ -1243,6 +1278,13 @@ func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @not_op_unranked
func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> {
// CHECK: "xla_hlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
%0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
return %0 : tensor<*xi1>
}
// CHECK-LABEL: @neg
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
......
......@@ -20,6 +20,11 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>;
// IEEE compliant floating point tensors.
def IEEEFloatTensor : TensorOf<[F16, F32, F64]>;
//===----------------------------------------------------------------------===//
// BatchNorm op patterns.
//===----------------------------------------------------------------------===//
......@@ -93,6 +98,7 @@ class DirectBinaryPat<Op FromOp, Op ToOp>
foreach fromToBinPair = [[TF_AddOp, HLO_AddOp],
[TF_AddV2Op, HLO_AddOp],
[TF_DivOp, HLO_DivOp],
[TF_LeftShiftOp, HLO_ShiftLeftOp],
[TF_MaximumOp, HLO_MaxOp],
[TF_MinimumOp, HLO_MinOp],
[TF_MulOp, HLO_MulOp],
......@@ -101,12 +107,15 @@ foreach fromToBinPair = [[TF_AddOp, HLO_AddOp],
[TF_SubOp, HLO_SubOp]] in
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
def LowerRightShiftSigned :
Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)),
[(SignedIntTensor $r)]>;
def IntegerTensor : TensorOf<[I1, I8, I16, I32, I64]>;
// TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op
// supports unsigned integers.
// IEEE compliant floating point tensors.
def IEEEFloatTensor : TensorOf<[F16, F32, F64]>;
def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
// Performs a substitution of FloorDiv, pseudo code below:
//
......@@ -131,7 +140,7 @@ def : Pat<(TF_FloorDivOp IEEEFloatTensor:$l, IEEEFloatTensor:$r),
// without returning the broadcast of 'r' to broadcast('l', 'r').
//
// NOTE: This should be optimized for unsigned integers.
def : Pat<(TF_FloorDivOp IntegerTensor:$l, IntegerTensor:$r),
def : Pat<(TF_FloorDivOp SignedIntTensor:$l, SignedIntTensor:$r),
(HLO_SelectOp
(HLO_CompareOp
(HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)),
......@@ -186,7 +195,7 @@ def : Pat<(TF_BroadcastToOp:$result AnyRankedTensor:$input, $shape),
//===----------------------------------------------------------------------===//
class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp IntegerTensor:$l, IntegerTensor:$r),
: Pat<(FromOp SignedIntTensor:$l, SignedIntTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp],
......@@ -412,6 +421,7 @@ foreach Mapping = [
[TF_ImagOp, HLO_ImagOp],
[TF_IsFiniteOp, HLO_IsFiniteOp],
[TF_LogOp, HLO_LogOp],
[TF_LogicalNotOp, HLO_NotOp],
[TF_NegOp, HLO_NegOp],
[TF_RealOp, HLO_RealOp],
[TF_RsqrtOp, HLO_RsqrtOp],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册