提交 ce55348e 编写于 作者: Y Yuanzhong Xu 提交者: TensorFlower Gardener

Fix LowerInvertPermutationOp

<5xi32> to <5x1xi32> is a reshape, not a transpose

PiperOrigin-RevId: 306531112
Change-Id: I1e5541bc43997eda222837691bcbad7107f57982
上级 9d53fd3a
......@@ -3,8 +3,8 @@
// CHECK-LABEL: invert_permutation
func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> {
// CHECK-NEXT: %[[UPDATES:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
// CHECK-NEXT: %[[PERM:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK-NEXT: %[[INDICES:.*]] = "tf.Transpose"(%arg0, %[[PERM]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK-NEXT: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
// CHECK-NEXT: "tf.TensorScatterUpdate"(%arg0, %[[INDICES]], %[[UPDATES]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
%0 = "tf.InvertPermutation"(%arg0) : (tensor<5xi32>) -> tensor<5xi32>
return %0 : tensor<5xi32>
......
......@@ -253,8 +253,8 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
// %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
// %updates = "tf.Range"(%start, %limit, %delta) :
// (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
// %perm = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>}
// %indices = "tf.Transpose"(%x, %perm) : (tensor<5xi32, tensor<2xi32) ->
// %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
// %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
// tensor<5x1xi32>
// "tf.TensorScatterUpdate"(%x, %indices, %updates) :
// (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
......@@ -268,13 +268,12 @@ class LowerInvertPermutationOp
LogicalResult matchAndRewrite(TF::InvertPermutationOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto x_type = op.x().getType().cast<TensorType>();
Type int_type = x_type.getElementType(); // Could be i32 or i64.
auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
// x input must have static shape.
if (!x_type.hasStaticShape()) {
if (!x_type || !x_type.hasStaticShape()) {
return failure();
}
Type int_type = x_type.getElementType(); // Could be i32 or i64.
auto result_type = x_type;
auto start =
......@@ -287,13 +286,11 @@ class LowerInvertPermutationOp
auto updates =
rewriter.create<TF::RangeOp>(loc, result_type, start, limit, delta);
auto perm_type = RankedTensorType::get({2}, int_type);
auto perm = rewriter.create<TF::ConstOp>(
loc, DenseElementsAttr::get(perm_type, {1, 0}));
auto transposed_x_type =
RankedTensorType::get({x_type.getShape()[0], 1}, int_type);
auto indices =
rewriter.create<TF::TransposeOp>(loc, transposed_x_type, op.x(), perm);
auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
auto shape = rewriter.create<TF::ConstOp>(
loc, DenseElementsAttr::get(
shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
auto indices = rewriter.create<TF::ReshapeOp>(loc, op.x(), shape);
rewriter.replaceOpWithNewOp<TF::TensorScatterUpdateOp>(
op, result_type, op.x(), indices, updates);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册