diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 156cb85f6658f83d5a0ef95f7162f176a85c95dc..692d186b14dfd5396c83ab21114530c301f33292 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -493,6 +493,36 @@ cc_library( ], ) +cc_library( + name = "liveness_util", + srcs = ["liveness_util.cc"], + hdrs = ["liveness_util.h"], + deps = [ + ":hlo", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "liveness_util_test", + srcs = ["liveness_util_test.cc"], + deps = [ + ":hlo", + ":liveness_util", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "buffer_liveness", srcs = [ @@ -504,6 +534,7 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -586,6 +617,7 @@ cc_library( ], deps = [ ":hlo", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index b5a2936b670cd8b0b8ee1c834f820fa23fe441f4..0fe6e37c00f283791874163c16fded96f9c827bc 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -17,11 +17,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" -#include #include #include #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -92,128 +92,6 @@ string BufferLiveness::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -namespace { - -// Returns false if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns true otherwise. -// Precondition: 'operand' is an operand of 'user'. -bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return false; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return true if any uses are detected at 'index', returns false otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return true; - } - } - // Return false: found no uses of 'operand' at 'index' in 'user'. - return false; - } - return true; -} - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'intruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector> uses; - const std::vector& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). -// Returns false otherwise. -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following two qualifications: -// *) Is element-wise. -// *) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - Shape operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice - // fused root instruction. - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - for (auto& fused_param : user->fused_parameters()) { - // Find fusion parameter associated with 'operand'. - if (user->operand(fused_param->parameter_number()) != operand) { - continue; - } - // Get all uses of 'operand' at 'index' from 'user.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root at operand index 0. - if (fused_param_uses.size() == 1 && - fused_param_uses[0].first == user->fused_expression_root() && - fused_param_uses[0].second == 0) { - return true; - } - break; - } - return false; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // anonymous namespace - bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); @@ -226,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { for (auto user : alias.instruction()->users()) { - if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, + points_to_analysis())) { continue; } if (user != b.instruction() && diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 76702f52e02553302f468049c9e9418535d24bec..46c0d8edead1eaba518fd1040b7dd7d0d6c79159 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -26,6 +27,8 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; +namespace { + // Returns the set of buffers that may be sources of all operands of the given // instruction. The returned buffers are guaranteed to have no duplicates, and // to be sorted in a deterministic order. @@ -46,6 +49,8 @@ std::vector UniqueOperandSourceBuffers( return sorted; } +} // namespace + /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, @@ -145,13 +150,10 @@ StatusOr HeapSimulator::Run( // we must be the last user of the buffer. bool shared = false; for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { - // The operand buffer can be shared if we have the same shape, and we're - // an elementwise instruction. - // - // TODO(b/35903632): Refactor and use the CanShareOperandBufferWithUser - // logic from buffer_liveness.cc - if (ShapeUtil::Equal(buffer->shape(), operand_buffer->shape()) && - instruction->IsElementwise()) { + if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && + CanShareOperandBufferWithUser( + operand_buffer->instruction(), operand_buffer->index(), + buffer->instruction(), buffer->index(), points_to_analysis)) { heap.ShareBuffer(buffer, operand_buffer); shared = true; break; diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d157e8fd5f42682af51c5195dca1de6903090bf --- /dev/null +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/liveness_util.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, + HloInstruction* user, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = + points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user, points_to_analysis)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +namespace { + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'intruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis) { + std::vector> uses; + const std::vector& points_to = + points_to_analysis.GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user, points_to_analysis)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +} // namespace + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following two qualifications: +// *) Is element-wise. +// *) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. +bool CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + Shape operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice + // fused root instruction. + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + for (auto& fused_param : user->fused_parameters()) { + // Find fusion parameter associated with 'operand'. + if (user->operand(fused_param->parameter_number()) != operand) { + continue; + } + // Get all uses of 'operand' at 'index' from 'user.fused_instructions'. + auto fused_param_uses = GetAllUsesOfInstructionAtIndex( + fused_param, operand_index, points_to_analysis); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root at operand index 0. + if (fused_param_uses.size() == 1 && + fused_param_uses[0].first == user->fused_expression_root() && + fused_param_uses[0].second == 0) { + return true; + } + break; + } + return false; + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h new file mode 100644 index 0000000000000000000000000000000000000000..410a7b1b519e117f21c01938cb8e4a5b1c358ad2 --- /dev/null +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A collection of utilities on the HLO graph. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Returns true if 'user' cannot possibly use the buffer at 'index' in +// 'operand'. Returns false otherwise. +// +// REQUIRES: 'operand' is an operand of 'user'. +bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, + HloInstruction* user, + const TuplePointsToAnalysis& points_to_analysis); + +// Returns true if 'user' (at 'user_index') can share a buffer with its operand +// 'operand' (at 'operand_index'). +// Returns false otherwise. +// +// REQUIRES: 'operand' is an operand of 'user'. +bool CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index, + const TuplePointsToAnalysis& points_to_analysis); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ff71d6f3c8eff58b83783fc867d5874c6c700a3 --- /dev/null +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/liveness_util.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = MakeUnique(TestName()); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get(), + /*include_loop_fusion_instructions=*/true) + .ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); + EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); + EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE( + DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); + EXPECT_FALSE( + DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, + *points_to_analysis_)); + EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, + *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, + *points_to_analysis_)); +} + +} // namespace +} // namespace xla