提交 e1b51397 编写于 作者: B Blake Hechtman 提交者: TensorFlower Gardener

PR #23661: [XLA] Simplify transposes that are really reshapes

Please approve this CL. It will be submitted automatically, and its GitHub pull request will be marked as merged.

Imported from GitHub PR #23661

A transpose like
```
f32[1,1,64,1] = transpose(f32[1,64,1,1]), dimensions={3,2,1,0}
```
is really just a reshape (because there's only one non-1 dimension).
Teach algebraic simplifier to make that substitution, to enable applying
reshape-combining optimizations to such instructions.

Copybara import of the project:

  - c51f19ef50f993677d7d58d9dcf3de6785540e0b [XLA] Simplify transposes that are really reshapes by Keno Fischer <keno@juliacomputing.com>
  - 10fe3503be362e28906e6a01d0d272903f693817 [XLA] Canonicalize Transpose by dropping degenerate dims by Keno Fischer <keno@juliacomputing.com>
  - 333cdccc3a045ebdb36ca03e8877706d5659642e Merge 10fe3503be362e28906e6a01d0d272903f693817 into 3dfb4... by Keno Fischer <keno@alumni.harvard.edu>

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/23661 from JuliaComputing:kf/transposereshape 10fe3503be362e28906e6a01d0d272903f693817
PiperOrigin-RevId: 225416731
上级 fe4328b4
......@@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cmath>
#include <functional>
#include <iterator>
#include <memory>
#include <numeric>
......@@ -2026,6 +2027,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
reshape, HloInstruction::CreateReshape(reshape->shape(),
operand->mutable_operand(0)));
}
if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
*operand->mutable_shape() = reshape->shape();
return ReplaceInstruction(reshape, operand);
......@@ -2748,6 +2750,22 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
return Status::OK();
}
namespace {
bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape,
absl::Span<const int64> perm) {
std::vector<int64> new_permutation;
int64 degenerate_count = 0;
for (int64 i = 0; i < perm.size(); ++i) {
if (shape.dimensions(i) != 1) {
new_permutation.push_back(perm[i]);
} else {
++degenerate_count;
}
}
return degenerate_count > 1 && absl::c_is_sorted(new_permutation);
}
} // namespace
Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
auto operand = transpose->mutable_operand(0);
if (std::is_sorted(transpose->dimensions().begin(),
......@@ -2764,6 +2782,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
transpose->dimensions())));
}
// Replace transpose with a reshape if more than one degenerate method is
// permuted.
if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(),
transpose->dimensions())) {
return ReplaceWithNewInstruction(
transpose, HloInstruction::CreateReshape(
transpose->shape(), transpose->mutable_operand(0)));
}
if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
*operand->mutable_shape() = transpose->shape();
return ReplaceInstruction(transpose, operand);
......
......@@ -2047,6 +2047,27 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
computation->root_instruction()->dimensions());
}
TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
const char* hlo_string = R"(
HloModule module
ENTRY test {
param = f32[10] parameter(0)
reshaped = f32[1,1,10] reshape(f32[10] param)
transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module,
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, GmockMatch(m::Parameter()));
}
// Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
auto m = CreateNewVerifiedModule();
......
......@@ -1067,6 +1067,11 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
return absl::c_linear_search(shape.dimensions(), 1);
}
/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) {
return FilterDimensions(
[&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape);
}
namespace {
// Helper for ForEachSubshape which visits the subshapes of the given shape in
......
......@@ -551,6 +551,9 @@ class ShapeUtil {
// (dimensions with bound 1).
static bool HasDegenerateDimensions(const Shape& shape);
// Drops any degenerate dimensions (i.e. dimensions of size 1)
static Shape DropDegenerateDimensions(const Shape& shape);
// Permutes the dimensions by the given permutation, so
// return_value.dimensions[permutation[i]] = argument.dimensions[i].
//
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册