Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
e0d0c676
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e0d0c676
编写于
3月 14, 2017
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
3月 14, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor logic from buffer_liveness to use in HeapSimulator.
Also added some simple tests. Change: 150144113
上级
830cde87
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
435 addition
and
132 deletion
+435
-132
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/BUILD
+32
-0
tensorflow/compiler/xla/service/buffer_liveness.cc
tensorflow/compiler/xla/service/buffer_liveness.cc
+3
-125
tensorflow/compiler/xla/service/heap_simulator.cc
tensorflow/compiler/xla/service/heap_simulator.cc
+9
-7
tensorflow/compiler/xla/service/liveness_util.cc
tensorflow/compiler/xla/service/liveness_util.cc
+151
-0
tensorflow/compiler/xla/service/liveness_util.h
tensorflow/compiler/xla/service/liveness_util.h
+51
-0
tensorflow/compiler/xla/service/liveness_util_test.cc
tensorflow/compiler/xla/service/liveness_util_test.cc
+189
-0
未找到文件。
tensorflow/compiler/xla/service/BUILD
浏览文件 @
e0d0c676
...
...
@@ -493,6 +493,36 @@ cc_library(
],
)
cc_library
(
name
=
"liveness_util"
,
srcs
=
[
"liveness_util.cc"
],
hdrs
=
[
"liveness_util.h"
],
deps
=
[
":hlo"
,
":logical_buffer"
,
":tuple_points_to_analysis"
,
"//tensorflow/compiler/xla:shape_util"
,
"//tensorflow/compiler/xla:types"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"liveness_util_test"
,
srcs
=
[
"liveness_util_test.cc"
],
deps
=
[
":hlo"
,
":liveness_util"
,
":tuple_points_to_analysis"
,
"//tensorflow/compiler/xla:shape_util"
,
"//tensorflow/compiler/xla:types"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/compiler/xla/tests:hlo_test_base"
,
"//tensorflow/core:test_main"
,
],
)
cc_library
(
name
=
"buffer_liveness"
,
srcs
=
[
...
...
@@ -504,6 +534,7 @@ cc_library(
deps
=
[
":hlo"
,
":hlo_ordering"
,
":liveness_util"
,
":logical_buffer"
,
":tuple_points_to_analysis"
,
"//tensorflow/compiler/xla:shape_util"
,
...
...
@@ -586,6 +617,7 @@ cc_library(
],
deps
=
[
":hlo"
,
":liveness_util"
,
":logical_buffer"
,
":tuple_points_to_analysis"
,
"//tensorflow/compiler/xla:statusor"
,
...
...
tensorflow/compiler/xla/service/buffer_liveness.cc
浏览文件 @
e0d0c676
...
...
@@ -17,11 +17,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include <set>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
...
...
@@ -92,128 +92,6 @@ string BufferLiveness::ToString() const {
return
tensorflow
::
str_util
::
Join
(
pieces
,
"
\n
"
);
}
namespace
{
// Returns false if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns true otherwise.
// Precondition: 'operand' is an operand of 'user'.
bool
MayUseBufferInOperand
(
HloInstruction
*
operand
,
const
ShapeIndex
&
index
,
HloInstruction
*
user
,
const
TuplePointsToAnalysis
&
points_to_analysis
)
{
if
(
user
->
opcode
()
==
HloOpcode
::
kGetTupleElement
&&
!
index
.
empty
())
{
// GetTupleElement instructions only access the top-level buffer of their
// operand.
return
false
;
}
else
if
(
user
->
opcode
()
==
HloOpcode
::
kFusion
&&
user
->
fusion_kind
()
==
HloInstruction
::
FusionKind
::
kLoop
)
{
// Find fusion parameter associated with 'operand'.
auto
it
=
std
::
find_if
(
user
->
fused_parameters
().
begin
(),
user
->
fused_parameters
().
end
(),
[
=
](
HloInstruction
*
fused_param
)
{
return
user
->
operand
(
fused_param
->
parameter_number
())
==
operand
;
});
CHECK
(
it
!=
user
->
fused_parameters
().
end
());
// Iterate through all users of all buffer aliases of the buffer in the
// points-to set of fusion parameter at 'index'.
// Return true if any uses are detected at 'index', returns false otherwise.
const
LogicalBuffer
*
buffer
=
points_to_analysis
.
GetBufferDefinedAt
(
*
it
,
index
).
ValueOrDie
();
for
(
const
BufferAlias
&
alias
:
points_to_analysis
.
GetBufferAliases
(
*
buffer
))
{
for
(
HloInstruction
*
alias_user
:
alias
.
instruction
()
->
users
())
{
if
(
!
MayUseBufferInOperand
(
alias
.
instruction
(),
alias
.
index
(),
alias_user
,
points_to_analysis
))
{
continue
;
}
// Return true: use detected at 'buffer' -> 'alias' -> 'alias_user'.
return
true
;
}
}
// Return false: found no uses of 'operand' at 'index' in 'user'.
return
false
;
}
return
true
;
}
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
// where 'user' is a user of an alias of 'intruction' at 'index', and
// 'operand_index' is the operand index at which the alias appears in the
// operand list of 'user'.
std
::
vector
<
std
::
pair
<
HloInstruction
*
,
int64
>>
GetAllUsesOfInstructionAtIndex
(
HloInstruction
*
instruction
,
const
ShapeIndex
&
index
,
const
TuplePointsToAnalysis
&
points_to_analysis
)
{
std
::
vector
<
std
::
pair
<
HloInstruction
*
,
int64
>>
uses
;
const
std
::
vector
<
const
LogicalBuffer
*>&
points_to
=
points_to_analysis
.
GetPointsToSet
(
instruction
).
element
(
index
);
for
(
const
LogicalBuffer
*
buffer
:
points_to
)
{
for
(
const
BufferAlias
&
alias
:
points_to_analysis
.
GetBufferAliases
(
*
buffer
))
{
for
(
HloInstruction
*
alias_user
:
alias
.
instruction
()
->
users
())
{
if
(
!
MayUseBufferInOperand
(
alias
.
instruction
(),
alias
.
index
(),
alias_user
,
points_to_analysis
))
{
continue
;
}
for
(
int64
op_idx
:
alias_user
->
OperandIndices
(
alias
.
instruction
()))
{
uses
.
emplace_back
(
alias_user
,
op_idx
);
}
}
}
}
return
uses
;
}
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
// 'operand' (at 'operand_index').
// Returns false otherwise.
// User and operand can share buffers iff both instructions emit the same shape
// and layout, and 'user' meets one of the following two qualifications:
// *) Is element-wise.
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
bool
CanShareOperandBufferWithUser
(
HloInstruction
*
operand
,
const
ShapeIndex
&
operand_index
,
HloInstruction
*
user
,
const
ShapeIndex
&
user_index
,
const
TuplePointsToAnalysis
&
points_to_analysis
)
{
Shape
operand_subshape
=
ShapeUtil
::
GetSubshape
(
operand
->
shape
(),
operand_index
);
Shape
user_subshape
=
ShapeUtil
::
GetSubshape
(
user
->
shape
(),
user_index
);
// Check that operand and user emit the same shape and layout.
if
(
!
ShapeUtil
::
Equal
(
operand_subshape
,
user_subshape
))
{
return
false
;
}
// Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
// fused root instruction.
if
(
user
->
opcode
()
==
HloOpcode
::
kFusion
&&
user
->
fusion_kind
()
==
HloInstruction
::
FusionKind
::
kLoop
&&
user
->
fused_expression_root
()
->
opcode
()
==
HloOpcode
::
kDynamicUpdateSlice
)
{
for
(
auto
&
fused_param
:
user
->
fused_parameters
())
{
// Find fusion parameter associated with 'operand'.
if
(
user
->
operand
(
fused_param
->
parameter_number
())
!=
operand
)
{
continue
;
}
// Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
auto
fused_param_uses
=
GetAllUsesOfInstructionAtIndex
(
fused_param
,
operand_index
,
points_to_analysis
);
// Return true iff there is exactly one use of 'operand' at 'index', and
// this singleton use is the fused root at operand index 0.
if
(
fused_param_uses
.
size
()
==
1
&&
fused_param_uses
[
0
].
first
==
user
->
fused_expression_root
()
&&
fused_param_uses
[
0
].
second
==
0
)
{
return
true
;
}
break
;
}
return
false
;
}
// Check if 'user' is element-wise.
return
user
->
IsElementwise
();
}
}
// anonymous namespace
bool
BufferLiveness
::
live_range_strictly_before
(
const
LogicalBuffer
&
a
,
const
LogicalBuffer
&
b
)
const
{
TF_CHECK_OK
(
points_to_analysis_
->
VerifyBuffer
(
a
));
...
...
@@ -226,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// Every user of 'a' must be a predecessor of 'b' or 'b' itself.
for
(
const
BufferAlias
&
alias
:
points_to_analysis_
->
GetBufferAliases
(
a
))
{
for
(
auto
user
:
alias
.
instruction
()
->
users
())
{
if
(
!
MayUseBufferInOperand
(
alias
.
instruction
(),
alias
.
index
(),
user
,
points_to_analysis
()))
{
if
(
DoesNotUseOperandBuffer
(
alias
.
instruction
(),
alias
.
index
(),
user
,
points_to_analysis
()))
{
continue
;
}
if
(
user
!=
b
.
instruction
()
&&
...
...
tensorflow/compiler/xla/service/heap_simulator.cc
浏览文件 @
e0d0c676
...
...
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace
xla
{
...
...
@@ -26,6 +27,8 @@ namespace xla {
using
tensorflow
::
gtl
::
FlatMap
;
using
tensorflow
::
gtl
::
FlatSet
;
namespace
{
// Returns the set of buffers that may be sources of all operands of the given
// instruction. The returned buffers are guaranteed to have no duplicates, and
// to be sorted in a deterministic order.
...
...
@@ -46,6 +49,8 @@ std::vector<const LogicalBuffer*> UniqueOperandSourceBuffers(
return
sorted
;
}
}
// namespace
/*static*/
StatusOr
<
HeapSimulator
::
Result
>
HeapSimulator
::
Run
(
std
::
unique_ptr
<
HeapAlgorithm
>
algorithm
,
...
...
@@ -145,13 +150,10 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// we must be the last user of the buffer.
bool
shared
=
false
;
for
(
const
LogicalBuffer
*
operand_buffer
:
operand_buffers_to_free
)
{
// The operand buffer can be shared if we have the same shape, and we're
// an elementwise instruction.
//
// TODO(b/35903632): Refactor and use the CanShareOperandBufferWithUser
// logic from buffer_liveness.cc
if
(
ShapeUtil
::
Equal
(
buffer
->
shape
(),
operand_buffer
->
shape
())
&&
instruction
->
IsElementwise
())
{
if
(
buffer
->
instruction
()
->
IsUserOf
(
operand_buffer
->
instruction
())
&&
CanShareOperandBufferWithUser
(
operand_buffer
->
instruction
(),
operand_buffer
->
index
(),
buffer
->
instruction
(),
buffer
->
index
(),
points_to_analysis
))
{
heap
.
ShareBuffer
(
buffer
,
operand_buffer
);
shared
=
true
;
break
;
...
...
tensorflow/compiler/xla/service/liveness_util.cc
0 → 100644
浏览文件 @
e0d0c676
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
namespace
xla
{
bool
DoesNotUseOperandBuffer
(
HloInstruction
*
operand
,
const
ShapeIndex
&
index
,
HloInstruction
*
user
,
const
TuplePointsToAnalysis
&
points_to_analysis
)
{
CHECK
(
user
->
IsUserOf
(
operand
))
<<
"user: "
<<
user
->
ToString
()
<<
" operand: "
<<
operand
->
ToString
();
if
(
user
->
opcode
()
==
HloOpcode
::
kGetTupleElement
&&
!
index
.
empty
())
{
// GetTupleElement instructions only access the top-level buffer of their
// operand.
return
true
;
}
else
if
(
user
->
opcode
()
==
HloOpcode
::
kFusion
&&
user
->
fusion_kind
()
==
HloInstruction
::
FusionKind
::
kLoop
)
{
// Find fusion parameter associated with 'operand'.
auto
it
=
std
::
find_if
(
user
->
fused_parameters
().
begin
(),
user
->
fused_parameters
().
end
(),
[
=
](
HloInstruction
*
fused_param
)
{
return
user
->
operand
(
fused_param
->
parameter_number
())
==
operand
;
});
CHECK
(
it
!=
user
->
fused_parameters
().
end
());
// Iterate through all users of all buffer aliases of the buffer in the
// points-to set of fusion parameter at 'index'.
// Return false if any uses are detected at 'index', returns true otherwise.
const
LogicalBuffer
*
buffer
=
points_to_analysis
.
GetBufferDefinedAt
(
*
it
,
index
).
ValueOrDie
();
for
(
const
BufferAlias
&
alias
:
points_to_analysis
.
GetBufferAliases
(
*
buffer
))
{
for
(
HloInstruction
*
alias_user
:
alias
.
instruction
()
->
users
())
{
if
(
DoesNotUseOperandBuffer
(
alias
.
instruction
(),
alias
.
index
(),
alias_user
,
points_to_analysis
))
{
continue
;
}
// Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
return
false
;
}
}
// Return true: found no uses of 'operand' at 'index' in 'user'.
return
true
;
}
return
false
;
}
namespace
{
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
// where 'user' is a user of an alias of 'intruction' at 'index', and
// 'operand_index' is the operand index at which the alias appears in the
// operand list of 'user'.
std
::
vector
<
std
::
pair
<
HloInstruction
*
,
int64
>>
GetAllUsesOfInstructionAtIndex
(
HloInstruction
*
instruction
,
const
ShapeIndex
&
index
,
const
TuplePointsToAnalysis
&
points_to_analysis
)
{
std
::
vector
<
std
::
pair
<
HloInstruction
*
,
int64
>>
uses
;
const
std
::
vector
<
const
LogicalBuffer
*>&
points_to
=
points_to_analysis
.
GetPointsToSet
(
instruction
).
element
(
index
);
for
(
const
LogicalBuffer
*
buffer
:
points_to
)
{
for
(
const
BufferAlias
&
alias
:
points_to_analysis
.
GetBufferAliases
(
*
buffer
))
{
for
(
HloInstruction
*
alias_user
:
alias
.
instruction
()
->
users
())
{
if
(
DoesNotUseOperandBuffer
(
alias
.
instruction
(),
alias
.
index
(),
alias_user
,
points_to_analysis
))
{
continue
;
}
for
(
int64
op_idx
:
alias_user
->
OperandIndices
(
alias
.
instruction
()))
{
uses
.
emplace_back
(
alias_user
,
op_idx
);
}
}
}
}
return
uses
;
}
}
// namespace
// User and operand can share buffers iff both instructions emit the same shape
// and layout, and 'user' meets one of the following two qualifications:
// *) Is element-wise.
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
bool
CanShareOperandBufferWithUser
(
HloInstruction
*
operand
,
const
ShapeIndex
&
operand_index
,
HloInstruction
*
user
,
const
ShapeIndex
&
user_index
,
const
TuplePointsToAnalysis
&
points_to_analysis
)
{
CHECK
(
user
->
IsUserOf
(
operand
))
<<
"user: "
<<
user
->
ToString
()
<<
" operand: "
<<
operand
->
ToString
();
Shape
operand_subshape
=
ShapeUtil
::
GetSubshape
(
operand
->
shape
(),
operand_index
);
Shape
user_subshape
=
ShapeUtil
::
GetSubshape
(
user
->
shape
(),
user_index
);
// Check that operand and user emit the same shape and layout.
if
(
!
ShapeUtil
::
Equal
(
operand_subshape
,
user_subshape
))
{
return
false
;
}
// Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice
// fused root instruction.
if
(
user
->
opcode
()
==
HloOpcode
::
kFusion
&&
user
->
fusion_kind
()
==
HloInstruction
::
FusionKind
::
kLoop
&&
user
->
fused_expression_root
()
->
opcode
()
==
HloOpcode
::
kDynamicUpdateSlice
)
{
for
(
auto
&
fused_param
:
user
->
fused_parameters
())
{
// Find fusion parameter associated with 'operand'.
if
(
user
->
operand
(
fused_param
->
parameter_number
())
!=
operand
)
{
continue
;
}
// Get all uses of 'operand' at 'index' from 'user.fused_instructions'.
auto
fused_param_uses
=
GetAllUsesOfInstructionAtIndex
(
fused_param
,
operand_index
,
points_to_analysis
);
// Return true iff there is exactly one use of 'operand' at 'index', and
// this singleton use is the fused root at operand index 0.
if
(
fused_param_uses
.
size
()
==
1
&&
fused_param_uses
[
0
].
first
==
user
->
fused_expression_root
()
&&
fused_param_uses
[
0
].
second
==
0
)
{
return
true
;
}
break
;
}
return
false
;
}
// Check if 'user' is element-wise.
return
user
->
IsElementwise
();
}
}
// namespace xla
tensorflow/compiler/xla/service/liveness_util.h
0 → 100644
浏览文件 @
e0d0c676
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A collection of utilities on the HLO graph.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
namespace
xla
{
// Returns true if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
bool
DoesNotUseOperandBuffer
(
HloInstruction
*
operand
,
const
ShapeIndex
&
index
,
HloInstruction
*
user
,
const
TuplePointsToAnalysis
&
points_to_analysis
);
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
// 'operand' (at 'operand_index').
// Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
bool
CanShareOperandBufferWithUser
(
HloInstruction
*
operand
,
const
ShapeIndex
&
operand_index
,
HloInstruction
*
user
,
const
ShapeIndex
&
user_index
,
const
TuplePointsToAnalysis
&
points_to_analysis
);
}
// namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
tensorflow/compiler/xla/service/liveness_util_test.cc
0 → 100644
浏览文件 @
e0d0c676
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include <memory>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace
xla
{
namespace
{
class
PointsToAnalysisTestBase
:
public
HloTestBase
{
protected:
void
BuildModule
(
std
::
unique_ptr
<
HloComputation
>
computation
)
{
module_
=
MakeUnique
<
HloModule
>
(
TestName
());
computation_
=
module_
->
AddEntryComputation
(
std
::
move
(
computation
));
}
void
RunAnalysis
()
{
CHECK_NOTNULL
(
module_
.
get
());
points_to_analysis_
=
TuplePointsToAnalysis
::
Run
(
module_
.
get
(),
/*include_loop_fusion_instructions=*/
true
)
.
ConsumeValueOrDie
();
}
void
BuildModuleAndRunAnalysis
(
std
::
unique_ptr
<
HloComputation
>
computation
)
{
BuildModule
(
std
::
move
(
computation
));
RunAnalysis
();
}
std
::
unique_ptr
<
HloModule
>
module_
;
HloComputation
*
computation_
=
nullptr
;
std
::
unique_ptr
<
TuplePointsToAnalysis
>
points_to_analysis_
;
};
class
DoesNotUseOperandBufferTest
:
public
PointsToAnalysisTestBase
{};
TEST_F
(
DoesNotUseOperandBufferTest
,
GetTupleElement
)
{
auto
builder
=
HloComputation
::
Builder
(
TestName
());
Shape
elem_shape
=
ShapeUtil
::
MakeShape
(
F32
,
{
8
});
auto
tuple
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
ShapeUtil
::
MakeTupleShape
({
elem_shape
,
elem_shape
}),
"tuple"
));
auto
gte0
=
builder
.
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
elem_shape
,
tuple
,
0
));
auto
gte1
=
builder
.
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
elem_shape
,
tuple
,
1
));
builder
.
AddInstruction
(
HloInstruction
::
CreateBinary
(
elem_shape
,
HloOpcode
::
kAdd
,
gte0
,
gte1
));
BuildModuleAndRunAnalysis
(
builder
.
Build
());
// GetTupleElement instructions only access the top-level buffer of their
// operand.
EXPECT_TRUE
(
DoesNotUseOperandBuffer
(
tuple
,
{
0
},
gte0
,
*
points_to_analysis_
));
EXPECT_TRUE
(
DoesNotUseOperandBuffer
(
tuple
,
{
1
},
gte1
,
*
points_to_analysis_
));
EXPECT_FALSE
(
DoesNotUseOperandBuffer
(
tuple
,
{},
gte0
,
*
points_to_analysis_
));
EXPECT_FALSE
(
DoesNotUseOperandBuffer
(
tuple
,
{},
gte1
,
*
points_to_analysis_
));
}
TEST_F
(
DoesNotUseOperandBufferTest
,
FusedDynamicUpdateSlice
)
{
auto
builder
=
HloComputation
::
Builder
(
TestName
());
Shape
data_shape
=
ShapeUtil
::
MakeShape
(
F32
,
{
8
});
auto
tuple
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
ShapeUtil
::
MakeTupleShape
({
data_shape
,
data_shape
}),
"tuple"
));
auto
gte0
=
builder
.
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
data_shape
,
tuple
,
0
));
auto
gte1
=
builder
.
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
data_shape
,
tuple
,
1
));
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto
starts
=
builder
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
LiteralUtil
::
CreateR1
<
int32
>
({
2
})));
auto
update
=
builder
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
LiteralUtil
::
CreateR1
<
float
>
({
2.
f
,
2.
f
,
2.
f
})));
auto
dynamic_update_slice
=
builder
.
AddInstruction
(
HloInstruction
::
CreateDynamicUpdateSlice
(
data_shape
,
gte1
,
update
,
starts
));
builder
.
AddInstruction
(
HloInstruction
::
CreateTuple
({
gte0
,
dynamic_update_slice
}));
BuildModule
(
builder
.
Build
());
auto
fusion
=
computation_
->
CreateFusionInstruction
(
{
dynamic_update_slice
,
starts
,
update
,
gte1
},
HloInstruction
::
FusionKind
::
kLoop
);
RunAnalysis
();
// The fusion instruction never uses tuple element 0, but does use element 1.
EXPECT_TRUE
(
DoesNotUseOperandBuffer
(
tuple
,
{
0
},
fusion
,
*
points_to_analysis_
));
EXPECT_FALSE
(
DoesNotUseOperandBuffer
(
tuple
,
{
1
},
fusion
,
*
points_to_analysis_
));
}
class
CanShareOperandBufferWithUserTest
:
public
PointsToAnalysisTestBase
{};
TEST_F
(
CanShareOperandBufferWithUserTest
,
ElementWiseSameShape
)
{
auto
builder
=
HloComputation
::
Builder
(
TestName
());
Shape
shape
=
ShapeUtil
::
MakeShape
(
F32
,
{
8
});
auto
param
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
shape
,
"param"
));
auto
exp
=
builder
.
AddInstruction
(
HloInstruction
::
CreateUnary
(
shape
,
HloOpcode
::
kExp
,
param
));
auto
log
=
builder
.
AddInstruction
(
HloInstruction
::
CreateUnary
(
shape
,
HloOpcode
::
kLog
,
exp
));
BuildModuleAndRunAnalysis
(
builder
.
Build
());
EXPECT_TRUE
(
CanShareOperandBufferWithUser
(
param
,
{},
exp
,
{},
*
points_to_analysis_
));
EXPECT_TRUE
(
CanShareOperandBufferWithUser
(
exp
,
{},
log
,
{},
*
points_to_analysis_
));
}
TEST_F
(
CanShareOperandBufferWithUserTest
,
ElementWiseDifferentShape
)
{
auto
builder
=
HloComputation
::
Builder
(
TestName
());
Shape
in_shape
=
ShapeUtil
::
MakeShape
(
F32
,
{
8
});
Shape
out_shape
=
ShapeUtil
::
MakeShape
(
PRED
,
{
8
});
auto
param0
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
in_shape
,
"param0"
));
auto
param1
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
1
,
in_shape
,
"param1"
));
auto
result
=
builder
.
AddInstruction
(
HloInstruction
::
CreateBinary
(
out_shape
,
HloOpcode
::
kEq
,
param0
,
param1
));
BuildModuleAndRunAnalysis
(
builder
.
Build
());
EXPECT_FALSE
(
CanShareOperandBufferWithUser
(
param0
,
{},
result
,
{},
*
points_to_analysis_
));
EXPECT_FALSE
(
CanShareOperandBufferWithUser
(
param1
,
{},
result
,
{},
*
points_to_analysis_
));
}
TEST_F
(
CanShareOperandBufferWithUserTest
,
FusedDynamicUpdateSlice
)
{
auto
builder
=
HloComputation
::
Builder
(
TestName
());
Shape
data_shape
=
ShapeUtil
::
MakeShape
(
F32
,
{
8
});
auto
tuple
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
ShapeUtil
::
MakeTupleShape
({
data_shape
,
data_shape
}),
"tuple"
));
auto
gte0
=
builder
.
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
data_shape
,
tuple
,
0
));
auto
gte1
=
builder
.
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
data_shape
,
tuple
,
1
));
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto
starts
=
builder
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
LiteralUtil
::
CreateR1
<
int32
>
({
2
})));
auto
update
=
builder
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
LiteralUtil
::
CreateR1
<
float
>
({
2.
f
,
2.
f
,
2.
f
})));
auto
dynamic_update_slice
=
builder
.
AddInstruction
(
HloInstruction
::
CreateDynamicUpdateSlice
(
data_shape
,
gte1
,
update
,
starts
));
builder
.
AddInstruction
(
HloInstruction
::
CreateTuple
({
gte0
,
dynamic_update_slice
}));
BuildModule
(
builder
.
Build
());
auto
fusion
=
computation_
->
CreateFusionInstruction
(
{
dynamic_update_slice
,
starts
,
update
,
gte1
},
HloInstruction
::
FusionKind
::
kLoop
);
RunAnalysis
();
// The fusion instruction can share with tuple element 1.
EXPECT_FALSE
(
CanShareOperandBufferWithUser
(
tuple
,
{
0
},
fusion
,
{},
*
points_to_analysis_
));
EXPECT_TRUE
(
CanShareOperandBufferWithUser
(
tuple
,
{
1
},
fusion
,
{},
*
points_to_analysis_
));
}
}
// namespace
}
// namespace xla
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录