提交 6ba02f0e 编写于 作者: A Artem Belevich 提交者: TensorFlower Gardener

[XLA] Added HasAllocationAt() helper function.

PiperOrigin-RevId: 163742985
上级 18304683
......@@ -200,10 +200,10 @@ BufferAllocation* BufferAssignment::GetMutableAllocation(
return const_cast<BufferAllocation*>(&GetAllocation(index));
}
bool BufferAssignment::HasTopLevelAllocation(
const HloInstruction* instruction) const {
bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
for (const LogicalBuffer* buffer :
GetPointsToSet(instruction).element(/*index=*/{})) {
GetPointsToSet(instruction).element(index)) {
if (allocation_index_for_buffer_.count(buffer) > 0) {
return true;
}
......@@ -211,6 +211,11 @@ bool BufferAssignment::HasTopLevelAllocation(
return false;
}
bool BufferAssignment::HasTopLevelAllocation(
const HloInstruction* instruction) const {
return HasAllocationAt(instruction, /*index=*/{});
}
StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
const HloInstruction* instruction, const ShapeIndex& index) const {
VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
......
......@@ -281,6 +281,11 @@ class BufferAssignment {
std::set<BufferAllocation::Slice> GetAllSlices(
const HloInstruction* instruction, const ShapeIndex& index) const;
// Convenience function which returns whether the buffer of the
// instruction at the given index is assigned an allocation.
bool HasAllocationAt(const HloInstruction* instruction,
const ShapeIndex& index) const;
// Convenience function which returns whether the top-level buffer of the
// instruction (index == {}) is assigned an allocation.
bool HasTopLevelAllocation(const HloInstruction* instruction) const;
......
......@@ -296,6 +296,34 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
GetAssignedOutputAllocation(*buffers, add);
}
TEST_F(BufferAssignmentTest, HasAllocationAt) {
// Create a tuple with non-const and const elements and check that
// HasAllocationAt works correctly.
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32vec100_, "param0"));
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({negate, param0, constant}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignment(module.get());
// Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
// reports for the instruction directly.
EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
buffers->HasAllocationAt(tuple, /*index=*/{}));
EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
buffers->HasAllocationAt(tuple, /*index=*/{0}));
EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
buffers->HasAllocationAt(tuple, /*index=*/{1}));
EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
buffers->HasAllocationAt(tuple, /*index=*/{2}));
}
TEST_F(BufferAssignmentTest, BufferForOutputConst) {
// This computation copies a constant to output.
auto builder = HloComputation::Builder(TestName());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册