Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
4ae35f9f
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4ae35f9f
编写于
9月 13, 2023
作者:
G
George Karpenkov
提交者:
TensorFlower Gardener
9月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NFC] [XLA:GPU] Split up Triton emitter into multiple functions.
PiperOrigin-RevId: 565024582
上级
af5aa0e5
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
232 addition
and
190 deletion
+232
-190
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
+232
-190
未找到文件。
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
浏览文件 @
4ae35f9f
...
...
@@ -946,6 +946,203 @@ void ValidateMatMulConfig(const AutotuneResult::TritonGemmKey& config,
2
+
(
config
.
split_k
()
>
1
?
1
:
0
)
+
num_batch_dims
);
}
struct
Side
{
TritonFusionAnalysis
::
Scope
scope
;
std
::
vector
<
DimProperties
>
tiled_dims
;
std
::
optional
<
int64_t
>
batch_dim_idx
;
};
class
MatMulEmitterHelper
{
public:
MatMulEmitterHelper
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloDotInstruction
*
dot_instr
,
ImplicitLocOpBuilder
&
b
,
Type
index_ty
,
MatMulDims
dims
,
const
MatMulLaunchConfig
&
launch_config
,
const
TritonFusionAnalysis
&
analysis
)
:
b_
(
b
),
libdevice_path_
(
libdevice_path
),
dot_instr_
(
dot_instr
),
index_ty_
(
index_ty
),
analysis_
(
analysis
),
dims_
(
dims
),
launch_config_
(
launch_config
)
{}
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir
::
FloatType
GetDotAccumulatorType
()
{
Type
dot_output_ty
=
TritonType
(
b_
,
dot_instr_
->
shape
().
element_type
());
// Data type of dot() immediate inputs.
Type
dot_input_ty
=
[
&
]
{
const
Type
lhs_ty
=
TritonType
(
b_
,
dot_instr_
->
operand
(
0
)
->
shape
().
element_type
());
const
Type
rhs_ty
=
TritonType
(
b_
,
dot_instr_
->
operand
(
1
)
->
shape
().
element_type
());
CHECK
(
lhs_ty
==
rhs_ty
);
return
lhs_ty
;
}();
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
return
(
dot_output_ty
.
isF64
()
&&
dot_input_ty
.
isF64
())
?
b_
.
getF64Type
()
:
b_
.
getF32Type
();
}
std
::
vector
<
const
HloInstruction
*>
EpiloguePostOrderTransitiveOperands
(
const
HloInstruction
*
root
)
{
// Collect all instructions of the dot's output scope.
absl
::
flat_hash_set
<
const
HloInstruction
*>
to_order
;
{
std
::
queue
<
const
HloInstruction
*>
to_add
;
if
(
root
!=
dot_instr_
)
{
to_add
.
push
(
root
);
}
while
(
!
to_add
.
empty
())
{
const
HloInstruction
*
current
=
to_add
.
front
();
for
(
const
HloInstruction
*
operand
:
current
->
operands
())
{
if
(
!
to_order
.
contains
(
operand
))
{
if
(
operand
!=
dot_instr_
)
{
to_add
.
push
(
operand
);
}
}
}
CHECK
(
to_order
.
insert
(
current
).
second
);
to_add
.
pop
();
}
}
// Order them producers before consumers.
std
::
vector
<
const
HloInstruction
*>
to_emit
;
for
(
const
HloInstruction
*
hlo
:
dot_instr_
->
parent
()
->
MakeInstructionPostOrder
())
{
if
(
to_order
.
contains
(
hlo
))
{
to_emit
.
push_back
(
hlo
);
}
}
return
to_emit
;
}
Value
MakeInput
(
Side
&
side
,
int64_t
operand_index
,
absl
::
flat_hash_map
<
const
HloInstruction
*
,
Value
>&
values
)
{
return
*
EmitScope
(
b_
,
libdevice_path_
,
&
analysis_
,
side
.
scope
,
side
.
tiled_dims
,
dot_instr_
->
parent
()
->
MakeInstructionPostOrderFrom
(
const_cast
<
HloInstruction
&>
(
*
dot_instr_
->
operand
(
operand_index
))),
values
);
}
Value
EmitTensorPointer
(
const
HloInstruction
*
hlo
,
const
Side
&
side
,
Value
base
,
Value
pid_k
,
std
::
vector
<
int32_t
>&
boundary_checks
)
{
auto
pid_batch
=
b_
.
create
<
mt
::
GetProgramIdOp
>
(
launch_config_
.
batch_program_id_dim
);
std
::
vector
<
Value
>
bounds
;
std
::
vector
<
Value
>
strides
;
std
::
vector
<
Value
>
offsets
;
std
::
vector
<
int32_t
>
block_dims
;
std
::
vector
<
int32_t
>
dim_order
;
auto
add_dim
=
[
&
](
const
DimProperties
&
properties
)
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis_
.
IterSpec
(
side
.
scope
,
hlo
,
properties
.
index
);
if
(
spec
==
nullptr
)
{
return
;
}
const
int64_t
stride
=
spec
->
at
(
0
).
stride
;
int64_t
count
=
spec
->
at
(
0
).
count
;
if
(
side
.
scope
==
TritonFusionAnalysis
::
Scope
::
OUTPUT
&&
properties
.
index
==
dims_
.
out_lhs_noncontracting_dim_idx
&&
spec
->
size
()
==
1
&&
dims_
.
lhs_noncontracting_split
.
has_value
())
{
// Dimension of the output produced by the non-contracting LHS one
// is logically split, major part is addressed using pid_batch.
count
/=
*
dims_
.
lhs_noncontracting_split
;
}
if
(
count
%
properties
.
block_size
!=
0
)
{
boundary_checks
.
push_back
(
bounds
.
size
());
}
bounds
.
push_back
(
Cst64
(
count
));
strides
.
push_back
(
Cst64
(
stride
));
offsets
.
push_back
(
properties
.
offset
);
block_dims
.
push_back
(
properties
.
block_size
);
dim_order
.
emplace
(
dim_order
.
begin
(),
dim_order
.
size
());
};
for
(
const
DimProperties
&
dim
:
side
.
tiled_dims
)
{
add_dim
(
dim
);
}
int64_t
stride_batch
=
0
;
if
(
side
.
scope
!=
TritonFusionAnalysis
::
Scope
::
RHS
&&
dims_
.
lhs_noncontracting_split
)
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis_
.
IterSpec
(
side
.
scope
,
hlo
,
side
.
tiled_dims
[
0
].
index
);
if
(
spec
!=
nullptr
)
{
if
(
spec
->
size
()
>
1
)
{
// Support one specific kind of output transpose that splits the
// dimension originating from the split LHS non-contracting one.
stride_batch
=
spec
->
at
(
1
).
stride
;
}
else
{
// Because the major part of the split is implemented using the
// batch logic stride_batch is populated here as the stride of
// the minor part times its size.
stride_batch
=
spec
->
at
(
0
).
stride
*
(
spec
->
at
(
0
).
count
/
*
dims_
.
lhs_noncontracting_split
);
}
CHECK_NE
(
stride_batch
,
0
);
}
}
else
if
(
side
.
batch_dim_idx
.
has_value
())
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis_
.
IterSpec
(
side
.
scope
,
hlo
,
*
side
.
batch_dim_idx
);
if
(
spec
!=
nullptr
)
{
stride_batch
=
spec
->
at
(
0
).
stride
;
CHECK_NE
(
stride_batch
,
0
);
}
}
if
(
stride_batch
!=
0
)
{
Value
offset_batch
=
b_
.
create
<
ma
::
MulIOp
>
(
ConvertScalar
(
pid_batch
),
Cst
(
stride_batch
));
base
=
AddPtr
(
b_
,
base
,
offset_batch
);
}
if
(
dims_
.
out_split_k_dim_idx
.
has_value
())
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis_
.
IterSpec
(
TritonFusionAnalysis
::
Scope
::
OUTPUT
,
hlo
,
*
dims_
.
out_split_k_dim_idx
);
if
(
spec
!=
nullptr
)
{
int64_t
stride_split_k
=
spec
->
at
(
0
).
stride
;
Value
offset_split_k
=
b_
.
create
<
ma
::
MulIOp
>
(
ConvertScalar
(
pid_k
),
Cst
(
stride_split_k
));
base
=
AddPtr
(
b_
,
base
,
offset_split_k
);
}
}
if
(
block_dims
.
empty
())
{
return
base
;
}
return
b_
.
create
<
mt
::
MakeTensorPtrOp
>
(
base
,
bounds
,
strides
,
offsets
,
block_dims
,
dim_order
);
}
private:
// Extend int32 indexes to int64, if necessary.
Value
ConvertScalar
(
Value
value
)
{
if
(
index_ty_
.
getIntOrFloatBitWidth
()
==
64
)
{
return
b_
.
create
<
ma
::
ExtSIOp
>
(
index_ty_
,
value
);
}
return
value
;
}
Value
Cst
(
int64_t
v
)
{
return
CreateConst
(
b_
,
index_ty_
,
v
);
}
Value
Cst64
(
int64_t
v
)
{
return
CreateConst
(
b_
,
i64_ty_
,
v
);
}
ImplicitLocOpBuilder
&
b_
;
absl
::
string_view
libdevice_path_
;
const
HloDotInstruction
*
dot_instr_
;
Type
index_ty_
;
TritonFusionAnalysis
analysis_
;
MatMulDims
dims_
;
MatMulLaunchConfig
launch_config_
;
Type
i32_ty_
=
b_
.
getI32Type
();
Type
i64_ty_
=
b_
.
getI64Type
();
};
}
// namespace
LaunchDimensions
GetMatMulLaunchDimensions
(
...
...
@@ -964,7 +1161,6 @@ LaunchDimensions GetMatMulLaunchDimensions(
}
// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n].
// TODO(b/270937368): Split this up into smaller functions.
Status
EmitMatMul
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
TritonFusionAnalysis
&
analysis
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
...
...
@@ -979,7 +1175,7 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
ShapeUtil
::
ElementsIn
(
dot_instr
->
operand
(
0
)
->
shape
())
>
INT_MAX
||
ShapeUtil
::
ElementsIn
(
dot_instr
->
operand
(
1
)
->
shape
())
>
INT_MAX
||
ShapeUtil
::
ElementsIn
(
dot_instr
->
shape
())
*
config
.
split_k
()
>
INT_MAX
;
mlir
::
Type
index_ty
=
builder
.
getIntegerType
(
use_64bit_indexing
?
64
:
32
);
Type
index_ty
=
builder
.
getIntegerType
(
use_64bit_indexing
?
64
:
32
);
const
HloInstruction
*
root
=
dot_instr
->
parent
()
->
root_instruction
();
CHECK
(
!
root
->
shape
().
IsTuple
());
...
...
@@ -990,7 +1186,6 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
auto
loc
=
mlir
::
NameLoc
::
get
(
builder
.
getStringAttr
(
dot_instr
->
name
()));
ImplicitLocOpBuilder
b
(
loc
,
builder
);
Type
i32_ty
=
b
.
getI32Type
();
Type
i64_ty
=
b
.
getI64Type
();
ValidateMatMulConfig
(
config
,
*
dot_instr
);
const
int
split_k
=
config
.
split_k
();
...
...
@@ -1002,62 +1197,38 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
const
MatMulLaunchConfig
launch_config
(
config
,
*
dot_instr
,
dims
);
VLOG
(
6
)
<<
analysis
.
ToString
();
constexpr
int
group_m
=
8
;
MatMulEmitterHelper
emitter
(
builder
,
libdevice_path
,
dot_instr
,
b
,
index_ty
,
dims
,
launch_config
,
analysis
);
constexpr
int
group_m
=
8
;
const
int64_t
width
=
group_m
*
launch_config
.
grid_n
;
auto
pid_batch
=
b
.
create
<
mt
::
GetProgramIdOp
>
(
launch_config
.
batch_program_id_dim
);
auto
c32
=
[
&
](
int64_t
v
)
{
return
CreateConst
(
b
,
b
.
getI32Type
(),
v
);
};
auto
pid_nc
=
b
.
create
<
mt
::
GetProgramIdOp
>
(
launch_config
.
noncontracting_program_id_dim
);
auto
pid_k
=
b
.
create
<
mt
::
GetProgramIdOp
>
(
mt
::
ProgramIDDim
::
Z
);
auto
group_id
=
b
.
create
<
ma
::
DivSIOp
>
(
pid_nc
,
CreateConst
(
b
,
i32_ty
,
width
));
ma
::
ConstantOp
group_m_op
=
CreateConst
(
b
,
i32_ty
,
group_m
);
auto
group_id
=
b
.
create
<
ma
::
DivSIOp
>
(
pid_nc
,
c32
(
width
));
ma
::
ConstantOp
group_m_op
=
c32
(
group_m
);
auto
first_pid_m
=
b
.
create
<
ma
::
MulIOp
>
(
group_id
,
group_m_op
);
auto
sub0
=
b
.
create
<
ma
::
SubIOp
>
(
CreateConst
(
b
,
i32_ty
,
launch_config
.
grid_m
),
first_pid_m
);
auto
sub0
=
b
.
create
<
ma
::
SubIOp
>
(
c32
(
launch_config
.
grid_m
),
first_pid_m
);
auto
group_size
=
b
.
create
<
ma
::
SelectOp
>
(
b
.
create
<
ma
::
CmpIOp
>
(
ma
::
CmpIPredicate
::
slt
,
sub0
,
group_m_op
),
sub0
,
group_m_op
);
// Extend int32 indexes to int64, if necessary.
auto
convert_scalar
=
[
&
](
Value
value
)
->
Value
{
if
(
index_ty
.
getIntOrFloatBitWidth
()
==
64
)
{
return
b
.
create
<
ma
::
ExtSIOp
>
(
index_ty
,
value
);
}
return
value
;
};
auto
pid_m
=
b
.
create
<
ma
::
AddIOp
>
(
first_pid_m
,
b
.
create
<
ma
::
RemSIOp
>
(
pid_nc
,
group_size
));
auto
pid_m_offset
=
b
.
create
<
ma
::
MulIOp
>
(
pid_m
,
CreateConst
(
b
,
i32_ty
,
block_m
));
auto
pid_m_offset
=
b
.
create
<
ma
::
MulIOp
>
(
pid_m
,
c32
(
block_m
));
auto
pid_n
=
b
.
create
<
ma
::
DivSIOp
>
(
b
.
create
<
ma
::
RemSIOp
>
(
pid_nc
,
CreateConst
(
b
,
i32_ty
,
width
)),
group_size
);
auto
pid_n_offset
=
b
.
create
<
ma
::
MulIOp
>
(
pid_n
,
CreateConst
(
b
,
i32_ty
,
block_n
));
auto
pid_n
=
b
.
create
<
ma
::
DivSIOp
>
(
b
.
create
<
ma
::
RemSIOp
>
(
pid_nc
,
c32
(
width
)),
group_size
);
auto
pid_n_offset
=
b
.
create
<
ma
::
MulIOp
>
(
pid_n
,
c32
(
block_n
));
auto
pid_k_offset
=
b
.
create
<
ma
::
MulIOp
>
(
pid_k
,
CreateConst
(
b
,
i32_ty
,
block_k
));
auto
pid_k_offset
=
b
.
create
<
ma
::
MulIOp
>
(
pid_k
,
c32
(
block_k
));
mlir
::
FloatType
acc_ty
=
emitter
.
GetDotAccumulatorType
();
Type
dot_output_ty
=
TritonType
(
b
,
dot_instr
->
shape
().
element_type
());
// Data type of dot() immediate inputs.
Type
dot_input_ty
=
b
.
getF32Type
();
{
const
Type
lhs_ty
=
TritonType
(
b
,
dot_instr
->
operand
(
0
)
->
shape
().
element_type
());
const
Type
rhs_ty
=
TritonType
(
b
,
dot_instr
->
operand
(
1
)
->
shape
().
element_type
());
CHECK
(
lhs_ty
==
rhs_ty
);
dot_input_ty
=
lhs_ty
;
}
// TODO(b/266862493): Accumulator can be integer too.
// Otherwise only f64 x f64 -> f64 uses f64 accumulator.
mlir
::
FloatType
acc_ty
=
(
dot_output_ty
.
isF64
()
&&
dot_input_ty
.
isF64
())
?
b
.
getF64Type
()
:
b
.
getF32Type
();
ma
::
ConstantOp
accumulator_init
=
CreateConst
(
b
,
acc_ty
,
0
,
{
block_m
,
block_n
});
...
...
@@ -1066,11 +1237,6 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
absl
::
flat_hash_map
<
int
,
const
HloInstruction
*>
iter_args_to_parameters
;
absl
::
flat_hash_map
<
int
,
std
::
vector
<
int32_t
>>
iter_args_to_boundary_checks
;
struct
Side
{
TritonFusionAnalysis
::
Scope
scope
;
std
::
vector
<
DimProperties
>
tiled_dims
;
std
::
optional
<
int64_t
>
batch_dim_idx
;
};
Side
lhs
{
TritonFusionAnalysis
::
Scope
::
LHS
,
/*tiled_dims=*/
{{
dims
.
lhs_noncontracting_dim_idx
,
pid_m_offset
,
block_m
},
...
...
@@ -1114,10 +1280,9 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
// Only the contracting dimensions are advanced.
if
(
dim
.
index
==
(
is_lhs
?
dims
.
lhs_contracting_dim_idx
:
dims
.
rhs_contracting_dim_idx
))
{
increments
.
push_back
(
CreateConst
(
b
,
i32_ty
,
dim
.
block_size
*
split_k
));
increments
.
push_back
(
c32
(
dim
.
block_size
*
split_k
));
}
else
{
increments
.
push_back
(
CreateConst
(
b
,
i32_ty
,
0
));
increments
.
push_back
(
c32
(
0
));
}
}
if
(
increments
.
empty
())
{
...
...
@@ -1129,15 +1294,8 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
}
// Emit all operations of LHS and RHS scopes.
auto
make_input
=
[
&
](
Side
&
side
,
int64_t
operand_index
,
auto
&
values
)
{
return
*
EmitScope
(
b
,
libdevice_path
,
&
analysis
,
side
.
scope
,
side
.
tiled_dims
,
dot_instr
->
parent
()
->
MakeInstructionPostOrderFrom
(
const_cast
<
HloInstruction
&>
(
*
dot_instr
->
operand
(
operand_index
))),
values
);
};
Value
dot_input_lhs
=
make_input
(
lhs
,
0
,
values_lhs
);
Value
dot_input_rhs
=
make_input
(
rhs
,
1
,
values_rhs
);
Value
dot_input_lhs
=
emitter
.
MakeInput
(
lhs
,
0
,
values_lhs
);
Value
dot_input_rhs
=
emitter
.
MakeInput
(
rhs
,
1
,
values_rhs
);
// Operation in the fusion before the dot can alter the elements of the
// tiles that were zero masked during loads. These have to be zeroed here
...
...
@@ -1186,151 +1344,35 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
analysis
.
ScopeParameters
(
TritonFusionAnalysis
::
Scope
::
LHS
).
size
()
+
analysis
.
ScopeParameters
(
TritonFusionAnalysis
::
Scope
::
RHS
).
size
()
+
1
);
auto
emit_tensor_pointer
=
[
&
](
const
HloInstruction
*
hlo
,
const
Side
&
side
,
Value
base
,
std
::
vector
<
int32_t
>&
boundary_checks
)
->
Value
{
std
::
vector
<
Value
>
bounds
;
std
::
vector
<
Value
>
strides
;
std
::
vector
<
Value
>
offsets
;
std
::
vector
<
int32_t
>
block_dims
;
std
::
vector
<
int32_t
>
dim_order
;
auto
add_dim
=
[
&
](
const
DimProperties
&
properties
)
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis
.
IterSpec
(
side
.
scope
,
hlo
,
properties
.
index
);
if
(
spec
==
nullptr
)
{
return
;
}
const
int64_t
stride
=
spec
->
at
(
0
).
stride
;
int64_t
count
=
spec
->
at
(
0
).
count
;
if
(
side
.
scope
==
TritonFusionAnalysis
::
Scope
::
OUTPUT
&&
properties
.
index
==
dims
.
out_lhs_noncontracting_dim_idx
&&
spec
->
size
()
==
1
&&
dims
.
lhs_noncontracting_split
.
has_value
())
{
// Dimension of the output produced by the non-contracting LHS one
// is logically split, major part is addressed using pid_batch.
count
/=
*
dims
.
lhs_noncontracting_split
;
}
if
(
count
%
properties
.
block_size
!=
0
)
{
boundary_checks
.
push_back
(
bounds
.
size
());
}
bounds
.
push_back
(
CreateConst
(
b
,
i64_ty
,
count
));
strides
.
push_back
(
CreateConst
(
b
,
i64_ty
,
stride
));
offsets
.
push_back
(
properties
.
offset
);
block_dims
.
push_back
(
properties
.
block_size
);
dim_order
.
emplace
(
dim_order
.
begin
(),
dim_order
.
size
());
};
for
(
const
DimProperties
&
dim
:
side
.
tiled_dims
)
{
add_dim
(
dim
);
}
int64_t
stride_batch
=
0
;
if
(
side
.
scope
!=
TritonFusionAnalysis
::
Scope
::
RHS
&&
dims
.
lhs_noncontracting_split
)
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis
.
IterSpec
(
side
.
scope
,
hlo
,
side
.
tiled_dims
[
0
].
index
);
if
(
spec
!=
nullptr
)
{
if
(
spec
->
size
()
>
1
)
{
// Support one specific kind of output transpose that splits the
// dimension originating from the split LHS non-contracting one.
stride_batch
=
spec
->
at
(
1
).
stride
;
}
else
{
// Because the major part of the split is implemented using the
// batch logic stride_batch is populated here as the stride of
// the minor part times its size.
stride_batch
=
spec
->
at
(
0
).
stride
*
(
spec
->
at
(
0
).
count
/
*
dims
.
lhs_noncontracting_split
);
}
CHECK_NE
(
stride_batch
,
0
);
}
}
else
if
(
side
.
batch_dim_idx
.
has_value
())
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis
.
IterSpec
(
side
.
scope
,
hlo
,
*
side
.
batch_dim_idx
);
if
(
spec
!=
nullptr
)
{
stride_batch
=
spec
->
at
(
0
).
stride
;
CHECK_NE
(
stride_batch
,
0
);
}
}
if
(
stride_batch
!=
0
)
{
Value
offset_batch
=
b
.
create
<
ma
::
MulIOp
>
(
convert_scalar
(
pid_batch
),
CreateConst
(
b
,
index_ty
,
stride_batch
));
base
=
AddPtr
(
b
,
base
,
offset_batch
);
}
if
(
dims
.
out_split_k_dim_idx
.
has_value
())
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis
.
IterSpec
(
TritonFusionAnalysis
::
Scope
::
OUTPUT
,
hlo
,
*
dims
.
out_split_k_dim_idx
);
if
(
spec
!=
nullptr
)
{
int64_t
stride_split_k
=
spec
->
at
(
0
).
stride
;
Value
offset_split_k
=
b
.
create
<
ma
::
MulIOp
>
(
convert_scalar
(
pid_k
),
CreateConst
(
b
,
index_ty
,
stride_split_k
));
base
=
AddPtr
(
b
,
base
,
offset_split_k
);
}
}
if
(
block_dims
.
empty
())
{
return
base
;
}
return
b
.
create
<
mt
::
MakeTensorPtrOp
>
(
base
,
bounds
,
strides
,
offsets
,
block_dims
,
dim_order
);
};
for
(
const
auto
&
side
:
{
lhs
,
rhs
})
{
for
(
const
HloInstruction
*
param
:
analysis
.
ScopeParameters
(
side
.
scope
))
{
CHECK
(
iter_args_to_parameters
.
insert
({
iter_args
.
size
(),
param
}).
second
);
iter_args
.
push_back
(
emit
_tensor_p
ointer
(
param
,
side
,
fn
.
getArgument
(
param
->
parameter_number
()),
iter_args
.
push_back
(
emit
ter
.
EmitTensorP
ointer
(
param
,
side
,
fn
.
getArgument
(
param
->
parameter_number
()),
pid_k
,
iter_args_to_boundary_checks
[
iter_args
.
size
()]));
}
}
iter_args
.
push_back
(
accumulator_init
);
Value
acc_final
=
b
.
create
<
mlir
::
scf
::
ForOp
>
(
/*lowerBound=*/
b
.
create
<
ma
::
ConstantIntOp
>
(
0
,
/*width=*/
32
),
/*upperBound=*/
b
.
create
<
ma
::
ConstantIntOp
>
(
dims
.
k
,
/*width=*/
32
),
/*step=*/
b
.
create
<
ma
::
ConstantIntOp
>
(
block_k
*
split_k
,
/*width=*/
32
),
/*iterArgs=*/
iter_args
,
body_builder
)
.
getResult
(
iter_args
.
size
()
-
1
);
Value
acc_final
=
b
.
create
<
mlir
::
scf
::
ForOp
>
(
/*lowerBound=*/
c32
(
0
),
/*upperBound=*/
c32
(
dims
.
k
),
/*step=*/
c32
(
block_k
*
split_k
),
/*iterArgs=*/
iter_args
,
body_builder
)
.
getResult
(
iter_args
.
size
()
-
1
);
absl
::
flat_hash_map
<
const
HloInstruction
*
,
Value
>
values_out
;
values_out
[
dot_instr
]
=
Cast
(
b
,
acc_final
,
TritonType
(
b
,
dot_instr
->
shape
().
element_type
()));
// Collect all instructions of the dot's output scope.
absl
::
flat_hash_set
<
const
HloInstruction
*>
to_order
;
{
std
::
queue
<
const
HloInstruction
*>
to_add
;
if
(
root
!=
dot_instr
)
{
to_add
.
push
(
root
);
}
while
(
!
to_add
.
empty
())
{
const
HloInstruction
*
current
=
to_add
.
front
();
for
(
const
HloInstruction
*
operand
:
current
->
operands
())
{
if
(
!
to_order
.
contains
(
operand
))
{
if
(
operand
!=
dot_instr
)
{
to_add
.
push
(
operand
);
}
}
}
CHECK
(
to_order
.
insert
(
current
).
second
);
to_add
.
pop
();
}
}
// Order them producers before consumers.
std
::
vector
<
const
HloInstruction
*>
to_emit
;
for
(
const
HloInstruction
*
hlo
:
dot_instr
->
parent
()
->
MakeInstructionPostOrder
())
{
if
(
to_order
.
contains
(
hlo
))
{
to_emit
.
push_back
(
hlo
);
}
}
// Emit the output scope.
if
(
!
to_emit
.
empty
())
{
if
(
std
::
vector
<
const
HloInstruction
*>
to_emit
=
emitter
.
EpiloguePostOrderTransitiveOperands
(
root
);
!
to_emit
.
empty
())
{
for
(
const
HloInstruction
*
parameter
:
analysis
.
ScopeParameters
(
TritonFusionAnalysis
::
Scope
::
OUTPUT
))
{
std
::
vector
<
int32_t
>
boundary_checks
;
Value
tensor_pointer
=
emit
_tensor_p
ointer
(
parameter
,
out
,
fn
.
getArgument
(
parameter
->
parameter_number
()),
Value
tensor_pointer
=
emit
ter
.
EmitTensorP
ointer
(
parameter
,
out
,
fn
.
getArgument
(
parameter
->
parameter_number
()),
pid_k
,
boundary_checks
);
CHECK
(
values_out
.
insert
({
parameter
,
...
...
@@ -1349,9 +1391,9 @@ Status EmitMatMul(mlir::OpBuilder builder, absl::string_view libdevice_path,
const
HloInstruction
*
producer
=
root
->
shape
().
IsTuple
()
?
root
->
operand
(
i
)
:
root
;
std
::
vector
<
int32_t
>
boundary_checks
;
Value
tensor_pointer
=
emit
_tensor_p
ointer
(
Value
tensor_pointer
=
emit
ter
.
EmitTensorP
ointer
(
producer
,
out
,
fn
.
getArgument
(
i
+
dot_instr
->
parent
()
->
num_parameters
()),
fn
.
getArgument
(
i
+
dot_instr
->
parent
()
->
num_parameters
()),
pid_k
,
boundary_checks
);
b
.
create
<
mt
::
StoreOp
>
(
tensor_pointer
,
values_out
[
producer
],
boundary_checks
,
mt
::
CacheModifier
::
NONE
,
mt
::
EvictionPolicy
::
NORMAL
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录