提交 e918c5c7 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[XLA] Fix issue in conditional code motion regarding sharing of computations...

[XLA] Fix issue in conditional code motion regarding sharing of computations in conditionals and cleanup generated code.

The branch computations inside a conditional may be shared among different Hlo instructions, e.g., different conditionals. When moving instructions across the boundaries of two computations, specifically the branch computations and the parent of a conditional, we must make sure the branch computations being modified are not shared --- if shared, they must be cloned first before being modified.

The transformation code and the cost calculation for moving instructions inside branches are also modified to produce cleaner result and to refrain from modifying a conditional back and forth. The original implementation for moving instructions inside branches merely extends the old roots of the branches with new instructions. The improved transformation now folds the tuple/getTupleElement instructions in the branches to eliminate unnecessary tuple/getTupleElement pairs.

PiperOrigin-RevId: 327764642
Change-Id: Ia7d7fda3f6e8d8d9af6e091f92a94946af096a7e
上级 8a002f22
......@@ -580,6 +580,154 @@ ENTRY main {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
branch {
arg_tuple.1 = (f32[10]) parameter(0)
get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0
add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1)
ROOT tuple.3 = (f32[10]) tuple(add.1)
}
ENTRY main {
pred.1 = pred[] parameter(0)
tuple.1 = (f32[10]) parameter(1)
tuple.2 = (f32[10]) parameter(2)
conditional = (f32[10])
conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
false_computation=branch
get-first-index = f32[10] get-tuple-element(conditional), index=0
ROOT pow.1 = f32[10] power(get-first-index, get-first-index)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 5);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 5);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
branch {
arg_tuple.1 = (f32[10]) parameter(0)
get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0
ROOT add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1)
}
ENTRY main {
pred.1 = pred[] parameter(0)
tuple.1 = (f32[10]) parameter(1)
tuple.2 = (f32[10]) parameter(2)
conditional = f32[10]
conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
false_computation=branch
ROOT pow.1 = f32[10] power(conditional, conditional)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 5);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 5);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, MovePowInWithEmptyBranch) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
branch1 {
arg_tuple.1 = (f32[10]) parameter(0)
get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0
add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1)
ROOT tuple.3 = (f32[10]) tuple(add.1)
}
branch2 {
ROOT arg_tuple.1 = (f32[10]) parameter(0)
}
ENTRY main {
pred.1 = pred[] parameter(0)
tuple.1 = (f32[10]) parameter(1)
tuple.2 = (f32[10]) parameter(2)
conditional = (f32[10])
conditional(pred.1, tuple.1, tuple.2), true_computation=branch1,
false_computation=branch2
get-first-index = f32[10] get-tuple-element(conditional), index=0
ROOT pow.1 = f32[10] power(get-first-index, get-first-index)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 5);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 4);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleParameter) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
branch {
arg.1 = f32[10] parameter(0)
ROOT add.1 = f32[10] add(arg.1, arg.1)
}
ENTRY main {
pred.1 = pred[] parameter(0)
tuple.1 = f32[10] parameter(1)
tuple.2 = f32[10] parameter(2)
conditional = f32[10]
conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
false_computation=branch
ROOT pow.1 = f32[10] power(conditional, conditional)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 4);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 4);
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
}
} // namespace conditional_opt
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册