提交 e0d0c676 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Refactor logic from buffer_liveness to use in HeapSimulator.

Also added some simple tests.
Change: 150144113
上级 830cde87
......@@ -493,6 +493,36 @@ cc_library(
],
)
cc_library(
name = "liveness_util",
srcs = ["liveness_util.cc"],
hdrs = ["liveness_util.h"],
deps = [
":hlo",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
cc_test(
name = "liveness_util_test",
srcs = ["liveness_util_test.cc"],
deps = [
":hlo",
":liveness_util",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "buffer_liveness",
srcs = [
......@@ -504,6 +534,7 @@ cc_library(
deps = [
":hlo",
":hlo_ordering",
":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
......@@ -586,6 +617,7 @@ cc_library(
],
deps = [
":hlo",
":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
......
......@@ -17,11 +17,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include <set>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
......@@ -92,128 +92,6 @@ string BufferLiveness::ToString() const {
return tensorflow::str_util::Join(pieces, "\n");
}
namespace {
// Returns false if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns true otherwise.
// Precondition: 'operand' is an operand of 'user'.
bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index,
HloInstruction* user,
const TuplePointsToAnalysis& points_to_analysis) {
if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
// GetTupleElement instructions only access the top-level buffer of their
// operand.
return false;
} else if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Find fusion parameter associated with 'operand'.
auto it = std::find_if(
user->fused_parameters().begin(), user->fused_parameters().end(),
[=](HloInstruction* fused_param) {
return user->operand(fused_param->parameter_number()) == operand;
});
CHECK(it != user->fused_parameters().end());
// Iterate through all users of all buffer aliases of the buffer in the
// points-to set of fusion parameter at 'index'.
// Return true if any uses are detected at 'index', returns false otherwise.
const LogicalBuffer* buffer =
points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
for (const BufferAlias& alias :
points_to_analysis.GetBufferAliases(*buffer)) {
for (HloInstruction* alias_user : alias.instruction()->users()) {
if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
alias_user, points_to_analysis)) {
continue;
}
// Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'.
return true;
}
}
// Return false: found no uses of 'operand' at 'index' in 'user'.
return false;
}
return true;
}
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
// where 'user' is a user of an alias of 'intruction' at 'index', and
// 'operand_index' is the operand index at which the alias appears in the
// operand list of 'user'.
std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
HloInstruction* instruction, const ShapeIndex& index,
const TuplePointsToAnalysis& points_to_analysis) {
std::vector<std::pair<HloInstruction*, int64>> uses;
const std::vector<const LogicalBuffer*>& points_to =
points_to_analysis.GetPointsToSet(instruction).element(index);
for (const LogicalBuffer* buffer : points_to) {
for (const BufferAlias& alias :
points_to_analysis.GetBufferAliases(*buffer)) {
for (HloInstruction* alias_user : alias.instruction()->users()) {
if (!MayUseBufferInOperand(alias.instruction(), alias.index(),
alias_user, points_to_analysis)) {
continue;
}
for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
uses.emplace_back(alias_user, op_idx);
}
}
}
}
return uses;
}
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
// 'operand' (at 'operand_index').
// Returns false otherwise.
// User and operand can share buffers iff both instructions emit the same shape
// and layout, and 'user' meets one of the following two qualifications:
// *) Is element-wise.
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
const TuplePointsToAnalysis& points_to_analysis) {
Shape operand_subshape =
ShapeUtil::GetSubshape(operand->shape(), operand_index);
Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
}
// Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
// fused root instruction.
if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
for (auto& fused_param : user->fused_parameters()) {
// Find fusion parameter associated with 'operand'.
if (user->operand(fused_param->parameter_number()) != operand) {
continue;
}
// Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
fused_param, operand_index, points_to_analysis);
// Return true iff there is exactly one use of 'operand' at 'index', and
// this singleton use is the fused root at operand index 0.
if (fused_param_uses.size() == 1 &&
fused_param_uses[0].first == user->fused_expression_root() &&
fused_param_uses[0].second == 0) {
return true;
}
break;
}
return false;
}
// Check if 'user' is element-wise.
return user->IsElementwise();
}
} // anonymous namespace
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const {
TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
......@@ -226,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
for (auto user : alias.instruction()->users()) {
if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user,
points_to_analysis())) {
if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user,
points_to_analysis())) {
continue;
}
if (user != b.instruction() &&
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
......@@ -26,6 +27,8 @@ namespace xla {
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
namespace {
// Returns the set of buffers that may be sources of all operands of the given
// instruction. The returned buffers are guaranteed to have no duplicates, and
// to be sorted in a deterministic order.
......@@ -46,6 +49,8 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
return sorted;
}
} // namespace
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm,
......@@ -145,13 +150,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// we must be the last user of the buffer.
bool shared = false;
for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) {
// The operand buffer can be shared if we have the same shape, and we're
// an elementwise instruction.
//
// TODO(b/35903632): Refactor and use the CanShareOperandBufferWithUser
// logic from buffer_liveness.cc
if (ShapeUtil::Equal(buffer->shape(), operand_buffer->shape()) &&
instruction->IsElementwise()) {
if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
CanShareOperandBufferWithUser(
operand_buffer->instruction(), operand_buffer->index(),
buffer->instruction(), buffer->index(), points_to_analysis)) {
heap.ShareBuffer(buffer, operand_buffer);
shared = true;
break;
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index,
HloInstruction* user,
const TuplePointsToAnalysis& points_to_analysis) {
CHECK(user->IsUserOf(operand))
<< "user: " << user->ToString() << " operand: " << operand->ToString();
if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
// GetTupleElement instructions only access the top-level buffer of their
// operand.
return true;
} else if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Find fusion parameter associated with 'operand'.
auto it = std::find_if(
user->fused_parameters().begin(), user->fused_parameters().end(),
[=](HloInstruction* fused_param) {
return user->operand(fused_param->parameter_number()) == operand;
});
CHECK(it != user->fused_parameters().end());
// Iterate through all users of all buffer aliases of the buffer in the
// points-to set of fusion parameter at 'index'.
// Return false if any uses are detected at 'index', returns true otherwise.
const LogicalBuffer* buffer =
points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
for (const BufferAlias& alias :
points_to_analysis.GetBufferAliases(*buffer)) {
for (HloInstruction* alias_user : alias.instruction()->users()) {
if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
alias_user, points_to_analysis)) {
continue;
}
// Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
return false;
}
}
// Return true: found no uses of 'operand' at 'index' in 'user'.
return true;
}
return false;
}
namespace {
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
// where 'user' is a user of an alias of 'intruction' at 'index', and
// 'operand_index' is the operand index at which the alias appears in the
// operand list of 'user'.
std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
HloInstruction* instruction, const ShapeIndex& index,
const TuplePointsToAnalysis& points_to_analysis) {
std::vector<std::pair<HloInstruction*, int64>> uses;
const std::vector<const LogicalBuffer*>& points_to =
points_to_analysis.GetPointsToSet(instruction).element(index);
for (const LogicalBuffer* buffer : points_to) {
for (const BufferAlias& alias :
points_to_analysis.GetBufferAliases(*buffer)) {
for (HloInstruction* alias_user : alias.instruction()->users()) {
if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
alias_user, points_to_analysis)) {
continue;
}
for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
uses.emplace_back(alias_user, op_idx);
}
}
}
}
return uses;
}
} // namespace
// User and operand can share buffers iff both instructions emit the same shape
// and layout, and 'user' meets one of the following two qualifications:
// *) Is element-wise.
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
const TuplePointsToAnalysis& points_to_analysis) {
CHECK(user->IsUserOf(operand))
<< "user: " << user->ToString() << " operand: " << operand->ToString();
Shape operand_subshape =
ShapeUtil::GetSubshape(operand->shape(), operand_index);
Shape user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index);
// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
}
// Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
// fused root instruction.
if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
for (auto& fused_param : user->fused_parameters()) {
// Find fusion parameter associated with 'operand'.
if (user->operand(fused_param->parameter_number()) != operand) {
continue;
}
// Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
fused_param, operand_index, points_to_analysis);
// Return true iff there is exactly one use of 'operand' at 'index', and
// this singleton use is the fused root at operand index 0.
if (fused_param_uses.size() == 1 &&
fused_param_uses[0].first == user->fused_expression_root() &&
fused_param_uses[0].second == 0) {
return true;
}
break;
}
return false;
}
// Check if 'user' is element-wise.
return user->IsElementwise();
}
} // namespace xla
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A collection of utilities on the HLO graph.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Returns true if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index,
HloInstruction* user,
const TuplePointsToAnalysis& points_to_analysis);
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
// 'operand' (at 'operand_index').
// Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
const TuplePointsToAnalysis& points_to_analysis);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include <memory>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace {
class PointsToAnalysisTestBase : public HloTestBase {
protected:
void BuildModule(std::unique_ptr<HloComputation> computation) {
module_ = MakeUnique<HloModule>(TestName());
computation_ = module_->AddEntryComputation(std::move(computation));
}
void RunAnalysis() {
CHECK_NOTNULL(module_.get());
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get(),
/*include_loop_fusion_instructions=*/true)
.ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
BuildModule(std::move(computation));
RunAnalysis();
}
std::unique_ptr<HloModule> module_;
HloComputation* computation_ = nullptr;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
};
class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
auto builder = HloComputation::Builder(TestName());
Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
builder.AddInstruction(
HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
BuildModuleAndRunAnalysis(builder.Build());
// GetTupleElement instructions only access the top-level buffer of their
// operand.
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_));
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
}
TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
BuildModule(builder.Build());
auto fusion = computation_->CreateFusionInstruction(
{dynamic_update_slice, starts, update, gte1},
HloInstruction::FusionKind::kLoop);
RunAnalysis();
// The fusion instruction never uses tuple element 0, but does use element 1.
EXPECT_TRUE(
DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
EXPECT_FALSE(
DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
}
class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {8});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_TRUE(
CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
EXPECT_TRUE(
CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
auto builder = HloComputation::Builder(TestName());
Shape in_shape = ShapeUtil::MakeShape(F32, {8});
Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, in_shape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, in_shape, "param1"));
auto result = builder.AddInstruction(
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
*points_to_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
*points_to_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
BuildModule(builder.Build());
auto fusion = computation_->CreateFusionInstruction(
{dynamic_update_slice, starts, update, gte1},
HloInstruction::FusionKind::kLoop);
RunAnalysis();
// The fusion instruction can share with tuple element 1.
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
*points_to_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
*points_to_analysis_));
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册