提交 63593f36 编写于 作者: B Berkin Ilbeyi 提交者: TensorFlower Gardener

[XLA] Free up cross-program-prefetched buffers after the last use.

PiperOrigin-RevId: 327899057
Change-Id: I602aa480c35b8734b50395d1c7e0fb621ad2d0fb
上级 58f4a3e0
......@@ -1400,33 +1400,79 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
// Find the earliest use.
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
auto uses = buffer->uses();
auto first_use =
absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) {
return instruction_schedule.at(lhs.instruction) <
instruction_schedule.at(rhs.instruction);
});
auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
return instruction_schedule.at(lhs.instruction) <
instruction_schedule.at(rhs.instruction);
};
auto first_use = absl::c_min_element(uses, use_schedule_compare);
int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction);
// Find the latest use time.
int64 last_use_time = instruction_schedule.at(
absl::c_max_element(uses, use_schedule_compare)->instruction);
for (const HloValue* colocation : prefetch_candidate->colocations) {
last_use_time = std::max(
last_use_time,
instruction_schedule.at(
absl::c_max_element(colocation->uses(), use_schedule_compare)
->instruction));
}
int64 end_of_program_prefetch_end_time = instruction_schedule.size() - 1;
int64 end_of_program_prefetch_start_time =
options_.prefetch_interval_picker->PreferredPrefetchStartTime(
buffer->defining_position().shape(), last_use_time,
end_of_program_prefetch_end_time, end_of_program_prefetch_end_time);
VLOG(2) << "last use time = " << last_use_time
<< ", end-of-program prefetch start time = "
<< end_of_program_prefetch_start_time;
bool free_buffer =
(end_of_program_prefetch_start_time > last_use_time &&
end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
int64 cross_program_prefetch_end_time =
free_buffer ? last_use_time : prefetch_candidate->end;
AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
chunk_candidate.chunk, prefetch_candidate->start,
prefetch_candidate->end, latest_prefetch_time, &allocations,
cross_program_prefetch_end_time, latest_prefetch_time,
&allocations,
/*is_cross_program_prefetch=*/true);
absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
int64 cross_program_prefetch_offset = allocations.back()->chunk().offset;
if (free_buffer) {
VLOG(2) << "Adding an end-of-program prefetch for freed "
"cross-program-prefetched buffer.";
AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate,
chunk_candidate.chunk, end_of_program_prefetch_start_time,
end_of_program_prefetch_end_time,
end_of_program_prefetch_end_time, &allocations);
CHECK_EQ(cross_program_prefetch_offset, allocations.back()->chunk().offset);
}
for (auto& allocation : allocations) {
allocations_->push_back(std::move(allocation));
}
// Add a repack allocation block for the Allocation object in alternate
// Add a repack allocation block for the Allocation objects in alternate
// memory.
CHECK_EQ(allocations_->size(), 2);
MemorySpaceAssignment::Allocation* last_allocation =
allocations_->at(1).get();
CHECK(last_allocation->memory_space() == MemorySpace::kAlternate);
repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
last_allocation->start_time(), last_allocation->end_time(),
last_allocation->chunk().size, last_allocation->chunk().offset,
static_cast<int64>(repack_allocation_blocks_.size()), last_allocation));
repack_allocation_blocks_.back().colocations.push_back(
&repack_allocation_blocks_.back());
CHECK_EQ(repack_allocation_blocks_.size(), 0);
for (const auto& allocation : *allocations_) {
if (allocation->memory_space() == MemorySpace::kAlternate) {
repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
allocation->start_time(), allocation->end_time(),
allocation->chunk().size, allocation->chunk().offset,
static_cast<int64>(repack_allocation_blocks_.size()),
allocation.get()));
RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
colocation.colocations.push_back(inserted);
if (&colocation != inserted) {
inserted->colocations.push_back(&colocation);
}
}
}
}
ClearPendingChunks();
}
......@@ -2478,7 +2524,9 @@ FindCrossProgramPrefetchCandidate(
const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range,
const MemorySpaceAssignment::Options& options) {
std::vector<MemorySpaceAssignment::BufferInterval> candidates;
for (HloValue* value : alias_analysis.dataflow_analysis().values()) {
for (const HloBuffer& buffer : alias_analysis.buffers()) {
CHECK_GE(buffer.values().size(), 1);
const HloValue* value = buffer.values().at(0);
if (IsCrossProgramPrefetchCandidate(*value, options)) {
MemorySpaceAssignment::BufferInterval interval;
interval.buffer = value;
......@@ -2486,6 +2534,7 @@ FindCrossProgramPrefetchCandidate(
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
candidates.emplace_back(interval);
}
}
......
......@@ -4566,6 +4566,125 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) {
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) {
// This test is for checking if the cross-program-prefetched buffer is freed
// after its last use and there is an end-of-program prefetch.
absl::string_view hlo_string = R"(
HloModule cross_program_prefetch, is_scheduled=true
ENTRY CrossProgramPrefetch {
p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
negate.1 = f32[8,2]{1,0} negate(dot)
negate.2 = f32[8,2]{1,0} negate(negate.1)
negate.3 = f32[8,2]{1,0} negate(negate.2)
negate.4 = f32[8,2]{1,0} negate(negate.3)
negate.5 = f32[8,2]{1,0} negate(negate.4)
negate.6 = f32[8,2]{1,0} negate(negate.5)
negate.7 = f32[8,2]{1,0} negate(negate.6)
negate.8 = f32[8,2]{1,0} negate(negate.7)
ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
/*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
if (!cross_program_prefetches.empty()) {
EXPECT_EQ(cross_program_prefetches[0].first, 0);
EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
}
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
HloDataflowAnalysis::Run(*module));
const HloValue& cross_program_prefetched_value =
dataflow_analysis->GetValueDefinedAt(
module->entry_computation()->parameter_instruction(0), {1});
// Expect that there are two prefetches that use this value, one is the
// cross-program prefetch, the other is the end-of-program prefetch.
EXPECT_EQ(absl::c_count_if(
cross_program_prefetched_value.uses(),
[](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
use.instruction->is_cross_program_prefetch();
}),
1);
EXPECT_EQ(absl::c_count_if(
cross_program_prefetched_value.uses(),
[](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
!use.instruction->is_cross_program_prefetch();
}),
1);
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
// This tests the scenario that the cross-program-prefetched buffer is used
// again close to the end of the computation. In this case, it is better not
// to free the buffer.
absl::string_view hlo_string = R"(
HloModule cross_program_prefetch, is_scheduled=true
ENTRY CrossProgramPrefetch {
p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
negate.1 = f32[8,2]{1,0} negate(dot)
negate.2 = f32[8,2]{1,0} negate(negate.1)
negate.3 = f32[8,2]{1,0} negate(negate.2)
negate.4 = f32[8,2]{1,0} negate(negate.3)
negate.5 = f32[8,2]{1,0} negate(negate.4)
negate.6 = f32[8,2]{1,0} negate(negate.5)
negate.7 = f32[8,2]{1,0} negate(negate.6)
negate.8 = f32[8,2]{1,0} negate(negate.7)
ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
/*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
if (!cross_program_prefetches.empty()) {
EXPECT_EQ(cross_program_prefetches[0].first, 0);
EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
}
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
HloDataflowAnalysis::Run(*module));
const HloValue& cross_program_prefetched_value =
dataflow_analysis->GetValueDefinedAt(
module->entry_computation()->parameter_instruction(0), {1});
// Expect that there is one prefetch that use this value, the cross-program
// prefetch. There shouldn't be an end-of-program prefetch.
EXPECT_EQ(absl::c_count_if(
cross_program_prefetched_value.uses(),
[](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
use.instruction->is_cross_program_prefetch();
}),
1);
EXPECT_EQ(absl::c_count_if(
cross_program_prefetched_value.uses(),
[](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
!use.instruction->is_cross_program_prefetch();
}),
0);
}
using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册