diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index c80123bcd50ddff8bace363f68991b16a7e2f357..00b1b746021f175b5d93fd253c11428ef6b6b196 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -37,23 +37,15 @@ namespace m = match; using absl::optional; using hlo_query::ContainsInstrWithOpcode; -// Tries to remove elements in a while loop's tuple that aren't used within the -// loop. -// -// Specifically, if a loop is tuple-shaped, and there exists some element of -// that tuple that is not used by the loop condition and is not used by the loop -// body except to pass it to the next iteration of the loop, then we can remove -// that element from the loop's tuples. -static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - - // Don't try this transformation if the while loop isn't removable, since if - // it succeeds ultimately we're going to have to replace the old while loop - // with a new one. - if (!while_op->parent()->IsSafelyRemovable(while_op)) { - VLOG(2) << "Can't remove dead parameters from non-removable while op."; - return false; - } +// This is a utility function that removes the given tuple indices from the +// while loop init, body, and condition. The final shape returned is still the +// same as before. +static StatusOr RemoveDeadTupleIndices( + HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices) { + // Build up maps from the old/new to the new/old tuple indices. + std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), + used_tuple_indices.end()); + absl::c_sort(new_to_old_tuple_idx); HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); @@ -62,107 +54,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!while_init->shape().IsTuple()) { - VLOG(2) << "While op's carried value isn't tuple shaped."; - return false; - } - - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple(...) instruction."; - return false; - } - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - // Bail if param0 of while_cond or while_body has users which aren't of type - // get-tuple-element. - for (const HloInstruction* instr : {while_body->parameter_instruction(0), - while_cond->parameter_instruction(0)}) { - for (const HloInstruction* user : instr->users()) { - if (user->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToString(print_no_metadata) - << " used by non-GTE instruction " - << user->ToString(print_no_metadata) << " in computation " - << instr->parent()->name(); - return false; - } - } - } - - const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); - if (tuple_size == 0) { - VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " - "empty."; - return false; - } - - absl::flat_hash_set used_tuple_indices; - for (HloComputation* comp : {while_body, while_cond}) { - // The HLO verifier ensures that while_input's shape matches while_init's - // shape, which we verified above is a tuple. - HloInstruction* while_input = comp->parameter_instruction(0); - - for (const HloInstruction* user : while_input->users()) { - // This user doesn't count if it's only used by the while body's root, and - // the root places the tuple element into the same index of the tuple as - // it came from. That just amounts to us carrying the variable through - // the loop. - // - // Careful: HloInstruction::operand_index returns the first index the - // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users().front() == while_body_root && - while_body_root->operand_index(user) == user->tuple_index() && - absl::c_count(while_body_root->operands(), user) == 1) { - continue; - } - - used_tuple_indices.insert(user->tuple_index()); - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If a tuple element is not passed unmodified from the while body's param0 - // through to the while body's root, count that element as "used", since - // removing that element would be observable. - for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.contains(i)) { - continue; - } - - auto* operand = while_body_root->operand(i); - if (operand->opcode() != HloOpcode::kGetTupleElement || - operand->operand(0) != while_body->parameter_instruction(0) || - operand->tuple_index() != i) { - VLOG(2) << "Tuple index " << i - << " is not passed through loop body unmodified."; - used_tuple_indices.insert(i); - - if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) - << " uses all of its inputs; no simplification possible."; - return false; - } - } - } - - // If we got here, used_tuple_indices.size() < tuple_size, meaning some - // elements of the loop's tuple aren't used by while_body or while_cond. - CHECK_LT(used_tuple_indices.size(), tuple_size); - - VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " - << while_op->ToString(print_no_metadata); - - // Build up maps from the old/new to the new/old tuple indices. - std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), - used_tuple_indices.end()); - absl::c_sort(new_to_old_tuple_idx); - absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; @@ -288,6 +181,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // The tuple simplifier will then simplify this if possible, removing // new_tuple and while_init. std::vector new_tuple_elems; + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { @@ -305,9 +199,291 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloInstruction* new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); + + return new_while_op; +} + +// Tries to remove elements in a while loop's tuple that aren't used within the +// loop. +// +// Specifically, if a loop is tuple-shaped, and there exists some element of +// that tuple that is not used by the loop condition and is not used by the loop +// body except to pass it to the next iteration of the loop, then we can remove +// that element from the loop's tuples. +static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + + // Bail if param0 of while_cond or while_body has users which aren't of type + // get-tuple-element. + for (const HloInstruction* instr : {while_body->parameter_instruction(0), + while_cond->parameter_instruction(0)}) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "Cowardly refusing to analyze while loop with " + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); + return false; + } + } + } + + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); + if (tuple_size == 0) { + VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " + "empty."; + return false; + } + + absl::flat_hash_set used_tuple_indices; + for (HloComputation* comp : {while_body, while_cond}) { + // The HLO verifier ensures that while_input's shape matches while_init's + // shape, which we verified above is a tuple. + HloInstruction* while_input = comp->parameter_instruction(0); + + for (const HloInstruction* user : while_input->users()) { + // This user doesn't count if it's only used by the while body's root, and + // the root places the tuple element into the same index of the tuple as + // it came from. That just amounts to us carrying the variable through + // the loop. + // + // Careful: HloInstruction::operand_index returns the first index the + // operand appears in, but it may appear more than once! + if (user->user_count() == 1 && user->users().front() == while_body_root && + while_body_root->operand_index(user) == user->tuple_index() && + absl::c_count(while_body_root->operands(), user) == 1) { + continue; + } + + used_tuple_indices.insert(user->tuple_index()); + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If a tuple element is not passed unmodified from the while body's param0 + // through to the while body's root, count that element as "used", since + // removing that element would be observable. + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + if (used_tuple_indices.contains(i)) { + continue; + } + + auto* operand = while_body_root->operand(i); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->operand(0) != while_body->parameter_instruction(0) || + operand->tuple_index() != i) { + VLOG(2) << "Tuple index " << i + << " is not passed through loop body unmodified."; + used_tuple_indices.insert(i); + + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If we got here, used_tuple_indices.size() < tuple_size, meaning some + // elements of the loop's tuple aren't used by while_body or while_cond. + CHECK_LT(used_tuple_indices.size(), tuple_size); + + VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() + << " elements from tuple of " + << while_op->ToString(print_no_metadata); + + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + return true; } +// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes +// duplicates by replacing them with tuple_index, followed by a call to +// RemoveDeadTupleIndices. +static StatusOr TryRemoveRepeatedWhileTupleIndicesHelper( + HloInstruction* while_op, const int64 tuple_index, + absl::flat_hash_set& duplicates) { + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_init = while_op->mutable_operand(0); + + VLOG(2) << "while_init " << while_init->ToString() << " operands " + << while_init->operand_count(); + VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString() + << " operands " << while_body->root_instruction()->operand_count(); + + // Change the loop body and condition such that uses of the duplicates are + // replaced with the original tuple element. + for (HloComputation* comp : {while_body, while_cond}) { + auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement( + comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index), + comp->parameter_instruction(0), tuple_index)); + + std::vector instrs_to_replace; + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kGetTupleElement && + duplicates.contains(instr->tuple_index()) && + instr->operand(0) == comp->parameter_instruction(0)) { + instrs_to_replace.push_back(instr); + } + } + + for (auto instr : instrs_to_replace) { + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get)); + } + } + + // We know which tuple indices are useful; i.e, those which aren't duplicates. + absl::flat_hash_set used_tuple_indices; + for (int index = 0; index < while_init->shape().tuple_shapes_size(); + ++index) { + if (!duplicates.count(index)) { + used_tuple_indices.insert(index); + } + } + + // Remove the duplicate tuple elements. + TF_ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); + + return while_op; +} + +// If the while loop init passes the same values to several tuple indices, and +// if the body keeps on passing them through, we can remove the duplicates. +static StatusOr TryRemoveRepeatedWhileTupleIndices( + HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + int index_to_investigate = 0; + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsSafelyRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!while_init->shape().IsTuple()) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + bool changed = false; + while (index_to_investigate < while_init->shape().tuple_shapes_size()) { + if (!while_init->shape().IsTuple() || + while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto& while_shape = while_init->shape(); + VLOG(2) << "Iterating " << index_to_investigate; + + absl::flat_hash_set duplicates; + auto* pivot_init_elem = while_init->operand(index_to_investigate); + auto* pivot_body_elem = while_body_root->operand(index_to_investigate); + if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement) { + if (pivot_body_elem->tuple_index() != index_to_investigate) { + VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() " + << pivot_body_elem->tuple_index() << " index_to_investigate " + << index_to_investigate; + index_to_investigate++; + continue; + } + } else { + index_to_investigate++; + continue; + } + + // Look from index_to_investigate onwards to see if it is repeated. + for (int64 i = index_to_investigate + 1; + i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (body_elem->opcode() == HloOpcode::kGetTupleElement) { + if (body_elem->tuple_index() != i) { + VLOG(2) << "Mismatch between body_elem->tuple_index() " + << body_elem->tuple_index() << " i " << i; + continue; + } + } else { + continue; + } + + if (pivot_init_elem == init_elem /*&& pivot_body_elem == body_elem*/) { + VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem " + << pivot_init_elem->ToString(); + VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem " + << pivot_body_elem->ToString(); + duplicates.insert(i); + } + } + + // If duplicates are found, call the helper to remove them. + if (!duplicates.empty()) { + VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init " + << pivot_init_elem->ToString(); + TF_ASSIGN_OR_RETURN(while_op, + TryRemoveRepeatedWhileTupleIndicesHelper( + while_op, index_to_investigate, duplicates)); + changed = true; + VLOG(2) << "Changed while_op " << while_op->ToString() + << " while_op operand count " << while_op->operand_count(); + // Update the while loop variables so we can continue looking for + // duplicates of a different index. + while_init = while_op->mutable_operand(0); + while_cond = while_op->while_condition(); + while_body = while_op->while_body(); + while_body_root = while_body->root_instruction(); + } + index_to_investigate++; + } + + return changed; +} + // Removes each loop parameter (i.e. member of the while loop tuple) that is a // constant and is the same in the while loop body and the while loop init. static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { @@ -1048,6 +1224,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); changed |= result; + if (result) { // Don't continue simplifying after successfully removing the while loop // -- that would result in use-after-free nastiness. @@ -1067,6 +1244,12 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { // successful, meaning that `while_op` is no longer valid after one of these // transformations returns true. + TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op)); + changed |= result; + if (result) { + continue; + } + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { @@ -1074,6 +1257,7 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { } TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; if (result) { continue; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d715fb3857a13e25f89588bf5ea90b468b719c93..c93cb5dc34761c18d21774716ac0ab310ede58cf 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -794,5 +794,51 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { .ValueOrDie()); } +TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) { + const string hlo_string = R"( + HloModule SwappingTupleElements + + SwappingTupleElements.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(loop_var), index=0 + get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1 + get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2 + y = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y, + s32[] get-tuple-element.2) + } + + SwappingTupleElements.always_true { + param = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element(param), index=0 + get-tuple-element.1 = s32[] get-tuple-element(param), index=1 + ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT + } + + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x) + ROOT while = (s32[], s32[], s32[]) while(tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); +} + } // namespace } // namespace xla