Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
8cf6ece8
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8cf6ece8
编写于
9月 08, 2023
作者:
J
Johannes Reifferscheid
提交者:
TensorFlower Gardener
9月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fusion analysis: Remove remaining dependencies on fusion_.
PiperOrigin-RevId: 563691979
上级
453a8537
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
111 addition
and
83 deletion
+111
-83
third_party/xla/xla/service/gpu/fusions/reduction.cc
third_party/xla/xla/service/gpu/fusions/reduction.cc
+15
-16
third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
+68
-58
third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
+2
-8
third_party/xla/xla/service/gpu/hlo_traversal.cc
third_party/xla/xla/service/gpu/hlo_traversal.cc
+17
-0
third_party/xla/xla/service/gpu/hlo_traversal.h
third_party/xla/xla/service/gpu/hlo_traversal.h
+9
-1
未找到文件。
third_party/xla/xla/service/gpu/fusions/reduction.cc
浏览文件 @
8cf6ece8
...
...
@@ -323,7 +323,7 @@ Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder,
StatusOr
<
std
::
unique_ptr
<
Thunk
>>
BuildFusedInitializerThunk
(
IrEmitterContext
&
ir_emitter_context
,
mlir
::
lmhlo
::
FusionOp
fusion
,
const
Hlo
FusionAnalysis
&
fusion_analysis
,
const
Hlo
Computation
*
fused_computation
,
ElementalIrEmitter
&
elemental_emitter
,
KernelReuseCache
&
kernel_cache
,
int
output_index
,
llvm
::
IRBuilder
<>*
builder
)
{
auto
reduce
=
mlir
::
dyn_cast_or_null
<
mlir
::
mhlo
::
ReduceOp
>
(
...
...
@@ -349,9 +349,6 @@ StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
auto
builder_fn
=
[
&
](
std
::
vector
<
llvm_ir
::
IrArray
>
inputs
,
std
::
vector
<
llvm_ir
::
IrArray
>
outputs
)
->
Status
{
const
HloComputation
*
fused_computation
=
fusion_analysis
.
fused_computation
();
FusedIrEmitter
fused_emitter
(
elemental_emitter
);
for
(
int
i
=
0
;
i
<
fused_computation
->
num_parameters
();
i
++
)
{
fused_emitter
.
BindGenerator
(
...
...
@@ -375,11 +372,11 @@ StatusOr<std::unique_ptr<Thunk>> BuildFusedInitializerThunk(
return
OkStatus
();
};
return
BuildKernelThunkForFusion
(
ir_emitter_context
,
kernel_cache
,
fusion
,
fusion_analysis
.
fused_computation
(),
launch_dimensions
,
/*discriminator=*/
absl
::
StrCat
(
"init_"
,
output_index
),
builder_fn
,
builder
);
return
BuildKernelThunkForFusion
(
ir_emitter_context
,
kernel_cache
,
fusion
,
fused_computation
,
launch_dimensions
,
/*discriminator=*/
absl
::
StrCat
(
"init_"
,
output_index
),
builder_fn
,
builder
);
}
// Gets the output offset as calculated from thread_id.x (to be applied to the
...
...
@@ -972,14 +969,17 @@ StatusOr<FusionEmissionResult> ReductionFusion::Emit(
VLOG
(
3
)
<<
"Launch dimensions of "
<<
mlir
::
mhlo
::
GetDebugNameFromLocation
(
fusion_op
.
getLoc
())
<<
": "
<<
launch_dimensions
.
ToString
();
const
HloComputation
*
fused_computation
=
fusion
.
fused_instructions_computation
();
if
(
!
reduction_codegen_info
->
IsRaceFree
())
{
absl
::
Span
<
HloInstruction
*
const
>
fusion_roots
=
analysis_
.
fusion_roots
();
for
(
int
i
=
0
;
i
<
fusion_roots
.
size
();
++
i
)
{
if
(
IsReductionFromOrToContiguousDimensions
(
*
fusion_roots
[
i
]))
{
TF_ASSIGN_OR_RETURN
(
result
.
thunks
.
emplace_back
(),
BuildFusedInitializerThunk
(
ir_emitter_context
,
fusion_op
,
analysis_
,
elemental_emitter
,
kernel_cache
,
i
,
builder
));
TF_ASSIGN_OR_RETURN
(
result
.
thunks
.
emplace_back
(),
BuildFusedInitializerThunk
(
ir_emitter_context
,
fusion_op
,
fused_computation
,
elemental_emitter
,
kernel_cache
,
i
,
builder
));
}
}
}
...
...
@@ -987,7 +987,6 @@ StatusOr<FusionEmissionResult> ReductionFusion::Emit(
auto
builder_fn
=
[
&
,
this
](
std
::
vector
<
llvm_ir
::
IrArray
>
inputs
,
std
::
vector
<
llvm_ir
::
IrArray
>
outputs
)
->
Status
{
FusedIrEmitter
fused_emitter
(
elemental_emitter
);
const
HloComputation
*
fused_computation
=
analysis_
.
fused_computation
();
for
(
int
i
=
0
;
i
<
fused_computation
->
num_parameters
();
i
++
)
{
HloInstruction
*
fused_operand
=
fused_computation
->
parameter_instruction
(
i
);
...
...
@@ -1042,8 +1041,8 @@ StatusOr<FusionEmissionResult> ReductionFusion::Emit(
TF_ASSIGN_OR_RETURN
(
result
.
thunks
.
emplace_back
(),
BuildKernelThunkForFusion
(
ir_emitter_context
,
kernel_cache
,
fusion_op
,
analysis_
.
fused_computation
()
,
launch_dimensions
,
""
,
builder_fn
,
builder
));
fused_computation
,
launch_dimensions
,
""
,
builder_fn
,
builder
));
return
result
;
}
...
...
third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc
浏览文件 @
8cf6ece8
...
...
@@ -79,37 +79,30 @@ bool AllSliceInputsAreCompatible(
});
}
bool
MayPreventVectorization
(
const
HloComputation
*
fusion
)
{
bool
MayPreventVectorization
(
const
std
::
vector
<
HloInstruction
*>&
fusion_roots
)
{
// An empirically chosen constant: unrolling concat with a large amount of
// arguments causes excessive register spilling.
static
constexpr
int
kMaxConcatArgumentsForUnrolling
=
10
;
for
(
const
HloInstruction
*
instr
:
fusion
->
instructions
())
{
switch
(
instr
->
opcode
())
{
case
HloOpcode
::
kReduceWindow
:
case
HloOpcode
::
kSort
:
case
HloOpcode
::
kDot
:
case
HloOpcode
::
kSin
:
case
HloOpcode
::
kCos
:
case
HloOpcode
::
kTan
:
case
HloOpcode
::
kPower
:
case
HloOpcode
::
kAtan2
:
return
true
;
case
HloOpcode
::
kConcatenate
:
if
(
instr
->
operand_count
()
>
kMaxConcatArgumentsForUnrolling
)
{
return
true
;
}
break
;
case
HloOpcode
::
kReduce
:
if
(
instr
->
shape
().
tuple_shapes_size
()
>
1
)
{
return
true
;
return
HloAnyOf
(
fusion_roots
,
DefaultFusionBoundaryFn
,
[
&
](
const
HloInstruction
&
node
)
{
switch
(
node
.
opcode
())
{
case
HloOpcode
::
kReduceWindow
:
case
HloOpcode
::
kSort
:
case
HloOpcode
::
kDot
:
case
HloOpcode
::
kSin
:
case
HloOpcode
::
kCos
:
case
HloOpcode
::
kTan
:
case
HloOpcode
::
kPower
:
case
HloOpcode
::
kAtan2
:
return
true
;
case
HloOpcode
::
kConcatenate
:
return
node
.
operand_count
()
>
kMaxConcatArgumentsForUnrolling
;
case
HloOpcode
::
kReduce
:
return
node
.
shape
().
tuple_shapes_size
()
>
1
;
default:
return
false
;
}
break
;
default:
break
;
}
}
return
false
;
});
}
// Determines if we enable the row optimized codegen. When we have a fusion with
...
...
@@ -303,10 +296,10 @@ StatusOr<HloFusionAnalysis> HloFusionAnalysis::Create(
std
::
optional
<
TransposeDescription
>
tiled_transpose_hero
=
FindConsistentTransposeHero
(
hlo_roots
,
heroes
);
return
HloFusionAnalysis
(
fusion
,
std
::
move
(
backend_config
),
std
::
move
(
hlo_roo
ts
),
std
::
move
(
fusion_parameter_inputs
),
std
::
move
(
heroes
),
device_info
,
tiled_transpose_hero
);
return
HloFusionAnalysis
(
std
::
move
(
backend_config
),
std
::
move
(
hlo_roots
),
std
::
move
(
fusion_parameter_inpu
ts
),
std
::
move
(
heroes
),
device_info
,
tiled_transpose_hero
);
}
// Returns true if the fusion has consistent transpose heros.
...
...
@@ -473,7 +466,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
int64_t
n_threads_max
=
device_info_
->
threads_per_core_limit
*
device_info_
->
core_count
;
if
(
num_elements
>=
n_threads_max
&&
!
MayPreventVectorization
(
fus
ed_computation
_
))
{
!
MayPreventVectorization
(
fus
ion_roots
_
))
{
unroll_factor
=
ComputeMaxUnrollFactor
(
num_elements
);
}
VLOG
(
2
)
<<
"Unroll factor: "
<<
unroll_factor
;
...
...
@@ -488,25 +481,23 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
int
num_big_inputs
;
std
::
tie
(
row_vectorized
,
num_big_inputs
)
=
RowVectorizationEnabled
(
fusion_roots
(),
GetElementShape
().
rank
());
bool
few_waves
=
[
this
,
row_vectorized
,
num_big_inputs
]()
{
for
(
const
HloInstruction
*
instr
:
fused_computation_
->
instructions
())
{
if
(
instr
->
opcode
()
==
HloOpcode
::
kParameter
||
instr
->
opcode
()
==
HloOpcode
::
kConstant
||
HloInstruction
::
IsOpElementwise
(
instr
->
opcode
()))
{
continue
;
}
if
(
auto
broadcast
=
DynCast
<
HloBroadcastInstruction
>
(
instr
))
{
if
(
broadcast
->
dimensions
().
empty
()
||
// More than 3 big inputs cause a speed regression.
(
row_vectorized
&&
num_big_inputs
<=
3
))
{
continue
;
bool
few_waves
=
!
HloAnyOf
(
fusion_roots_
,
DefaultFusionBoundaryFn
,
[
&
](
const
HloInstruction
&
instr
)
{
if
(
instr
.
opcode
()
==
HloOpcode
::
kParameter
||
instr
.
opcode
()
==
HloOpcode
::
kConstant
||
HloInstruction
::
IsOpElementwise
(
instr
.
opcode
()))
{
return
false
;
}
}
VLOG
(
2
)
<<
"few_waves not enabled due to: "
<<
instr
->
ToString
();
return
false
;
}
return
true
;
}();
if
(
auto
broadcast
=
DynCast
<
HloBroadcastInstruction
>
(
&
instr
))
{
if
(
broadcast
->
dimensions
().
empty
()
||
// More than 3 big inputs cause a speed regression.
(
row_vectorized
&&
num_big_inputs
<=
3
))
{
return
false
;
}
}
VLOG
(
2
)
<<
"few_waves not enabled due to: "
<<
instr
.
ToString
();
return
true
;
});
LaunchDimensionsConfig
launch_config
{
unroll_factor
,
few_waves
,
row_vectorized
};
...
...
@@ -523,7 +514,7 @@ const LaunchDimensionsConfig* HloFusionAnalysis::GetLoopFusionConfig() {
}
const
Shape
&
HloFusionAnalysis
::
GetElementShape
()
const
{
const
Shape
*
shape
=
&
fusion_
->
shape
();
const
Shape
*
shape
=
&
fusion_
roots_
.
front
()
->
shape
();
while
(
shape
->
IsTuple
())
{
shape
=
&
shape
->
tuple_shapes
(
0
);
}
...
...
@@ -581,16 +572,20 @@ HloFusionAnalysis::GroupDisjointReductions() const {
return
{{
fusion_roots
()[
0
]}};
}
HloInstructionMap
<
tensorflow
::
UnionFind
<
HloInstruction
*>>
disjoint_sets
;
ConstHloInstructionMap
<
tensorflow
::
UnionFind
<
const
HloInstruction
*>>
disjoint_sets
;
// TODO(b/249976438): we currently do not treat properly
// aliasing between inputs and outputs of the fusion, so for now put all
// non-reduction roots into one group to avoid read-after-write conflicts.
HloInstruction
*
first_non_reduction_root
=
nullptr
;
ConstHloInstructionMap
<
absl
::
flat_hash_set
<
const
HloInstruction
*>>
reachable_outputs
;
absl
::
flat_hash_set
<
HloInstruction
*>
roots_with_reduction
;
for
(
auto
[
root
,
hero
]
:
llvm
::
zip
(
fusion_roots
(),
fusion_heroes_
))
{
disjoint_sets
[
root
].
Get
()
=
root
;
reachable_outputs
[
root
].
insert
(
root
);
if
(
IsRealReductionHero
(
*
root
,
*
hero
))
{
roots_with_reduction
.
insert
(
root
);
}
else
if
(
first_non_reduction_root
)
{
...
...
@@ -600,9 +595,23 @@ HloFusionAnalysis::GroupDisjointReductions() const {
}
}
std
::
unique_ptr
<
HloReachabilityMap
>
reachability_map
=
HloReachabilityMap
::
Build
(
fused_computation_
);
for
(
HloInstruction
*
instr
:
fused_computation_
->
instructions
())
{
std
::
vector
<
const
HloInstruction
*>
instructions
;
HloBfsConsumersFirstTraversal
(
fusion_roots_
,
[
&
](
const
HloInstruction
&
producer
,
const
HloInstruction
&
consumer
)
{
auto
&
producer_reachable
=
reachable_outputs
[
&
producer
];
for
(
auto
*
instruction
:
reachable_outputs
[
&
consumer
])
{
producer_reachable
.
insert
(
instruction
);
}
return
DefaultFusionBoundaryFn
(
producer
,
consumer
);
},
[
&
](
const
HloInstruction
&
node
)
{
instructions
.
push_back
(
&
node
);
return
TraversalResult
::
kVisitOperands
;
});
for
(
const
HloInstruction
*
instr
:
instructions
)
{
const
auto
&
reachable
=
reachable_outputs
[
instr
];
std
::
vector
<
HloInstruction
*>
reached_output_ids
;
bool
added_to_reduce
=
false
;
for
(
HloInstruction
*
output
:
fusion_roots
())
{
...
...
@@ -618,7 +627,7 @@ HloFusionAnalysis::GroupDisjointReductions() const {
}
}
// Now group output instructions if they have common predecessors.
if
(
reachab
ility_map
->
IsReachable
(
instr
,
output
))
{
if
(
reachab
le
.
contains
(
output
))
{
VLOG
(
3
)
<<
"Reaching "
<<
output
->
ToString
()
<<
" from "
<<
instr
->
ToString
();
reached_output_ids
.
push_back
(
output
);
...
...
@@ -634,12 +643,13 @@ HloFusionAnalysis::GroupDisjointReductions() const {
}
// Place output instructions in the same set into the same group.
HloInstructionMap
<
std
::
vector
<
HloInstruction
*>>
groups
;
Const
HloInstructionMap
<
std
::
vector
<
HloInstruction
*>>
groups
;
for
(
HloInstruction
*
root
:
fusion_roots
())
{
groups
[
disjoint_sets
[
root
].
Get
()].
push_back
(
root
);
}
std
::
vector
<
std
::
vector
<
HloInstruction
*>>
ret
;
ret
.
reserve
(
groups
.
size
());
absl
::
c_for_each
(
groups
,
[
&
](
auto
&
iter
)
{
ret
.
emplace_back
(
std
::
move
(
iter
.
second
));
});
return
ret
;
...
...
@@ -725,7 +735,7 @@ bool HloFusionAnalysis::CanVectorizeReduction(
}
if
(
reduction_dimensions
.
dimensions
[
kDimX
]
%
2
!=
0
||
MayPreventVectorization
(
fusion_
->
fused_instructions_computation
()
))
{
MayPreventVectorization
(
fusion_
roots_
))
{
return
false
;
}
...
...
third_party/xla/xla/service/gpu/hlo_fusion_analysis.h
浏览文件 @
8cf6ece8
...
...
@@ -50,7 +50,6 @@ class HloFusionAnalysis {
static
StatusOr
<
HloFusionAnalysis
>
Create
(
const
HloFusionInstruction
*
fusion
,
const
GpuDeviceInfo
*
device_info
);
const
HloComputation
*
fused_computation
()
const
{
return
fused_computation_
;
}
const
std
::
vector
<
HloInstruction
*>&
fusion_roots
()
const
{
return
fusion_roots_
;
}
...
...
@@ -78,16 +77,13 @@ class HloFusionAnalysis {
const
HloInstruction
*
FindHeroReduction
()
const
;
private:
HloFusionAnalysis
(
const
HloFusionInstruction
*
fusion
,
FusionBackendConfig
fusion_backend_config
,
HloFusionAnalysis
(
FusionBackendConfig
fusion_backend_config
,
std
::
vector
<
HloInstruction
*>
fusion_roots
,
std
::
vector
<
const
HloInstruction
*>
fusion_parameters
,
std
::
vector
<
const
HloInstruction
*>
fusion_heroes
,
const
GpuDeviceInfo
*
device_info
,
std
::
optional
<
TransposeDescription
>
tiled_transpose
)
:
fusion_
(
fusion
),
fusion_backend_config_
(
std
::
move
(
fusion_backend_config
)),
fused_computation_
(
fusion
->
fused_instructions_computation
()),
:
fusion_backend_config_
(
std
::
move
(
fusion_backend_config
)),
fusion_roots_
(
std
::
move
(
fusion_roots
)),
fusion_parameter_inputs_
(
std
::
move
(
fusion_parameters
)),
fusion_heroes_
(
std
::
move
(
fusion_heroes
)),
...
...
@@ -111,9 +107,7 @@ class HloFusionAnalysis {
const
HloInstruction
*
hero_reduction
)
const
;
bool
HasConsistentTransposeHeros
()
const
;
const
HloFusionInstruction
*
fusion_
;
FusionBackendConfig
fusion_backend_config_
;
const
HloComputation
*
fused_computation_
;
std
::
vector
<
HloInstruction
*>
fusion_roots_
;
// The HLO instructions that are inputs into the fusion. These instructions
// are /outside/ the fusion.
...
...
third_party/xla/xla/service/gpu/hlo_traversal.cc
浏览文件 @
8cf6ece8
...
...
@@ -103,5 +103,22 @@ void FindFusionParameters(
[
&
](
const
HloInstruction
&
)
{
return
TraversalResult
::
kVisitOperands
;
});
}
bool
HloAnyOf
(
absl
::
Span
<
const
HloInstruction
*
const
>
roots
,
const
std
::
function
<
bool
(
const
HloInstruction
&
producer
,
const
HloInstruction
&
consumer
)
>&
boundary
,
const
std
::
function
<
bool
(
const
HloInstruction
&
node
)
>&
visit
)
{
bool
result
=
false
;
HloBfsConsumersFirstTraversal
(
roots
,
boundary
,
[
&
](
const
HloInstruction
&
node
)
{
if
(
visit
(
node
))
{
result
=
true
;
return
TraversalResult
::
kAbortTraversal
;
}
return
TraversalResult
::
kVisitOperands
;
});
return
result
;
}
}
// namespace gpu
}
// namespace xla
third_party/xla/xla/service/gpu/hlo_traversal.h
浏览文件 @
8cf6ece8
...
...
@@ -41,7 +41,7 @@ using FusionBoundaryFn = std::function<bool(const HloInstruction& producer,
bool
DefaultFusionBoundaryFn
(
const
HloInstruction
&
producer
,
const
HloInstruction
&
consumer
);
// Visit the HLO nodes starting from `root` in BFS order (consumers before
// Visit the HLO nodes starting from `root
s
` in BFS order (consumers before
// producers). Each node will be visited exactly once. The graph is not
// traversed along edges for which `boundary` returns true.
void
HloBfsConsumersFirstTraversal
(
...
...
@@ -50,6 +50,14 @@ void HloBfsConsumersFirstTraversal(
const
HloInstruction
&
consumer
)
>&
boundary
,
const
std
::
function
<
TraversalResult
(
const
HloInstruction
&
node
)
>&
visit
);
// Visit the HLO nodes starting from `roots`, returning true if the return value
// of `visit` for any of the ones is true.
bool
HloAnyOf
(
absl
::
Span
<
const
HloInstruction
*
const
>
roots
,
const
std
::
function
<
bool
(
const
HloInstruction
&
producer
,
const
HloInstruction
&
consumer
)
>&
boundary
,
const
std
::
function
<
bool
(
const
HloInstruction
&
node
)
>&
visit
);
// Visit the producers of all parameters that are needed by the fusion.
void
FindFusionParameters
(
absl
::
Span
<
const
HloInstruction
*
const
>
roots
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录