提交 413551b9 编写于 作者: S Sanjoy Das 提交者: TensorFlower Gardener

[XLA:CPU] Make instruction order compulsory in IrEmitter::EmitComputation; NFC

PiperOrigin-RevId: 225127595
上级 92f67536
......@@ -635,18 +635,17 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
.EmitComputation(
embedded_computation, embedded_computation->name(),
/*is_top_level_computation=*/false,
&schedule.sequence(embedded_computation).instructions())
schedule.sequence(embedded_computation).instructions())
.status());
}
string function_name_prefix = entry_computation->name().empty()
? "__compute"
: entry_computation->name();
TF_ASSIGN_OR_RETURN(
llvm::Function * entry_function,
ir_emitter.EmitComputation(
entry_computation, function_name_prefix,
/*is_top_level_computation=*/true,
&schedule.sequence(entry_computation).instructions()));
TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
ir_emitter.EmitComputation(
entry_computation, function_name_prefix,
/*is_top_level_computation=*/true,
schedule.sequence(entry_computation).instructions()));
string function_name = [&]() {
llvm::SmallVector<char, 40> function_name_vector;
......@@ -835,7 +834,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
.EmitComputation(
embedded_computation, embedded_computation->name(),
/*is_top_level_computation=*/false,
&schedule.sequence(embedded_computation).instructions())
schedule.sequence(embedded_computation).instructions())
.status());
}
const string& entry_point_name = options.entry_point_name();
......@@ -843,7 +842,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
ir_emitter.EmitComputation(
computation, entry_point_name,
/*is_top_level_computation=*/true,
&schedule.sequence(computation).instructions()));
schedule.sequence(computation).instructions()));
CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
......
......@@ -111,10 +111,9 @@ IrEmitter::IrEmitter(
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
const std::vector<HloInstruction*>* instruction_order) {
absl::Span<HloInstruction* const> instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
is_top_level_computation_ = is_top_level_computation;
num_dynamic_loop_bounds_ = 0;
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
......@@ -141,11 +140,7 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
arch_type_ == llvm::Triple::ArchType::x86_64;
profiling_state_ = ProfilingState(use_rdtscp);
if (instruction_order == nullptr) {
TF_RETURN_IF_ERROR(computation->Accept(this));
} else {
TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order));
}
TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
llvm::Function* ir_function = compute_function_->function();
InsertOrDie(&emitted_functions_, computation, ir_function);
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
......
......@@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
StatusOr<llvm::Function*> EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
const std::vector<HloInstruction*>* instruction_order);
absl::Span<HloInstruction* const> instruction_order);
llvm::IRBuilder<>* b() { return &b_; }
......
......@@ -797,7 +797,7 @@ Status HloComputation::AcceptWithOperandOrder(
template <typename HloInstructionPtr>
Status HloComputation::AcceptOrdered(
DfsHloVisitorBase<HloInstructionPtr>* visitor,
const std::vector<HloInstruction*>& order) const {
absl::Span<HloInstruction* const> order) const {
VLOG(3) << "Accepting visitor with order.";
for (HloInstruction* root : CollectUnreachableRoots()) {
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
......@@ -827,9 +827,9 @@ Status HloComputation::AcceptOrdered(
// Explicit instantiations.
template Status HloComputation::AcceptOrdered(
DfsHloVisitor*, const std::vector<HloInstruction*>&) const;
DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
template Status HloComputation::AcceptOrdered(
ConstDfsHloVisitor*, const std::vector<HloInstruction*>&) const;
ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
Status HloComputation::Accept(
const std::function<Status(HloInstruction*)>& visitor_func) {
......
......@@ -307,7 +307,7 @@ class HloComputation {
// be a topological sort of all instructions in the computation.
template <typename HloInstructionPtr>
Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor,
const std::vector<HloInstruction*>& order) const;
absl::Span<HloInstruction* const> order) const;
// Same as Accept() above, but the visitor is given as a function.
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册