提交 4370219c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

While loop simplification to remove repeated parameters

PiperOrigin-RevId: 328182995
Change-Id: I2bfb11c5043218b9dfb042c984f81c3d6e114e7d
上级 367ea1a9
......@@ -37,23 +37,15 @@ namespace m = match;
using absl::optional;
using hlo_query::ContainsInstrWithOpcode;
// Tries to remove elements in a while loop's tuple that aren't used within the
// loop.
//
// Specifically, if a loop is tuple-shaped, and there exists some element of
// that tuple that is not used by the loop condition and is not used by the loop
// body except to pass it to the next iteration of the loop, then we can remove
// that element from the loop's tuples.
static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
// Don't try this transformation if the while loop isn't removable, since if
// it succeeds ultimately we're going to have to replace the old while loop
// with a new one.
if (!while_op->parent()->IsSafelyRemovable(while_op)) {
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
return false;
}
// This is a utility function that removes the given tuple indices from the
// while loop init, body, and condition. The final shape returned is still the
// same as before.
static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
// Build up maps from the old/new to the new/old tuple indices.
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
used_tuple_indices.end());
absl::c_sort(new_to_old_tuple_idx);
HloModule* module = while_op->GetModule();
HloComputation* computation = while_op->parent();
......@@ -62,107 +54,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
HloComputation* while_body = while_op->while_body();
HloInstruction* while_body_root = while_body->root_instruction();
if (!while_init->shape().IsTuple()) {
VLOG(2) << "While op's carried value isn't tuple shaped.";
return false;
}
if (while_body_root->opcode() != HloOpcode::kTuple) {
VLOG(2) << "While body's root is not a tuple(...) instruction.";
return false;
}
auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
// Bail if param0 of while_cond or while_body has users which aren't of type
// get-tuple-element.
for (const HloInstruction* instr : {while_body->parameter_instruction(0),
while_cond->parameter_instruction(0)}) {
for (const HloInstruction* user : instr->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement) {
VLOG(2) << "Cowardly refusing to analyze while loop with "
<< instr->ToString(print_no_metadata)
<< " used by non-GTE instruction "
<< user->ToString(print_no_metadata) << " in computation "
<< instr->parent()->name();
return false;
}
}
}
const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
if (tuple_size == 0) {
VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
"empty.";
return false;
}
absl::flat_hash_set<int64> used_tuple_indices;
for (HloComputation* comp : {while_body, while_cond}) {
// The HLO verifier ensures that while_input's shape matches while_init's
// shape, which we verified above is a tuple.
HloInstruction* while_input = comp->parameter_instruction(0);
for (const HloInstruction* user : while_input->users()) {
// This user doesn't count if it's only used by the while body's root, and
// the root places the tuple element into the same index of the tuple as
// it came from. That just amounts to us carrying the variable through
// the loop.
//
// Careful: HloInstruction::operand_index returns the first index the
// operand appears in, but it may appear more than once!
if (user->user_count() == 1 && user->users().front() == while_body_root &&
while_body_root->operand_index(user) == user->tuple_index() &&
absl::c_count(while_body_root->operands(), user) == 1) {
continue;
}
used_tuple_indices.insert(user->tuple_index());
if (used_tuple_indices.size() == tuple_size) {
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
<< " uses all of its inputs; no simplification possible.";
return false;
}
}
}
// If a tuple element is not passed unmodified from the while body's param0
// through to the while body's root, count that element as "used", since
// removing that element would be observable.
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
if (used_tuple_indices.contains(i)) {
continue;
}
auto* operand = while_body_root->operand(i);
if (operand->opcode() != HloOpcode::kGetTupleElement ||
operand->operand(0) != while_body->parameter_instruction(0) ||
operand->tuple_index() != i) {
VLOG(2) << "Tuple index " << i
<< " is not passed through loop body unmodified.";
used_tuple_indices.insert(i);
if (used_tuple_indices.size() == tuple_size) {
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
<< " uses all of its inputs; no simplification possible.";
return false;
}
}
}
// If we got here, used_tuple_indices.size() < tuple_size, meaning some
// elements of the loop's tuple aren't used by while_body or while_cond.
CHECK_LT(used_tuple_indices.size(), tuple_size);
VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
<< " elements from tuple of "
<< while_op->ToString(print_no_metadata);
// Build up maps from the old/new to the new/old tuple indices.
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
used_tuple_indices.end());
absl::c_sort(new_to_old_tuple_idx);
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
int64 old_idx = new_to_old_tuple_idx[new_idx];
......@@ -288,6 +181,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// The tuple simplifier will then simplify this if possible, removing
// new_tuple and while_init.
std::vector<HloInstruction*> new_tuple_elems;
const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
......@@ -305,9 +199,291 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
HloInstruction* new_tuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
return new_while_op;
}
// Tries to remove elements in a while loop's tuple that aren't used within the
// loop.
//
// Specifically, if a loop is tuple-shaped, and there exists some element of
// that tuple that is not used by the loop condition and is not used by the loop
// body except to pass it to the next iteration of the loop, then we can remove
// that element from the loop's tuples.
static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
// Don't try this transformation if the while loop isn't removable, since if
// it succeeds ultimately we're going to have to replace the old while loop
// with a new one.
if (!while_op->parent()->IsSafelyRemovable(while_op)) {
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
return false;
}
HloInstruction* while_init = while_op->mutable_operand(0);
HloComputation* while_cond = while_op->while_condition();
HloComputation* while_body = while_op->while_body();
HloInstruction* while_body_root = while_body->root_instruction();
if (!while_init->shape().IsTuple()) {
VLOG(2) << "While op's carried value isn't tuple shaped.";
return false;
}
if (while_body_root->opcode() != HloOpcode::kTuple) {
VLOG(2) << "While body's root is not a tuple(...) instruction.";
return false;
}
auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
// Bail if param0 of while_cond or while_body has users which aren't of type
// get-tuple-element.
for (const HloInstruction* instr : {while_body->parameter_instruction(0),
while_cond->parameter_instruction(0)}) {
for (const HloInstruction* user : instr->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement) {
VLOG(2) << "Cowardly refusing to analyze while loop with "
<< instr->ToString(print_no_metadata)
<< " used by non-GTE instruction "
<< user->ToString(print_no_metadata) << " in computation "
<< instr->parent()->name();
return false;
}
}
}
const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
if (tuple_size == 0) {
VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
"empty.";
return false;
}
absl::flat_hash_set<int64> used_tuple_indices;
for (HloComputation* comp : {while_body, while_cond}) {
// The HLO verifier ensures that while_input's shape matches while_init's
// shape, which we verified above is a tuple.
HloInstruction* while_input = comp->parameter_instruction(0);
for (const HloInstruction* user : while_input->users()) {
// This user doesn't count if it's only used by the while body's root, and
// the root places the tuple element into the same index of the tuple as
// it came from. That just amounts to us carrying the variable through
// the loop.
//
// Careful: HloInstruction::operand_index returns the first index the
// operand appears in, but it may appear more than once!
if (user->user_count() == 1 && user->users().front() == while_body_root &&
while_body_root->operand_index(user) == user->tuple_index() &&
absl::c_count(while_body_root->operands(), user) == 1) {
continue;
}
used_tuple_indices.insert(user->tuple_index());
if (used_tuple_indices.size() == tuple_size) {
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
<< " uses all of its inputs; no simplification possible.";
return false;
}
}
}
// If a tuple element is not passed unmodified from the while body's param0
// through to the while body's root, count that element as "used", since
// removing that element would be observable.
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
if (used_tuple_indices.contains(i)) {
continue;
}
auto* operand = while_body_root->operand(i);
if (operand->opcode() != HloOpcode::kGetTupleElement ||
operand->operand(0) != while_body->parameter_instruction(0) ||
operand->tuple_index() != i) {
VLOG(2) << "Tuple index " << i
<< " is not passed through loop body unmodified.";
used_tuple_indices.insert(i);
if (used_tuple_indices.size() == tuple_size) {
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
<< " uses all of its inputs; no simplification possible.";
return false;
}
}
}
// If we got here, used_tuple_indices.size() < tuple_size, meaning some
// elements of the loop's tuple aren't used by while_body or while_cond.
CHECK_LT(used_tuple_indices.size(), tuple_size);
VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
<< " elements from tuple of "
<< while_op->ToString(print_no_metadata);
TF_ASSIGN_OR_RETURN(while_op,
RemoveDeadTupleIndices(while_op, used_tuple_indices));
return true;
}
// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
// duplicates by replacing them with tuple_index, followed by a call to
// RemoveDeadTupleIndices.
static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
HloInstruction* while_op, const int64 tuple_index,
absl::flat_hash_set<int64>& duplicates) {
HloComputation* while_cond = while_op->while_condition();
HloComputation* while_body = while_op->while_body();
HloInstruction* while_init = while_op->mutable_operand(0);
VLOG(2) << "while_init " << while_init->ToString() << " operands "
<< while_init->operand_count();
VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
<< " operands " << while_body->root_instruction()->operand_count();
// Change the loop body and condition such that uses of the duplicates are
// replaced with the original tuple element.
for (HloComputation* comp : {while_body, while_cond}) {
auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
comp->parameter_instruction(0), tuple_index));
std::vector<HloInstruction*> instrs_to_replace;
for (auto* instr : comp->instructions()) {
if (instr->opcode() == HloOpcode::kGetTupleElement &&
duplicates.contains(instr->tuple_index()) &&
instr->operand(0) == comp->parameter_instruction(0)) {
instrs_to_replace.push_back(instr);
}
}
for (auto instr : instrs_to_replace) {
TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
}
}
// We know which tuple indices are useful; i.e, those which aren't duplicates.
absl::flat_hash_set<int64> used_tuple_indices;
for (int index = 0; index < while_init->shape().tuple_shapes_size();
++index) {
if (!duplicates.count(index)) {
used_tuple_indices.insert(index);
}
}
// Remove the duplicate tuple elements.
TF_ASSIGN_OR_RETURN(while_op,
RemoveDeadTupleIndices(while_op, used_tuple_indices));
return while_op;
}
// If the while loop init passes the same values to several tuple indices, and
// if the body keeps on passing them through, we can remove the duplicates.
static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
HloInstruction* while_op) {
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
int index_to_investigate = 0;
// Don't try this transformation if the while loop isn't removable, since if
// it succeeds ultimately we're going to have to replace the old while loop
// with a new one.
if (!while_op->parent()->IsSafelyRemovable(while_op)) {
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
return false;
}
HloInstruction* while_init = while_op->mutable_operand(0);
HloComputation* while_cond = while_op->while_condition();
HloComputation* while_body = while_op->while_body();
HloInstruction* while_body_root = while_body->root_instruction();
if (!while_init->shape().IsTuple()) {
VLOG(2) << "While op's carried value isn't tuple shaped.";
return false;
}
bool changed = false;
while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
if (!while_init->shape().IsTuple() ||
while_init->opcode() != HloOpcode::kTuple) {
VLOG(2) << "While op's carried value isn't tuple shaped.";
return false;
}
if (while_body_root->opcode() != HloOpcode::kTuple) {
VLOG(2) << "While body's root is not a tuple(...) instruction.";
return false;
}
auto& while_shape = while_init->shape();
VLOG(2) << "Iterating " << index_to_investigate;
absl::flat_hash_set<int64> duplicates;
auto* pivot_init_elem = while_init->operand(index_to_investigate);
auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement) {
if (pivot_body_elem->tuple_index() != index_to_investigate) {
VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
<< pivot_body_elem->tuple_index() << " index_to_investigate "
<< index_to_investigate;
index_to_investigate++;
continue;
}
} else {
index_to_investigate++;
continue;
}
// Look from index_to_investigate onwards to see if it is repeated.
for (int64 i = index_to_investigate + 1;
i < while_shape.tuple_shapes_size(); ++i) {
auto* init_elem = while_init->operand(i);
auto* body_elem = while_body_root->operand(i);
if (body_elem->opcode() == HloOpcode::kGetTupleElement) {
if (body_elem->tuple_index() != i) {
VLOG(2) << "Mismatch between body_elem->tuple_index() "
<< body_elem->tuple_index() << " i " << i;
continue;
}
} else {
continue;
}
if (pivot_init_elem == init_elem /*&& pivot_body_elem == body_elem*/) {
VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
<< pivot_init_elem->ToString();
VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
<< pivot_body_elem->ToString();
duplicates.insert(i);
}
}
// If duplicates are found, call the helper to remove them.
if (!duplicates.empty()) {
VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
<< pivot_init_elem->ToString();
TF_ASSIGN_OR_RETURN(while_op,
TryRemoveRepeatedWhileTupleIndicesHelper(
while_op, index_to_investigate, duplicates));
changed = true;
VLOG(2) << "Changed while_op " << while_op->ToString()
<< " while_op operand count " << while_op->operand_count();
// Update the while loop variables so we can continue looking for
// duplicates of a different index.
while_init = while_op->mutable_operand(0);
while_cond = while_op->while_condition();
while_body = while_op->while_body();
while_body_root = while_body->root_instruction();
}
index_to_investigate++;
}
return changed;
}
// Removes each loop parameter (i.e. member of the while loop tuple) that is a
// constant and is the same in the while loop body and the while loop init.
static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
......@@ -1048,6 +1224,7 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
changed |= result;
if (result) {
// Don't continue simplifying after successfully removing the while loop
// -- that would result in use-after-free nastiness.
......@@ -1067,6 +1244,12 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
// successful, meaning that `while_op` is no longer valid after one of these
// transformations returns true.
TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
changed |= result;
if (result) {
continue;
}
TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
changed |= result;
if (result) {
......@@ -1074,6 +1257,7 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
}
TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
changed |= result;
if (result) {
continue;
......
......@@ -794,5 +794,51 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) {
.ValueOrDie());
}
TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) {
const string hlo_string = R"(
HloModule SwappingTupleElements
SwappingTupleElements.body {
loop_var = (s32[], s32[], s32[]) parameter(0)
get-tuple-element = s32[] get-tuple-element(loop_var), index=0
get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1
get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2
y = s32[] add(get-tuple-element.1, get-tuple-element.2)
ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y,
s32[] get-tuple-element.2)
}
SwappingTupleElements.always_true {
param = (s32[], s32[], s32[]) parameter(0)
get-tuple-element = s32[] get-tuple-element(param), index=0
get-tuple-element.1 = s32[] get-tuple-element(param), index=1
ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
}
ENTRY SwappingTupleElements {
x = s32[] parameter(0)
y = s32[] parameter(1)
tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x)
ROOT while = (s32[], s32[], s32[]) while(tuple.1),
condition=SwappingTupleElements.always_true,
body=SwappingTupleElements.body
}
)";
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
HloInstruction* new_while = FindFirstWhile(m.get());
Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie();
EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
EXPECT_TRUE(ShapeUtil::Equal(
new_while->while_body()->root_instruction()->shape(), new_while_shape));
EXPECT_TRUE(ShapeUtil::Equal(
new_while->while_body()->parameter_instruction(0)->shape(),
new_while_shape));
EXPECT_TRUE(ShapeUtil::Equal(
new_while->while_condition()->parameter_instruction(0)->shape(),
new_while_shape));
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册