Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
c8cee324
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c8cee324
编写于
9月 08, 2023
作者:
J
Johannes Reifferscheid
提交者:
TensorFlower Gardener
9月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Triton emitter: remove some unnecessary helpers/templates.
PiperOrigin-RevId: 563735974
上级
d8359707
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
59 addition
and
100 deletion
+59
-100
third_party/xla/xla/service/gpu/BUILD
third_party/xla/xla/service/gpu/BUILD
+1
-0
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
+58
-100
未找到文件。
third_party/xla/xla/service/gpu/BUILD
浏览文件 @
c8cee324
...
...
@@ -487,6 +487,7 @@ cc_library(
"//xla:xla_data_proto_cc"
,
"//xla/hlo/ir:hlo"
,
"//xla/hlo/utils:hlo_query"
,
"//xla/mlir_hlo"
,
"//xla/mlir_hlo:map_mhlo_to_scalar_op"
,
"//xla/service:dump"
,
"//xla/service/gpu/llvm_gpu_backend"
,
...
...
third_party/xla/xla/service/gpu/ir_emitter_triton.cc
浏览文件 @
c8cee324
...
...
@@ -64,6 +64,7 @@ limitations under the License.
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
...
...
@@ -86,6 +87,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/literal.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
#include "xla/primitive_util.h"
#include "xla/service/dump.h"
...
...
@@ -214,39 +216,13 @@ Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) {
<<
llvm_ir
::
DumpToString
(
dst_element_ty
);
}
Type
ElementType
(
Value
v
)
{
Type
src_ty
=
v
.
getType
();
if
(
auto
src_shaped_ty
=
src_ty
.
dyn_cast
<
mlir
::
ShapedType
>
())
{
return
src_shaped_ty
.
getElementType
();
}
return
src_ty
;
}
// Get the value of the scalar constant's literal in a C++ type.
template
<
typename
T
>
T
ScalarConstantValue
(
const
HloInstruction
&
instr
)
{
T
ScalarConstantValue
(
const
HloInstruction
&
instr
,
PrimitiveType
dst_type
)
{
CHECK
(
hlo_query
::
IsScalarConstant
(
&
instr
));
PrimitiveType
dst_type
;
if
constexpr
(
std
::
is_integral_v
<
T
>
)
{
if
constexpr
(
std
::
numeric_limits
<
T
>::
is_signed
)
{
dst_type
=
S64
;
}
else
{
dst_type
=
U64
;
}
}
else
{
dst_type
=
F64
;
}
StatusOr
<
Literal
>
converted
=
instr
.
literal
().
Convert
(
dst_type
);
TF_CHECK_OK
(
converted
.
status
());
if
constexpr
(
std
::
is_integral_v
<
T
>
)
{
if
constexpr
(
std
::
numeric_limits
<
T
>::
is_signed
)
{
return
converted
.
value
().
GetFirstElement
<
int64_t
>
();
}
else
{
return
converted
.
value
().
GetFirstElement
<
uint64_t
>
();
}
}
else
{
return
converted
.
value
().
GetFirstElement
<
double
>
();
}
return
converted
.
value
().
GetFirstElement
<
T
>
();
}
// Create a scalar constant.
...
...
@@ -279,7 +255,7 @@ ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value,
}
Value
Subtract
(
ImplicitLocOpBuilder
&
b
,
ValueRange
values
)
{
if
(
ElementType
(
values
[
0
]).
isa
<
mlir
::
IntegerType
>
())
{
if
(
mlir
::
getElementTypeOrSelf
(
values
[
0
]).
isa
<
mlir
::
IntegerType
>
())
{
return
b
.
create
<
ma
::
SubIOp
>
(
values
[
0
],
values
[
1
]);
}
else
{
return
b
.
create
<
ma
::
SubFOp
>
(
values
[
0
],
values
[
1
]);
...
...
@@ -287,34 +263,28 @@ Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) {
}
Value
Compare
(
ImplicitLocOpBuilder
&
b
,
ValueRange
values
,
ComparisonDirection
direction
)
{
if
(
ElementType
(
values
[
0
]).
isa
<
mlir
::
IntegerType
>
())
{
mlir
::
mhlo
::
ComparisonDirection
direction
)
{
if
(
mlir
::
getElementTypeOrSelf
(
values
[
0
]).
isa
<
mlir
::
IntegerType
>
())
{
return
b
.
create
<
ma
::
CmpIOp
>
(
mlir
::
mhlo
::
impl
::
getCmpPredicate
<
ma
::
CmpIPredicate
>
(
mlir
::
mhlo
::
symbolizeComparisonDirection
(
ComparisonDirectionToString
(
direction
))
.
value
(),
/*isSigned=*/
true
)
mlir
::
mhlo
::
impl
::
getCmpPredicate
<
ma
::
CmpIPredicate
>
(
direction
,
/*isSigned=*/
true
)
.
value
(),
values
[
0
],
values
[
1
]);
}
return
b
.
create
<
ma
::
CmpFOp
>
(
mlir
::
mhlo
::
impl
::
getCmpPredicate
<
ma
::
CmpFPredicate
>
(
mlir
::
mhlo
::
symbolizeComparisonDirection
(
ComparisonDirectionToString
(
direction
))
.
value
(),
/*isSigned=*/
true
)
mlir
::
mhlo
::
impl
::
getCmpPredicate
<
ma
::
CmpFPredicate
>
(
direction
,
/*isSigned=*/
true
)
.
value
(),
values
[
0
],
values
[
1
]);
}
Value
Maximum
(
ImplicitLocOpBuilder
&
b
,
ValueRange
values
)
{
auto
cmp
=
Compare
(
b
,
values
,
ComparisonDirection
::
kGt
);
auto
cmp
=
Compare
(
b
,
values
,
mlir
::
mhlo
::
ComparisonDirection
::
GT
);
return
b
.
create
<
ma
::
SelectOp
>
(
cmp
,
values
[
0
],
values
[
1
]);
}
Value
Minimum
(
ImplicitLocOpBuilder
&
b
,
ValueRange
values
)
{
auto
cmp
=
Compare
(
b
,
values
,
ComparisonDirection
::
kLt
);
auto
cmp
=
Compare
(
b
,
values
,
mlir
::
mhlo
::
ComparisonDirection
::
LT
);
return
b
.
create
<
ma
::
SelectOp
>
(
cmp
,
values
[
0
],
values
[
1
]);
}
...
...
@@ -345,9 +315,7 @@ using TensorValue = mlir::TypedValue<mlir::RankedTensorType>;
Value
Broadcast
(
ImplicitLocOpBuilder
&
b
,
TensorValue
value
,
ArrayRef
<
int64_t
>
shape
)
{
auto
type
=
mlir
::
RankedTensorType
::
get
(
shape
,
value
.
getType
().
getElementType
());
return
b
.
create
<
mt
::
BroadcastOp
>
(
type
,
value
);
return
b
.
create
<
mt
::
BroadcastOp
>
(
value
.
getType
().
clone
(
shape
),
value
);
}
Value
Range
(
ImplicitLocOpBuilder
&
b
,
int32_t
limit
)
{
...
...
@@ -361,7 +329,8 @@ Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) {
Value
EmitElementwise
(
ImplicitLocOpBuilder
&
b
,
absl
::
string_view
libdevice_path
,
const
HloInstruction
&
hlo
,
ValueRange
inputs
)
{
if
(
ElementType
(
inputs
[
0
]).
isF32
()
||
ElementType
(
inputs
[
0
]).
isF64
())
{
if
(
mlir
::
getElementTypeOrSelf
(
inputs
[
0
]).
isF32
()
||
mlir
::
getElementTypeOrSelf
(
inputs
[
0
]).
isF64
())
{
auto
dev_fn_id
=
GetTargetDeviceFunctionID
(
hlo
.
opcode
());
if
(
dev_fn_id
.
ok
())
{
return
b
.
create
<
mt
::
PureExternElementwiseOp
>
(
...
...
@@ -371,7 +340,8 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
llvm
::
Triple
(
"nvptx64-unknown-unknown"
)));
}
}
const
bool
is_integer
=
ElementType
(
inputs
[
0
]).
isa
<
mlir
::
IntegerType
>
();
const
bool
is_integer
=
mlir
::
getElementTypeOrSelf
(
inputs
[
0
]).
isa
<
mlir
::
IntegerType
>
();
switch
(
hlo
.
opcode
())
{
case
HloOpcode
::
kCopy
:
...
...
@@ -418,11 +388,15 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
}
return
b
.
create
<
ma
::
DivFOp
>
(
inputs
[
0
],
inputs
[
1
]);
case
HloOpcode
::
kCompare
:
return
Compare
(
b
,
inputs
,
hlo
.
comparison_direction
());
return
Compare
(
b
,
inputs
,
mlir
::
mhlo
::
symbolizeComparisonDirection
(
ComparisonDirectionToString
(
hlo
.
comparison_direction
()))
.
value
());
case
HloOpcode
::
kSelect
:
return
b
.
create
<
ma
::
SelectOp
>
(
Compare
(
b
,
{
inputs
[
0
],
ZerosLike
(
b
,
inputs
[
0
])},
ComparisonDirection
::
kNe
),
mlir
::
mhlo
::
ComparisonDirection
::
NE
),
inputs
[
1
],
inputs
[
2
]);
default:
LOG
(
FATAL
)
<<
"Unsupported operation "
<<
hlo
.
ToString
();
...
...
@@ -450,12 +424,12 @@ Value EmitConstant(ImplicitLocOpBuilder& b, const HloInstruction& constant) {
Type
ty
=
TritonType
(
b
,
constant
.
shape
().
element_type
());
if
(
constant
.
shape
().
IsInteger
())
{
if
(
constant
.
shape
().
element_type
()
==
U64
)
{
return
CreateConst
(
b
,
ty
,
ScalarConstantValue
<
uint64_t
>
(
constant
));
return
CreateConst
(
b
,
ty
,
ScalarConstantValue
<
uint64_t
>
(
constant
,
U64
));
}
else
{
return
CreateConst
(
b
,
ty
,
ScalarConstantValue
<
int64_t
>
(
constant
));
return
CreateConst
(
b
,
ty
,
ScalarConstantValue
<
int64_t
>
(
constant
,
S64
));
}
}
return
CreateConst
(
b
,
ty
,
ScalarConstantValue
<
double
>
(
constant
));
return
CreateConst
(
b
,
ty
,
ScalarConstantValue
<
double
>
(
constant
,
F64
));
}
struct
DimProperties
{
...
...
@@ -755,13 +729,27 @@ struct GeneralizeKernelSignaturePass
}
};
}
// namespace
// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n].
// TODO(b/270937368): Split this up into smaller functions.
template
<
typename
IndexT
>
StatusOr
<
LaunchDimensions
>
MatMulImpl
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloDotInstruction
*
dot_instr
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
)
{
StatusOr
<
LaunchDimensions
>
MatMul
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
)
{
const
HloDotInstruction
*
dot_instr
=
DynCast
<
HloDotInstruction
>
(
hlo_query
::
GetFirstInstructionWithOpcode
(
*
computation
,
HloOpcode
::
kDot
));
// Use 32-bit indexing if addressing any of the inputs or the output (which
// could grow if split_k is set) does not cross the INT_MAX boundary.
// Otherwise, fall back to 64-bit indexing, which is slower.
bool
use_64bit_indexing
=
ShapeUtil
::
ElementsIn
(
dot_instr
->
operand
(
0
)
->
shape
())
>
INT_MAX
||
ShapeUtil
::
ElementsIn
(
dot_instr
->
operand
(
1
)
->
shape
())
>
INT_MAX
||
ShapeUtil
::
ElementsIn
(
dot_instr
->
shape
())
*
config
.
split_k
()
>
INT_MAX
;
mlir
::
Type
index_ty
=
builder
.
getIntegerType
(
use_64bit_indexing
?
64
:
32
);
const
HloInstruction
*
root
=
dot_instr
->
parent
()
->
root_instruction
();
CHECK
(
!
root
->
shape
().
IsTuple
());
...
...
@@ -772,12 +760,6 @@ StatusOr<LaunchDimensions> MatMulImpl(
ImplicitLocOpBuilder
b
(
loc
,
builder
);
Type
i32_ty
=
b
.
getI32Type
();
Type
i64_ty
=
b
.
getI64Type
();
Type
int_ty
;
if
constexpr
(
std
::
is_same_v
<
IndexT
,
int64_t
>
)
{
int_ty
=
i64_ty
;
}
else
{
int_ty
=
i32_ty
;
}
const
int
split_k
=
config
.
split_k
();
const
int
block_m
=
config
.
block_m
();
...
...
@@ -927,8 +909,8 @@ StatusOr<LaunchDimensions> MatMulImpl(
// Extend int32 indexes to int64, if necessary.
auto
convert_scalar
=
[
&
](
Value
value
)
->
Value
{
if
constexpr
(
std
::
is_same_v
<
IndexT
,
int64_t
>
)
{
return
b
.
create
<
ma
::
ExtSIOp
>
(
in
t
_ty
,
value
);
if
(
index_ty
.
getIntOrFloatBitWidth
()
==
64
)
{
return
b
.
create
<
ma
::
ExtSIOp
>
(
in
dex
_ty
,
value
);
}
return
value
;
};
...
...
@@ -1059,7 +1041,7 @@ StatusOr<LaunchDimensions> MatMulImpl(
b
.
create
<
mt
::
ExpandDimsOp
>
(
range_k
,
0
),
Splat
(
b
,
elements_in_tile
,
{
1
,
block_k
}))
.
getResult
()
.
template
cast
<
TensorValue
>(),
.
cast
<
TensorValue
>
(),
{
block_m
,
block_k
});
Value
rhs_mask
=
Broadcast
(
b
,
...
...
@@ -1067,7 +1049,7 @@ StatusOr<LaunchDimensions> MatMulImpl(
b
.
create
<
mt
::
ExpandDimsOp
>
(
range_k
,
1
),
Splat
(
b
,
elements_in_tile
,
{
block_k
,
1
}))
.
getResult
()
.
template
cast
<
TensorValue
>(),
.
cast
<
TensorValue
>
(),
{
block_k
,
block_n
});
dot_input_lhs
=
b
.
create
<
ma
::
SelectOp
>
(
lhs_mask
,
dot_input_lhs
,
ZerosLike
(
b
,
dot_input_lhs
));
...
...
@@ -1131,7 +1113,7 @@ StatusOr<LaunchDimensions> MatMulImpl(
add_dim
(
dim
);
}
IndexT
stride_batch
=
0
;
int64_t
stride_batch
=
0
;
if
(
scope
!=
TritonFusionAnalysis
::
Scope
::
RHS
&&
lhs_nc_split
)
{
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis
.
IterSpec
(
scope
,
hlo
,
tiled_dimensions
[
0
].
index
);
...
...
@@ -1158,8 +1140,9 @@ StatusOr<LaunchDimensions> MatMulImpl(
}
}
if
(
stride_batch
!=
0
)
{
Value
offset_batch
=
b
.
create
<
ma
::
MulIOp
>
(
convert_scalar
(
pid_batch
),
CreateConst
(
b
,
int_ty
,
stride_batch
));
Value
offset_batch
=
b
.
create
<
ma
::
MulIOp
>
(
convert_scalar
(
pid_batch
),
CreateConst
(
b
,
index_ty
,
stride_batch
));
base
=
AddPtr
(
b
,
base
,
offset_batch
);
}
...
...
@@ -1167,9 +1150,10 @@ StatusOr<LaunchDimensions> MatMulImpl(
const
TensorIterationSpec
::
DimIterationSpec
*
spec
=
analysis
.
IterSpec
(
TritonFusionAnalysis
::
Scope
::
OUTPUT
,
hlo
,
split_k_out_idx
);
if
(
spec
!=
nullptr
)
{
IndexT
stride_split_k
=
spec
->
at
(
0
).
stride
;
Value
offset_split_k
=
b
.
create
<
ma
::
MulIOp
>
(
convert_scalar
(
pid_k
),
CreateConst
(
b
,
int_ty
,
stride_split_k
));
int64_t
stride_split_k
=
spec
->
at
(
0
).
stride
;
Value
offset_split_k
=
b
.
create
<
ma
::
MulIOp
>
(
convert_scalar
(
pid_k
),
CreateConst
(
b
,
index_ty
,
stride_split_k
));
base
=
AddPtr
(
b
,
base
,
offset_split_k
);
}
}
...
...
@@ -1283,32 +1267,6 @@ StatusOr<LaunchDimensions> MatMulImpl(
{
config
.
num_warps
()
*
WarpSize
(),
1
,
1
}};
}
}
// namespace
StatusOr
<
LaunchDimensions
>
MatMul
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
mlir
::
triton
::
FuncOp
fn
,
const
AutotuneResult
::
TritonGemmKey
&
config
,
int
shmem_budget
)
{
const
HloDotInstruction
*
dot_instr
=
DynCast
<
HloDotInstruction
>
(
hlo_query
::
GetFirstInstructionWithOpcode
(
*
computation
,
HloOpcode
::
kDot
));
// Use 32-bit indexing if addressing any of the inputs or the output (which
// could grow if split_k is set) does not cross the INT_MAX boundary.
// Otherwise, fall back to 64-bit indexing, which is slower.
bool
use_64bit_indexing
=
ShapeUtil
::
ElementsIn
(
dot_instr
->
operand
(
0
)
->
shape
())
>
INT_MAX
||
ShapeUtil
::
ElementsIn
(
dot_instr
->
operand
(
1
)
->
shape
())
>
INT_MAX
||
ShapeUtil
::
ElementsIn
(
dot_instr
->
shape
())
*
config
.
split_k
()
>
INT_MAX
;
if
(
use_64bit_indexing
)
{
return
MatMulImpl
<
int64_t
>
(
builder
,
libdevice_path
,
dot_instr
,
fn
,
config
,
shmem_budget
);
}
else
{
return
MatMulImpl
<
int32_t
>
(
builder
,
libdevice_path
,
dot_instr
,
fn
,
config
,
shmem_budget
);
}
}
StatusOr
<
LaunchDimensions
>
SoftMax
(
mlir
::
OpBuilder
builder
,
absl
::
string_view
libdevice_path
,
const
HloComputation
*
computation
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录