提交 484f0e5f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Support folding TF::TransposeOp when perm is a constant instead of TF::ConstOp

PiperOrigin-RevId: 328149666
Change-Id: I0c5561152383f12126ab9568c0facc4c3043c6a3
上级 9578a394
......@@ -1939,11 +1939,9 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
namespace {
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
if (!const_perm) return {};
auto const_value = const_perm.value();
const auto elements = const_value.getValues<APInt>();
DenseIntElementsAttr perm;
if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
const auto elements = perm.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) return {};
......@@ -1966,14 +1964,14 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) {
if (!transpose) return {};
// Permutations defined by constant operations.
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
if (!perm0 || !perm1) return {};
DenseIntElementsAttr perm0;
DenseIntElementsAttr perm1;
if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
!matchPattern(transpose.perm(), m_Constant(&perm1)))
return {};
// With permutation indices that cancel each other
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
if (!AreCancellablePermutations(perm0, perm1)) return {};
return transpose.x();
}
......
......@@ -702,6 +702,15 @@ func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
// CHECK: return %arg0
}
// CHECK-LABEL: @identityTransposeConst
func @identityTransposeConst(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
%0 = constant dense<[0, 1, 2, 3, 4]> : tensor<5xi32>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32>
return %1 : tensor<2x3x4x5x6xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: @nonIdentityTranspose
func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> {
%0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
......@@ -724,6 +733,17 @@ func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
// CHECK: return %arg0
}
// CHECK-LABEL: @cancellableTransposeConst
func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
%0 = constant dense<[0, 3, 1, 2]> : tensor<4xi32>
%1 = constant dense<[0, 2, 3, 1]> : tensor<4xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
return %3 : tensor<1x4x4x8xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: @nonCancellableTranspose
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册