diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e7134325907406219087ba98696c808bfa00e8e2..9f8f74344af5b6644c954d826d9df481daa7431a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -655,7 +655,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN(HloSchedule schedule, ScheduleModule(module.get(), BufferSizeBytesFunction(), - DFSMemoryScheduler)); + ComputationSchedulerToModuleScheduler( + DFSMemoryScheduler))); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index da82b599a6a5131e6b8081d6c7020377507271dc..50eaee954557f2be45aaf1ba6c3300539000f5c8 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -500,6 +500,33 @@ StatusOr DFSMemoryScheduler( return sequence; } // namespace xla +ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( + const MemorySchedulerAlgorithm& computation_scheduler) { + return [computation_scheduler]( + HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_func, + int64* peak_memory) -> StatusOr { + HloSchedule schedule(module); + absl::flat_hash_map memory_by_computation; + for (auto* computation : module->MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN( + HloInstructionSequence computation_sequence, + ScheduleComputationHelper( + computation, points_to_analysis, alias_analysis, size_func, + computation_scheduler, memory_by_computation, nullptr)); + schedule.set_sequence(computation, std::move(computation_sequence)); + } + } + if (peak_memory) { + TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule( + schedule, size_func)); + } + return std::move(schedule); + }; +} + StatusOr ListMemoryScheduler( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, @@ -597,36 +624,75 @@ StatusOr DefaultMemoryScheduler( } } +StatusOr DefaultModuleScheduler( + HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const BufferValue::SizeFunction& size_function, int64* peak_memory) { + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. + int64 list_memory; + TF_ASSIGN_OR_RETURN( + HloSchedule list_sequence, + ComputationSchedulerToModuleScheduler(ListMemoryScheduler)( + module, points_to_analysis, alias_analysis, size_function, + &list_memory)); + + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); + + int64 dfs_memory; + TF_ASSIGN_OR_RETURN(HloSchedule dfs_sequence, + ComputationSchedulerToModuleScheduler(DFSMemoryScheduler)( + module, points_to_analysis, alias_analysis, + size_function, &dfs_memory)); + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); + + int64 post_order_memory; + TF_ASSIGN_OR_RETURN( + HloSchedule post_order_sequence, + ComputationSchedulerToModuleScheduler(PostOrderMemoryScheduler)( + module, points_to_analysis, alias_analysis, size_function, + &post_order_memory)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); + if (peak_memory) { + *peak_memory = min_memory; + } + + if (min_memory == list_memory) { + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); + return list_sequence; + } else if (min_memory == dfs_memory) { + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); + return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; + } +} + StatusOr ScheduleModule( HloModule* module, const BufferValue::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm, int64* peak_memory) { - HloSchedule schedule(module); + const ModuleSchedulerAlgorithm& algorithm, int64* peak_memory) { TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module)); - absl::flat_hash_map memory_by_computation; - for (auto* computation : module->MakeComputationPostOrder()) { - if (!computation->IsFusionComputation()) { - int64 computation_peak_memory; - TF_ASSIGN_OR_RETURN( - HloInstructionSequence computation_sequence, - ScheduleComputationHelper( - computation, *points_to_analysis, *alias_analysis, size_function, - algorithm, memory_by_computation, &computation_peak_memory)); - memory_by_computation[computation] = computation_peak_memory; - schedule.set_sequence(computation, std::move(computation_sequence)); - } - } - VLOG(1) << "Module schedule:\n" << schedule; - - if (peak_memory) { - *peak_memory = 0; - for (const auto& computation_and_peak : memory_by_computation) { - *peak_memory = std::max(*peak_memory, computation_and_peak.second); - } - } + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + (algorithm ? algorithm : DefaultModuleScheduler)( + module, *points_to_analysis, *alias_analysis, + size_function, peak_memory)); TF_RETURN_IF_ERROR(schedule.Verify()); @@ -649,7 +715,7 @@ StatusOr ScheduleComputation( HloMemoryScheduler::HloMemoryScheduler( const BufferValue::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) + const ModuleSchedulerAlgorithm& algorithm) : size_function_(size_function), algorithm_(algorithm) {} StatusOr HloMemoryScheduler::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index fd416e9413e19ee1ace21d989a62829bb6d3c942..a93e05d2357b55555e5410dd47764433d2127bc2 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -47,6 +47,18 @@ typedef std::function( /*peak_memory*/ int64*)> MemorySchedulerAlgorithm; +// Scheduler for the entire module. +typedef std::function( + HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, + const LogicalBuffer::SizeFunction&, + /*peak_memory*/ int64*)> + ModuleSchedulerAlgorithm; + +// Lift a computation scheduler into a module scheduler by calling the +// computation scheduler on all computations in a module. +ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( + const MemorySchedulerAlgorithm&); + // List scheduler StatusOr ListMemoryScheduler( HloComputation* computation, @@ -90,13 +102,18 @@ StatusOr DefaultMemoryScheduler( memory_by_computation, int64* peak_memory); +StatusOr DefaultModuleScheduler( + HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, int64* peak_memory); + // Returns an HloSchedule which seeks to minimize the memory required for the // module. size_function is the function returning the number of bytes required // for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak // memory (according to the HeapSimulator) of all computations in the module. StatusOr ScheduleModule( HloModule* module, const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}, + const ModuleSchedulerAlgorithm& algorithm = {}, int64* peak_memory = nullptr); // Computes the schedule for a single computation. @@ -114,7 +131,7 @@ class HloMemoryScheduler : public HloModulePass { // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not // specified, then DefaultMemoryScheduler is used. HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); + const ModuleSchedulerAlgorithm& algorithm = {}); ~HloMemoryScheduler() override = default; @@ -125,7 +142,7 @@ class HloMemoryScheduler : public HloModulePass { private: LogicalBuffer::SizeFunction size_function_; - MemorySchedulerAlgorithm algorithm_; + ModuleSchedulerAlgorithm algorithm_; }; // A pass which produces a naive, but correct schedule. The schedule is produced diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 2b1e059c7e5c25974c4782377109d01fe3135fd6..bf10817b3f537e7de74ec0f3572dbdf447cd92eb 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -145,7 +145,9 @@ ENTRY root { int64 peak_memory; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(module.get(), size_fn, ListMemoryScheduler, &peak_memory)); + ScheduleModule(module.get(), size_fn, + ComputationSchedulerToModuleScheduler(ListMemoryScheduler), + &peak_memory)); TF_ASSERT_OK(module->set_schedule(schedule)); // Verify that all instructions are in the sequence. const std::vector& sequence = @@ -194,9 +196,10 @@ ENTRY entry { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(module.get(), size_fn, + ComputationSchedulerToModuleScheduler( + ListMemoryScheduler))); // Verify that all instructions are in the sequence. const std::vector& sequence = schedule.sequence(module->entry_computation()).instructions(); @@ -239,14 +242,14 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule( - module.get(), - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), TUPLE_SIZE); + }, + ComputationSchedulerToModuleScheduler(ListMemoryScheduler))); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -290,13 +293,14 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule( - module.get(), - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ComputationSchedulerToModuleScheduler(ListMemoryScheduler))); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(),