Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
2c8798d2
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2c8798d2
编写于
9月 12, 2023
作者:
J
Johannes Reifferscheid
提交者:
TensorFlower Gardener
9月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
NFC: Extract LaunchDimension computation from Triton codegen.
PiperOrigin-RevId: 564686943
上级
4aa2ee61
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
92 addition
and
55 deletion
+92
-55
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
+49
-31
third_party/xla/xla/service/gpu/ir_emitter_triton.h
third_party/xla/xla/service/gpu/ir_emitter_triton.h
+29
-16
third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
+10
-6
third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
+4
-2
未找到文件。
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
浏览文件 @
2c8798d2
...
...
@@ -934,14 +934,23 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
}
// namespace
LaunchDimensions
GetMatMulLaunchDimensions
(
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
const
AutotuneResult
::
TritonGemmKey
&
config
)
{
const
HloDotInstruction
*
dot_instr
=
DynCast
<
HloDotInstruction
>
(
hlo_query
::
GetFirstInstructionWithOpcode
(
*
computation
,
HloOpcode
::
kDot
));
const
MatMulDims
dims
(
config
,
*
dot_instr
,
analysis
);
const
MatMulLaunchConfig
launch_config
(
config
,
*
dot_instr
,
dims
);
return
launch_config
.
launch_dims
;
}
// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n].
// TODO(b/270937368): Split this up into smaller functions.
StatusOr
<
LaunchDimensions
>
MatMul
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
)
{
Status
EmitMatMul
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
)
{
const
HloDotInstruction
*
dot_instr
=
DynCast
<
HloDotInstruction
>
(
hlo_query
::
GetFirstInstructionWithOpcode
(
*
computation
,
HloOpcode
::
kDot
));
// Use 32-bit indexing if addressing any of the inputs or the output (which
...
...
@@ -970,9 +979,6 @@ StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
const
int
block_k
=
config
.
block_k
();
const
int
block_n
=
config
.
block_n
();
TF_ASSIGN_OR_RETURN
(
const
TritonFusionAnalysis
analysis
,
TritonFusionAnalysis
::
Execute
(
*
dot_instr
->
parent
(),
split_k
));
const
MatMulDims
dims
(
config
,
*
dot_instr
,
analysis
);
const
MatMulLaunchConfig
launch_config
(
config
,
*
dot_instr
,
dims
);
VLOG
(
6
)
<<
analysis
.
ToString
();
...
...
@@ -1331,15 +1337,29 @@ StatusOr<LaunchDimensions> MatMul(mlir::OpBuilder builder,
b
.
create
<
mt
::
StoreOp
>
(
tensor_pointer
,
values_out
[
producer
],
boundary_checks
,
mt
::
CacheModifier
::
NONE
,
mt
::
EvictionPolicy
::
NORMAL
);
}
return
launch_config
.
launch_dims
;
return
OkStatus
()
;
}
StatusOr
<
LaunchDimensions
>
SoftMax
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
)
{
LaunchDimensions
GetSoftMaxLaunchDimensions
(
const
TritonFusionAnalysis
&
,
const
HloComputation
*
computation
,
const
AutotuneResult
::
TritonGemmKey
&
config
)
{
const
HloInstruction
*
reduce
=
hlo_query
::
GetFirstInstructionWithOpcode
(
*
computation
,
HloOpcode
::
kReduce
);
CHECK_NE
(
reduce
,
nullptr
);
const
Shape
&
reduce_input_shape
=
reduce
->
operand
(
0
)
->
shape
();
int
num_rows
=
1
;
for
(
int
minor_axis
=
1
;
minor_axis
<
reduce_input_shape
.
rank
();
++
minor_axis
)
{
num_rows
*=
reduce_input_shape
.
dimensions_minor
(
minor_axis
);
}
return
{{
num_rows
,
1
,
1
},
{
config
.
num_warps
()
*
WarpSize
(),
1
,
1
}};
}
Status
EmitSoftMax
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
)
{
const
HloInstruction
*
root
=
computation
->
root_instruction
();
auto
loc
=
mlir
::
NameLoc
::
get
(
builder
.
getStringAttr
(
root
->
name
()));
ImplicitLocOpBuilder
b
(
loc
,
builder
);
...
...
@@ -1377,10 +1397,6 @@ StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
block_row
*=
2
;
}
int
num_rows
=
1
;
for
(
int
minor_axis
=
1
;
minor_axis
<
reduce_input_shape
.
rank
();
++
minor_axis
)
num_rows
*=
reduce_input_shape
.
dimensions_minor
(
minor_axis
);
Value
row_index
=
b
.
create
<
ma
::
ExtSIOp
>
(
b
.
getI64Type
(),
b
.
create
<
mt
::
GetProgramIdOp
>
(
mt
::
ProgramIDDim
::
X
));
Value
row_stride
=
CreateConst
(
b
,
b
.
getI32Type
(),
row_len
);
...
...
@@ -1404,8 +1420,6 @@ StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
}
values_out
[
computation
->
parameter_instruction
(
0
)]
=
EmitParameterLoad
(
b
,
make_tensor_pointer
(
fn
.
getArgument
(
0
)),
boundary_checks
);
TF_ASSIGN_OR_RETURN
(
const
auto
analysis
,
TritonFusionAnalysis
::
Execute
(
*
computation
));
// Dimension 0 is the reduced one by construction and it's the only one
// present in the tile shapes.
std
::
vector
<
DimProperties
>
tiled_dims
=
{{
0
,
row_index
,
block_row
}};
...
...
@@ -1418,11 +1432,7 @@ StatusOr<LaunchDimensions> SoftMax(mlir::OpBuilder builder,
b
.
create
<
mt
::
StoreOp
>
(
make_tensor_pointer
(
fn
.
getArgument
(
1
)),
result
,
std
::
vector
<
int32_t
>
{
0
},
mt
::
CacheModifier
::
NONE
,
mt
::
EvictionPolicy
::
NORMAL
);
const
LaunchDimensions
launch_dimensions
{
{
num_rows
,
1
,
1
},
{
config
.
num_warps
()
*
WarpSize
(),
1
,
1
}};
return
launch_dimensions
;
return
OkStatus
();
}
// Simplified copy of translateLLVMToLLVMIR which in addition takes
...
...
@@ -1463,7 +1473,8 @@ StatusOr<LaunchDimensions> TritonWrapper(
absl
::
string_view
fusion_kind
,
const
se
::
CudaComputeCapability
&
cc
,
const
GpuDeviceInfo
&
device_info
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
llvm
::
Module
*
llvm_module
,
LaunchDimensionsGenerator
generator
,
mlir
::
MLIRContext
&
mlir_context
)
{
LaunchDimensionsGenerator
launch_dims_generator
,
TritonIrEmitter
ir_emitter
,
mlir
::
MLIRContext
&
mlir_context
)
{
if
(
fusion_kind
==
kTritonGemmFusionKind
)
{
// This is a heuristic that serves as a proxy for register usage and code
// size.
...
...
@@ -1537,8 +1548,13 @@ StatusOr<LaunchDimensions> TritonWrapper(
.
debug_options
()
.
xla_gpu_cuda_data_dir
());
TF_ASSIGN_OR_RETURN
(
LaunchDimensions
launch_dimensions
,
generator
(
b
,
libdevice_path
,
hlo_computation
,
fn
,
config
,
TF_ASSIGN_OR_RETURN
(
auto
analysis
,
fusion_kind
==
kTritonGemmFusionKind
?
TritonFusionAnalysis
::
Execute
(
*
hlo_computation
,
config
.
split_k
())
:
TritonFusionAnalysis
::
Execute
(
*
hlo_computation
));
TF_RETURN_IF_ERROR
(
ir_emitter
(
b
,
libdevice_path
,
analysis
,
hlo_computation
,
fn
,
config
,
device_info
.
shared_memory_per_block_optin
));
b
.
create
<
mt
::
ReturnOp
>
(
loc
);
...
...
@@ -1613,7 +1629,6 @@ StatusOr<LaunchDimensions> TritonWrapper(
if
(
shared_mem_bytes
>
device_info
.
shared_memory_per_block_optin
)
{
return
ResourceExhausted
(
"Shared memory size limit exceeded."
);
}
launch_dimensions
.
SetSharedMemBytes
(
shared_mem_bytes
);
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
llvm
::
Module
>
ll_triton_module
,
TranslateLLVMToLLVMIR
(
&
llvm_module
->
getContext
(),
...
...
@@ -1630,6 +1645,9 @@ StatusOr<LaunchDimensions> TritonWrapper(
llvm
::
Linker
::
Flags
::
OverrideFromSrc
));
LogAndVerify
(
llvm_module
);
LaunchDimensions
launch_dimensions
=
launch_dims_generator
(
analysis
,
hlo_computation
,
config
);
launch_dimensions
.
SetSharedMemBytes
(
shared_mem_bytes
);
return
launch_dimensions
;
}
...
...
third_party/xla/xla/service/gpu/ir_emitter_triton.h
浏览文件 @
2c8798d2
...
...
@@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/gemm_rewriter_triton.h"
#include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/statusor.h"
...
...
@@ -30,26 +31,36 @@ limitations under the License.
namespace
xla
{
namespace
gpu
{
// Compute the launch dimensions for the given Triton MatMul.
LaunchDimensions
GetMatMulLaunchDimensions
(
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
const
AutotuneResult
::
TritonGemmKey
&
config
);
// Use tiling and execution parameters from 'config'.
StatusOr
<
LaunchDimensions
>
MatMul
(
mlir
::
OpBuilder
b
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
);
Status
EmitMatMul
(
mlir
::
OpBuilder
b
,
absl
::
string_view
libdevice_path
,
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
);
// Compute the launch dimensions for the given Triton SoftMax.
LaunchDimensions
GetSoftMaxLaunchDimensions
(
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
const
AutotuneResult
::
TritonGemmKey
&
config
);
// Generate Softmax in Triton IR inside 'fn'.
// Use execution parameters from 'config'.
StatusOr
<
LaunchDimensions
>
SoftMax
(
mlir
::
OpBuilder
b
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
);
Status
EmitSoftMax
(
mlir
::
OpBuilder
b
,
absl
::
string_view
libdevice_path
,
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
);
using
LaunchDimensionsGenerator
=
std
::
function
<
StatusOr
<
LaunchDimensions
>
(
mlir
::
OpBuilder
,
absl
::
string_view
,
const
HloComputation
*
,
mlir
::
triton
::
FuncOp
,
const
AutotuneResult
::
TritonGemmKey
&
,
int
)
>
;
using
LaunchDimensionsGenerator
=
std
::
function
<
LaunchDimensions
(
const
TritonFusionAnalysis
&
,
const
HloComputation
*
,
const
AutotuneResult
::
TritonGemmKey
&
)
>
;
using
TritonIrEmitter
=
std
::
function
<
Status
(
mlir
::
OpBuilder
,
absl
::
string_view
,
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
,
mlir
::
triton
::
FuncOp
,
const
AutotuneResult
::
TritonGemmKey
&
,
int
)
>
;
// Generate Triton IR by running the provided generator, compile it into LLVM IR
// and return launch dimensions.
...
...
@@ -59,7 +70,9 @@ StatusOr<LaunchDimensions> TritonWrapper(
absl
::
string_view
fusion_kind
,
const
se
::
CudaComputeCapability
&
cc
,
const
GpuDeviceInfo
&
device_info
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
llvm
::
Module
*
llvm_module
,
LaunchDimensionsGenerator
generator
,
mlir
::
MLIRContext
&
mlir_context
);
LaunchDimensionsGenerator
launch_dims_generator
,
TritonIrEmitter
ir_emitter
,
mlir
::
MLIRContext
&
mlir_context
);
}
// namespace gpu
}
// namespace xla
...
...
third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc
浏览文件 @
2c8798d2
...
...
@@ -215,7 +215,8 @@ ENTRY entry {
TritonWrapper
(
"test_fn"
,
triton_dot_computation
,
kTritonGemmFusionKind
,
se
::
CudaComputeCapability
{
se
::
CudaComputeCapability
::
AMPERE
,
/*minor=*/
0
},
dev_info
,
config
,
&
llvm_module
,
&
MatMul
,
mlir_context
),
dev_info
,
config
,
&
llvm_module
,
&
GetMatMulLaunchDimensions
,
&
EmitMatMul
,
mlir_context
),
tsl
::
testing
::
StatusIs
(
tsl
::
error
::
RESOURCE_EXHAUSTED
,
"Shared memory size limit exceeded."
));
...
...
@@ -228,7 +229,8 @@ ENTRY entry {
TritonWrapper
(
"test_fn"
,
triton_dot_computation
,
kTritonGemmFusionKind
,
se
::
CudaComputeCapability
{
se
::
CudaComputeCapability
::
AMPERE
,
/*minor=*/
0
},
dev_info
,
config
,
&
llvm_module
,
&
MatMul
,
mlir_context
));
dev_info
,
config
,
&
llvm_module
,
&
GetMatMulLaunchDimensions
,
&
EmitMatMul
,
mlir_context
));
// Use optin shared memory which is > shared_memory_per_block.
EXPECT_GT
(
launch_dimensions
.
SharedMemBytes
(),
dev_info
.
shared_memory_per_block
);
...
...
@@ -642,7 +644,8 @@ ENTRY entry {
TritonWrapper
(
"test_fn"
,
triton_dot_computation
,
kTritonGemmFusionKind
,
se
::
CudaComputeCapability
{
se
::
CudaComputeCapability
::
AMPERE
,
/*minor=*/
0
},
dev_info
,
config
,
&
llvm_module
,
&
MatMul
,
mlir_context
),
dev_info
,
config
,
&
llvm_module
,
&
GetMatMulLaunchDimensions
,
&
EmitMatMul
,
mlir_context
),
tsl
::
testing
::
StatusIs
(
tsl
::
error
::
RESOURCE_EXHAUSTED
,
"Tiling complexity heuristic exceeded: 147456 > 9000"
));
...
...
@@ -655,7 +658,8 @@ ENTRY entry {
TritonWrapper
(
"test_fn"
,
triton_dot_computation
,
kTritonGemmFusionKind
,
se
::
CudaComputeCapability
{
se
::
CudaComputeCapability
::
AMPERE
,
/*minor=*/
0
},
dev_info
,
config
,
&
llvm_module
,
&
MatMul
,
mlir_context
)
dev_info
,
config
,
&
llvm_module
,
&
GetMatMulLaunchDimensions
,
&
EmitMatMul
,
mlir_context
)
.
status
());
}
...
...
@@ -1438,8 +1442,8 @@ ENTRY e {
const
LaunchDimensions
launch_dimensions
,
TritonWrapper
(
"test_fn"
,
triton_dot_computation
,
kTritonGemmFusionKind
,
GetCudaComputeCapability
(),
dev_info
,
config
.
triton_gemm_config
(),
&
llvm_module
,
&
MatMul
,
mlir_context
));
config
.
triton_gemm_config
(),
&
llvm_module
,
&
GetMatMulLaunchDimensions
,
&
EmitMatMul
,
mlir_context
));
// The config is chosen so that the used memory size is slightly above the
// 48 kB boundary of standard / optin shared memory so that any GPU that
// has the optin one should be able to execute the test.
...
...
third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
浏览文件 @
2c8798d2
...
...
@@ -1745,7 +1745,8 @@ Status IrEmitterUnnested::EmitTritonFusion(
TritonWrapper
(
impl_fn_name
,
hlo_computation
,
kTritonSoftmaxFusionKind
,
ir_emitter_context_
->
cuda_compute_capability
(),
ir_emitter_context_
->
gpu_device_info
(),
config
,
module_
,
&
SoftMax
,
*
ir_emitter_context_
->
mlir_context
()));
&
GetSoftMaxLaunchDimensions
,
&
EmitSoftMax
,
*
ir_emitter_context_
->
mlir_context
()));
}
else
{
// Must be a MatMul
CHECK_EQ
(
fusion_kind
,
kTritonGemmFusionKind
);
TF_ASSIGN_OR_RETURN
(
...
...
@@ -1753,7 +1754,8 @@ Status IrEmitterUnnested::EmitTritonFusion(
TritonWrapper
(
impl_fn_name
,
hlo_computation
,
kTritonGemmFusionKind
,
ir_emitter_context_
->
cuda_compute_capability
(),
ir_emitter_context_
->
gpu_device_info
(),
config
,
module_
,
&
MatMul
,
*
ir_emitter_context_
->
mlir_context
()));
&
GetMatMulLaunchDimensions
,
&
EmitMatMul
,
*
ir_emitter_context_
->
mlir_context
()));
}
llvm
::
Function
*
impl_fn
=
module_
->
getFunction
(
impl_fn_name
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录