From dab94104e13a7ab5192303c8b2784e3c36af60ac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 22 Feb 2019 11:56:28 -0800 Subject: [PATCH] Automated rollback of commit 4bf9ea4295a0eff68ab737ab41399ab8bb3464ef PiperOrigin-RevId: 235232101 --- .../compiler/xla/service/layout_assignment.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index bf1d1c233f1..aa791ea195e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1019,6 +1019,16 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); + if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { + return absl::make_unique(operand_shape.layout()); + } + if (operand_shape.rank() == output_shape.rank()) { + *operand_shape.mutable_layout() = output_layout; + if (ShapeUtil::ReshapeIsBitcast(operand_shape, + output_shape_with_layout)) { + return absl::make_unique(output_layout); + } + } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); if (aligned_operand_shape) { @@ -1080,6 +1090,16 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( Shape output_shape = user->shape(); *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); + if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { + return absl::make_unique(output_shape.layout()); + } + if (operand->shape().rank() == output_shape.rank()) { + *output_shape.mutable_layout() = operand_layout; + if (ShapeUtil::ReshapeIsBitcast(output_shape, + operand_shape_with_layout)) { + return absl::make_unique(operand_layout); + } + } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); if (aligned_user_shape) { -- GitLab