提交 d1b68813 编写于 作者: B Berkin Ilbeyi 提交者: TensorFlower Gardener

[XLA] Use max size of buffer in heap simulator and buffer assignment and make...

[XLA] Use max size of buffer in heap simulator and buffer assignment and make size DCHECKs into CHECKs

PiperOrigin-RevId: 564421004
上级 2ed84d32
......@@ -1761,9 +1761,12 @@ xla_cc_test(
name = "heap_simulator_test",
srcs = ["heap_simulator_test.cc"],
deps = [
":async_op_canonicalizer",
":buffer_value",
":heap_simulator",
":hlo_dce",
":hlo_ordering",
":hlo_parser",
":hlo_value",
":tuple_points_to_analysis",
"//xla:literal",
......
......@@ -551,9 +551,9 @@ class BufferAssignment {
BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
int64_t HloBufferSize(const HloBuffer& buffer) {
int64_t result = buffer_size_(*buffer.values()[0]);
int64_t result = 0;
for (const HloValue* value : buffer.values()) {
DCHECK_EQ(result, buffer_size_(*value));
result = std::max(result, buffer_size_(*value));
}
return result;
}
......
......@@ -2999,6 +2999,45 @@ ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
get_slice("negate_5", {}) == get_slice("negate_1", {}));
}
TEST_F(BufferAssignmentTest, AsyncCallImplicitSharding) {
std::string hlo_string = R"(
HloModule module, is_scheduled=true
called_computation {
param0 = f32[4] parameter(0)
constant = f32[1] constant(1)
dynamic-update-slice = f32[4] dynamic-update-slice(param0, constant, constant)
ROOT negate = f32[4] negate(dynamic-update-slice)
}
ENTRY entry {
p0 = f32[8] parameter(0)
call-start = ((f32[8]), f32[8], s32[]) call-start(p0), async_execution_thread="foo", to_apply=called_computation
ROOT call-done = f32[8] call-done(call-start), async_execution_thread="foo", to_apply=called_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
AsyncOpCanonicalizer canonicalizer;
TF_ASSERT_OK(canonicalizer.Run(module.get()).status());
HloDCE dce;
TF_ASSERT_OK(dce.Run(module.get()).status());
auto buffers = RunBufferAssignmentWithSequentialOrdering(module.get());
LOG(INFO) << buffers->ToString();
auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) {
return buffers
->GetUniqueSlice(FindInstruction(module.get(), hlo_name), index)
.value();
};
EXPECT_EQ(get_slice("p0", {}).size(), 32);
EXPECT_EQ(get_slice("dynamic-update-slice", {}).size(), 32);
}
TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
absl::string_view module_str = R"(
HloModule test_module
......
......@@ -257,6 +257,17 @@ Status HeapSimulator::RunComputation(
VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
// Populate buffer sizes with the maximum size of the constituent HloValues.
for (const HloBuffer& buffer : alias_analysis.buffers()) {
int64_t size = 0;
for (const HloValue* value : buffer.values()) {
size = std::max(size, size_fn_(*value));
}
for (const HloValue* value : buffer.values()) {
buffer_sizes_[value] = size;
}
}
// Go through each step in the program and replay each buffer define and free
// events.
for (int64_t i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
......@@ -406,7 +417,7 @@ void HeapSimulator::Alloc(const HloValue* buffer,
<< "Alloc called on freed buffer: " << *buffer;
allocated_buffers_.insert(buffer);
const int64_t size = size_fn_(*buffer);
const int64_t size = GetBufferSize(buffer);
algorithm_->Alloc(buffer, size);
no_fragmentation_stats_->Alloc(buffer, size);
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
......@@ -419,7 +430,7 @@ void HeapSimulator::Alloc(const HloValue* buffer,
// causes Free to be called on the underlying algorithm.
void HeapSimulator::Free(const HloValue* buffer,
const HloInstruction* instruction) {
const int64_t size = size_fn_(*buffer);
const int64_t size = GetBufferSize(buffer);
algorithm_->Free(buffer, size);
no_fragmentation_stats_->Free(buffer, size);
FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
......@@ -432,12 +443,18 @@ void HeapSimulator::Free(const HloValue* buffer,
// SharedGroup.
void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
const HloInstruction* instruction) {
algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
algorithm_->ShareWith(buffer, shared, GetBufferSize(shared));
no_fragmentation_stats_->ShareWith(buffer, shared, GetBufferSize(shared));
FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
shared);
}
int64_t HeapSimulator::GetBufferSize(const HloValue* buffer) const {
auto it = buffer_sizes_.find(buffer);
CHECK(it != buffer_sizes_.end());
return it->second;
}
HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
Result<HloValue> result = algorithm_->Finish();
......@@ -591,7 +608,7 @@ void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
auto emplace_result = buffer_intervals_.emplace(
buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
DCHECK(emplace_result.second);
CHECK(emplace_result.second);
++current_time_;
}
......@@ -603,11 +620,11 @@ void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
result_.chunk_map.emplace(buffer, Chunk::FromOffsetSize(0, 0));
return;
}
DCHECK_NE(buffer_intervals_.count(share_with), 0);
CHECK_NE(buffer_intervals_.count(share_with), 0);
buffer_intervals_[share_with].colocations.push_back(buffer);
auto emplace_result = buffer_intervals_.emplace(
buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
DCHECK(emplace_result.second);
CHECK(emplace_result.second);
++current_time_;
}
......@@ -638,9 +655,9 @@ void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
return;
}
BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
DCHECK_EQ(buffer_interval.buffer, buffer);
DCHECK_EQ(buffer_interval.size, size);
DCHECK_EQ(buffer_interval.end, -1);
CHECK_EQ(buffer_interval.buffer, buffer);
CHECK_EQ(buffer_interval.size, size);
CHECK_EQ(buffer_interval.end, -1);
if (buffer_interval.end != -1) {
return;
}
......
......@@ -223,6 +223,10 @@ class HeapSimulator {
void ShareBuffer(const HloValue* buffer, const HloValue* shared,
const HloInstruction* instruction);
// Returns the size of the HloValue, which is the max size of the HloValues
// that are part of the HloBuffer.
int64_t GetBufferSize(const HloValue* buffer) const;
// Returns true if:
// Two buffers belong to the same shared group.
// Eight of the buffer has no shared group assigned.
......@@ -253,6 +257,8 @@ class HeapSimulator {
absl::flat_hash_set<const HloValue*> allocated_buffers_;
absl::flat_hash_set<const HloValue*> freed_buffers_;
absl::flat_hash_map<const HloValue*, int64_t> buffer_sizes_;
// Debugging information filled in while the heap simulator runs.
HeapSimulatorTrace debug_trace_;
};
......
......@@ -29,8 +29,11 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/service/async_op_canonicalizer.h"
#include "xla/service/buffer_value.h"
#include "xla/service/hlo_dce.h"
#include "xla/service/hlo_ordering.h"
#include "xla/service/hlo_parser.h"
#include "xla/service/hlo_value.h"
#include "xla/service/tuple_points_to_analysis.h"
#include "xla/status_macros.h"
......@@ -940,6 +943,53 @@ TEST_F(HeapSimulatorTest, WholeModule) {
});
}
TEST_F(HeapSimulatorTest, AsyncCallImplicitSharding) {
std::string hlo_string = R"(
HloModule module, is_scheduled=true
called_computation {
param0 = f32[4] parameter(0)
constant = f32[1] constant(1)
dynamic-update-slice = f32[4] dynamic-update-slice(param0, constant, constant)
ROOT negate = f32[4] negate(dynamic-update-slice)
}
ENTRY entry {
p0 = f32[8] parameter(0)
call-start = ((f32[8]), f32[8], s32[]) call-start(p0), async_execution_thread="foo", to_apply=called_computation
ROOT call-done = f32[8] call-done(call-start), async_execution_thread="foo", to_apply=called_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
AsyncOpCanonicalizer canonicalizer;
TF_ASSERT_OK(canonicalizer.Run(module.get()).status());
HloDCE dce;
TF_ASSERT_OK(dce.Run(module.get()).status());
TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
HloAliasAnalysis::Run(module.get()));
auto size_fn = [](const BufferValue& buffer) -> int64_t {
const Shape& shape = buffer.shape();
if (!shape.IsArray()) {
return 0;
}
return ShapeUtil::ByteSizeOf(shape);
};
auto algorithm = std::make_unique<GlobalDecreasingSizeBestFitHeap<HloValue>>(
/*alignment=*/1);
HeapSimulator::Result<HloValue> result =
HeapSimulator::Run(std::move(algorithm), *module, module->schedule(),
*alias_analysis, size_fn)
.value();
for (const auto& [value, chunk] : result.heap_results[0].chunk_map) {
if (value->instruction()->name() == "dynamic-update-slice") {
EXPECT_EQ(chunk.size, 32);
}
}
}
// Base class for heap algorithm tests.
class HeapAlgorithmTestBase : public ::testing::Test {
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册