提交 25310d07 编写于 作者: H Haoliang Zhang 提交者: TensorFlower Gardener

Add legalization pattern for tf.SelectV2 op.

PiperOrigin-RevId: 262622407
上级 c9552455
......@@ -509,6 +509,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) ->
// CHECK: return %0 : tensor<8xf32>
}
func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
return %0: tensor<8xf32>
// CHECK-LABEL: select_v2
// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8xf32>
}
func @sin(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "tf.Sin"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
......
......@@ -141,6 +141,7 @@ def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
// TODO(jpienaar): this is not true for all selects, TF's select supports rank 0
// condition
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_SelectV2Op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
......
......@@ -2823,6 +2823,63 @@ For example:
select(condition, t, e) # => [[1, 6], [7, 4]]
# 'condition' tensor is [True, False]
# 't' is [[1, 2],
# [3, 4]]
# 'e' is [[5, 6],
# [7, 8]]
select(condition, t, e) ==> [[1, 2],
[7, 8]]
```
}];
let arguments = (ins
I1Tensor:$condition,
TF_Tensor:$t,
TF_Tensor:$e
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> {
let summary = "Selects elements from `x` or `y`, depending on `condition`.";
let description = [{
The `x`, and `y` tensors must all have the same shape, and the
output will also have that shape.
The `condition` tensor must be a scalar if `x` and `y` are scalars.
If `x` and `y` are vectors or higher rank, then `condition` must be either a
scalar, a vector with size matching the first dimension of `x`, or must have
the same shape as `x`.
The `condition` tensor acts as a mask that chooses, based on the value at each
element, whether the corresponding element / row in the output should be
taken from `x` (if true) or `y` (if false).
If `condition` is a vector and `x` and `y` are higher rank matrices, then
it chooses which row (outer dimension) to copy from `x` and `y`.
If `condition` has the same shape as `x` and `y`, then it chooses which
element to copy from `x` and `y`.
For example:
```python
# 'condition' tensor is [[True, False]
# [False, True]]
# 't' is [[1, 2],
# [3, 4]]
# 'e' is [[5, 6],
# [7, 8]]
select(condition, t, e) # => [[1, 6], [7, 4]]
# 'condition' tensor is [True, False]
# 't' is [[1, 2],
# [3, 4]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册