diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index ae31135a1aeb2807649aceb6e77d6050525ce5a6..cb6ea8b2359d46e25f5633050e83ebc0d4da8716 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -200,10 +200,10 @@ BufferAllocation* BufferAssignment::GetMutableAllocation( return const_cast(&GetAllocation(index)); } -bool BufferAssignment::HasTopLevelAllocation( - const HloInstruction* instruction) const { +bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction, + const ShapeIndex& index) const { for (const LogicalBuffer* buffer : - GetPointsToSet(instruction).element(/*index=*/{})) { + GetPointsToSet(instruction).element(index)) { if (allocation_index_for_buffer_.count(buffer) > 0) { return true; } @@ -211,6 +211,11 @@ bool BufferAssignment::HasTopLevelAllocation( return false; } +bool BufferAssignment::HasTopLevelAllocation( + const HloInstruction* instruction) const { + return HasAllocationAt(instruction, /*index=*/{}); +} + StatusOr BufferAssignment::GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const { VLOG(3) << "Trying to find unique slice for " << instruction->name() << " [" diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 35c904df130564a4848d3cb2db21ed8fa209e7e8..0cbd339dc014d18cb1b342a32860e920419e0c61 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -281,6 +281,11 @@ class BufferAssignment { std::set GetAllSlices( const HloInstruction* instruction, const ShapeIndex& index) const; + // Convenience function which returns whether the buffer of the + // instruction at the given index is assigned an allocation. + bool HasAllocationAt(const HloInstruction* instruction, + const ShapeIndex& index) const; + // Convenience function which returns whether the top-level buffer of the // instruction (index == {}) is assigned an allocation. bool HasTopLevelAllocation(const HloInstruction* instruction) const; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 18acd4f3ae47882bf629c090c510db92049e215a..96422b11164a8bcedf8dd1ab7cab8858c909d574 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -296,6 +296,34 @@ TEST_F(BufferAssignmentTest, BufferForConst) { GetAssignedOutputAllocation(*buffers, add); } +TEST_F(BufferAssignmentTest, HasAllocationAt) { + // Create a tuple with non-const and const elements and check that + // HasAllocationAt works correctly. + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "param0")); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({negate, param0, constant})); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() + // reports for the instruction directly. + EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), + buffers->HasAllocationAt(tuple, /*index=*/{})); + EXPECT_EQ(buffers->HasTopLevelAllocation(negate), + buffers->HasAllocationAt(tuple, /*index=*/{0})); + EXPECT_EQ(buffers->HasTopLevelAllocation(param0), + buffers->HasAllocationAt(tuple, /*index=*/{1})); + EXPECT_EQ(buffers->HasTopLevelAllocation(constant), + buffers->HasAllocationAt(tuple, /*index=*/{2})); +} + TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName());