提交 484f0e5f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Support folding TF::TransposeOp when perm is a constant instead of TF::ConstOp

PiperOrigin-RevId: 328149666
Change-Id: I0c5561152383f12126ab9568c0facc4c3043c6a3
上级 9578a394
...@@ -1939,11 +1939,9 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, ...@@ -1939,11 +1939,9 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
namespace { namespace {
OpFoldResult FoldIdentityTranspose(TransposeOp op) { OpFoldResult FoldIdentityTranspose(TransposeOp op) {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp()); DenseIntElementsAttr perm;
if (!const_perm) return {}; if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
const auto elements = perm.getValues<APInt>();
auto const_value = const_perm.value();
const auto elements = const_value.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) { for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) return {}; if (it.index() != it.value()) return {};
...@@ -1966,14 +1964,14 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) { ...@@ -1966,14 +1964,14 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) {
if (!transpose) return {}; if (!transpose) return {};
// Permutations defined by constant operations. // Permutations defined by constant operations.
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp()); DenseIntElementsAttr perm0;
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp()); DenseIntElementsAttr perm1;
if (!perm0 || !perm1) return {}; if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
!matchPattern(transpose.perm(), m_Constant(&perm1)))
return {};
// With permutation indices that cancel each other // With permutation indices that cancel each other
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>(); if (!AreCancellablePermutations(perm0, perm1)) return {};
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
return transpose.x(); return transpose.x();
} }
......
...@@ -702,6 +702,15 @@ func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> { ...@@ -702,6 +702,15 @@ func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
// CHECK: return %arg0 // CHECK: return %arg0
} }
// CHECK-LABEL: @identityTransposeConst
func @identityTransposeConst(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
%0 = constant dense<[0, 1, 2, 3, 4]> : tensor<5xi32>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32>
return %1 : tensor<2x3x4x5x6xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: @nonIdentityTranspose // CHECK-LABEL: @nonIdentityTranspose
func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> { func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> {
%0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> %0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
...@@ -724,6 +733,17 @@ func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { ...@@ -724,6 +733,17 @@ func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
// CHECK: return %arg0 // CHECK: return %arg0
} }
// CHECK-LABEL: @cancellableTransposeConst
func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
%0 = constant dense<[0, 3, 1, 2]> : tensor<4xi32>
%1 = constant dense<[0, 2, 3, 1]> : tensor<4xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
return %3 : tensor<1x4x4x8xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: @nonCancellableTranspose // CHECK-LABEL: @nonCancellableTranspose
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> { func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册