提交 a07965ca 编写于 作者: Y Yunxing Dai 提交者: TensorFlower Gardener

[XLA] Use module scheduler to reduce compile time.

Previously, we only have computation scheduler, which runs heap
simulator once per computation. For models with large number of
computation, this creates extremly slow compilation time.

This cl introduces module scheduler, that only runs heap simulator
after the whole module is scheduled. It also contains a helper
function that automatically converts a computation scheduler to module
scheduler.

PiperOrigin-RevId: 258436352
上级 502e5e62
......@@ -655,7 +655,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
ScheduleModule(module.get(), BufferSizeBytesFunction(),
DFSMemoryScheduler));
ComputationSchedulerToModuleScheduler(
DFSMemoryScheduler)));
// Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
......
......@@ -500,6 +500,33 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
return sequence;
} // namespace xla
ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
const MemorySchedulerAlgorithm& computation_scheduler) {
return [computation_scheduler](
HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_func,
int64* peak_memory) -> StatusOr<HloSchedule> {
HloSchedule schedule(module);
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
for (auto* computation : module->MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(
HloInstructionSequence computation_sequence,
ScheduleComputationHelper(
computation, points_to_analysis, alias_analysis, size_func,
computation_scheduler, memory_by_computation, nullptr));
schedule.set_sequence(computation, std::move(computation_sequence));
}
}
if (peak_memory) {
TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule(
schedule, size_func));
}
return std::move(schedule);
};
}
StatusOr<HloInstructionSequence> ListMemoryScheduler(
HloComputation* computation,
const TuplePointsToAnalysis& points_to_analysis,
......@@ -597,36 +624,75 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
}
}
StatusOr<HloSchedule> DefaultModuleScheduler(
HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const BufferValue::SizeFunction& size_function, int64* peak_memory) {
// We try a few schedulers and choose whichever returns a lower min-memory,
// not accounting for fragmentation.
// - List is a scheduler that uses greedy heuristics.
// - DFS visits HLOs in postorder, with a heuristic to decide the order of
// children.
// - Postorder does not use any heuristics.
// List wins for most of our benchmarks; postorder-based schedulers win for
// some RNNs.
int64 list_memory;
TF_ASSIGN_OR_RETURN(
HloSchedule list_sequence,
ComputationSchedulerToModuleScheduler(ListMemoryScheduler)(
module, points_to_analysis, alias_analysis, size_function,
&list_memory));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
int64 dfs_memory;
TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence,
ComputationSchedulerToModuleScheduler(DFSMemoryScheduler)(
module, points_to_analysis, alias_analysis,
size_function, &dfs_memory));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
int64 post_order_memory;
TF_ASSIGN_OR_RETURN(
HloSchedule post_order_sequence,
ComputationSchedulerToModuleScheduler(PostOrderMemoryScheduler)(
module, points_to_analysis, alias_analysis, size_function,
&post_order_memory));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
if (peak_memory) {
*peak_memory = min_memory;
}
if (min_memory == list_memory) {
VLOG(2) << "Chose min-memory list sequence: "
<< HumanReadableNumBytes(list_memory);
return list_sequence;
} else if (min_memory == dfs_memory) {
VLOG(2) << "Chose min-memory dfs sequence: "
<< HumanReadableNumBytes(dfs_memory);
return dfs_sequence;
} else {
VLOG(2) << "Chose min-memory post_order sequence: "
<< HumanReadableNumBytes(post_order_memory);
return post_order_sequence;
}
}
StatusOr<HloSchedule> ScheduleModule(
HloModule* module, const BufferValue::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm, int64* peak_memory) {
HloSchedule schedule(module);
const ModuleSchedulerAlgorithm& algorithm, int64* peak_memory) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module));
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
for (auto* computation : module->MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
int64 computation_peak_memory;
TF_ASSIGN_OR_RETURN(
HloInstructionSequence computation_sequence,
ScheduleComputationHelper(
computation, *points_to_analysis, *alias_analysis, size_function,
algorithm, memory_by_computation, &computation_peak_memory));
memory_by_computation[computation] = computation_peak_memory;
schedule.set_sequence(computation, std::move(computation_sequence));
}
}
VLOG(1) << "Module schedule:\n" << schedule;
if (peak_memory) {
*peak_memory = 0;
for (const auto& computation_and_peak : memory_by_computation) {
*peak_memory = std::max(*peak_memory, computation_and_peak.second);
}
}
TF_ASSIGN_OR_RETURN(HloSchedule schedule,
(algorithm ? algorithm : DefaultModuleScheduler)(
module, *points_to_analysis, *alias_analysis,
size_function, peak_memory));
TF_RETURN_IF_ERROR(schedule.Verify());
......@@ -649,7 +715,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
HloMemoryScheduler::HloMemoryScheduler(
const BufferValue::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm)
const ModuleSchedulerAlgorithm& algorithm)
: size_function_(size_function), algorithm_(algorithm) {}
StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
......
......@@ -47,6 +47,18 @@ typedef std::function<StatusOr<HloInstructionSequence>(
/*peak_memory*/ int64*)>
MemorySchedulerAlgorithm;
// Scheduler for the entire module.
typedef std::function<StatusOr<HloSchedule>(
HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
const LogicalBuffer::SizeFunction&,
/*peak_memory*/ int64*)>
ModuleSchedulerAlgorithm;
// Lift a computation scheduler into a module scheduler by calling the
// computation scheduler on all computations in a module.
ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
const MemorySchedulerAlgorithm&);
// List scheduler
StatusOr<HloInstructionSequence> ListMemoryScheduler(
HloComputation* computation,
......@@ -90,13 +102,18 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
memory_by_computation,
int64* peak_memory);
StatusOr<HloSchedule> DefaultModuleScheduler(
HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
const HloAliasAnalysis& alias_analysis,
const LogicalBuffer::SizeFunction& size_function, int64* peak_memory);
// Returns an HloSchedule which seeks to minimize the memory required for the
// module. size_function is the function returning the number of bytes required
// for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak
// memory (according to the HeapSimulator) of all computations in the module.
StatusOr<HloSchedule> ScheduleModule(
HloModule* module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {},
const ModuleSchedulerAlgorithm& algorithm = {},
int64* peak_memory = nullptr);
// Computes the schedule for a single computation.
......@@ -114,7 +131,7 @@ class HloMemoryScheduler : public HloModulePass {
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
// specified, then DefaultMemoryScheduler is used.
HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {});
const ModuleSchedulerAlgorithm& algorithm = {});
~HloMemoryScheduler() override = default;
......@@ -125,7 +142,7 @@ class HloMemoryScheduler : public HloModulePass {
private:
LogicalBuffer::SizeFunction size_function_;
MemorySchedulerAlgorithm algorithm_;
ModuleSchedulerAlgorithm algorithm_;
};
// A pass which produces a naive, but correct schedule. The schedule is produced
......
......@@ -145,7 +145,9 @@ ENTRY root {
int64 peak_memory;
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), size_fn, ListMemoryScheduler, &peak_memory));
ScheduleModule(module.get(), size_fn,
ComputationSchedulerToModuleScheduler(ListMemoryScheduler),
&peak_memory));
TF_ASSERT_OK(module->set_schedule(schedule));
// Verify that all instructions are in the sequence.
const std::vector<HloInstruction*>& sequence =
......@@ -194,9 +196,10 @@ ENTRY entry {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(module.get(), size_fn,
ComputationSchedulerToModuleScheduler(
ListMemoryScheduler)));
// Verify that all instructions are in the sequence.
const std::vector<HloInstruction*>& sequence =
schedule.sequence(module->entry_computation()).instructions();
......@@ -239,14 +242,14 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(
module.get(),
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(),
TUPLE_SIZE);
},
ListMemoryScheduler));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(
module.get(),
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), TUPLE_SIZE);
},
ComputationSchedulerToModuleScheduler(ListMemoryScheduler)));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
......@@ -290,13 +293,14 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
auto fusion = computation->CreateFusionInstruction(
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
ScheduleModule(
module.get(),
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
},
ListMemoryScheduler));
TF_ASSERT_OK_AND_ASSIGN(
HloSchedule schedule,
ScheduleModule(
module.get(),
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
},
ComputationSchedulerToModuleScheduler(ListMemoryScheduler)));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册