提交 f83f6b9e 编写于 作者: C Chris Leary 提交者: TensorFlower Gardener

[XLA] Handle higher-order HLOs (e.g. While) in CallInliner and test.

PiperOrigin-RevId: 168029345
上级 8988ae36
......@@ -20,30 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
StatusOr<bool> CallInliner::Run(HloModule* module) {
std::deque<HloInstruction*> work_queue;
// Seed the work queue with call instructions from the main computation.
TF_RETURN_IF_ERROR(
module->entry_computation()->Accept([&](HloInstruction* hlo) {
if (hlo->opcode() == HloOpcode::kCall) {
work_queue.push_back(hlo);
}
return Status::OK();
}));
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
bool mutated = false;
while (!work_queue.empty()) {
mutated = true;
HloInstruction* call = work_queue.front();
work_queue.pop_front();
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue));
}
return mutated;
}
namespace {
// Traverses the callee computation, inlining cloned nodes into the caller
// computation and connecting them to producers/consumers appropriately.
......@@ -141,6 +118,64 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
std::deque<HloInstruction*>* work_queue_;
};
} // namespace
StatusOr<bool> CallInliner::Run(HloModule* module) {
std::deque<HloInstruction*> work_queue;
tensorflow::gtl::FlatSet<HloComputation*> seen;
auto scan_computation = [&work_queue,
&seen](HloComputation* computation) -> Status {
if (!seen.insert(computation).second) {
return Status::OK(); // Already seen.
}
return computation->Accept([&](HloInstruction* hlo) {
if (!hlo->called_computations().empty()) {
work_queue.push_back(hlo);
}
return Status::OK();
});
};
// Seed the work queue with call instructions from the main computation.
TF_RETURN_IF_ERROR(scan_computation(module->entry_computation()));
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
bool mutated = false;
while (!work_queue.empty()) {
HloInstruction* caller = work_queue.front();
work_queue.pop_front();
switch (caller->opcode()) {
case HloOpcode::kCall:
mutated = true;
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue));
break;
case HloOpcode::kWhile:
TF_RETURN_IF_ERROR(scan_computation(caller->while_condition()));
TF_RETURN_IF_ERROR(scan_computation(caller->while_body()));
break;
case HloOpcode::kSelectAndScatter:
TF_RETURN_IF_ERROR(scan_computation(caller->select()));
TF_RETURN_IF_ERROR(scan_computation(caller->scatter()));
break;
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
TF_RETURN_IF_ERROR(scan_computation(caller->to_apply()));
break;
case HloOpcode::kFusion:
// Fusion nodes don't represent true calls, but instead delimit a
// boundary for the backend-specific fusion capabilities.
break;
default:
return Unimplemented("Unknown higher-order HLO opcode: %s",
caller->ToString().c_str());
}
}
return mutated;
}
Status CallInliner::ReplaceWithInlinedBody(
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
......
......@@ -73,5 +73,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
}
// Tests for referential transparency (a function that calls a function that
// returns false should be identical to just returning false).
TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
const Shape pred = ShapeUtil::MakeShape(PRED, {});
auto module = CreateNewModule();
// Create a lambda that calls a function that returns the false predicate.
// Note we also use this lambda twice by reference, just to make the test a
// little trickier.
HloComputation::Builder just_false(TestName() + ".false");
just_false.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* false_computation =
module->AddEmbeddedComputation(just_false.Build());
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
HloComputation::Builder outer(TestName() + ".outer");
HloInstruction* init_value = outer.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
outer.AddInstruction(
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
op::Constant());
EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
op::Constant());
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册