Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6839a7b9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6839a7b9
编写于
8月 29, 2023
作者:
L
lzy
提交者:
GitHub
8月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make variable_length_memory_efficient_attention supports mask_broadcast_heads (#56673)
上级
f5d9981e
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
16 addition
and
21 deletion
+16
-21
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h
...cutlass/memory_efficient_attention/default_fmha_grouped.h
+1
-4
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h
...on/cutlass/memory_efficient_attention/gemm/fmha_grouped.h
+10
-2
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py
..._efficient_attention/generate_variable_forward_kernels.py
+3
-10
paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu
...ion/cutlass/variable_length_memory_efficient_attention.cu
+2
-5
未找到文件。
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h
浏览文件 @
6839a7b9
...
@@ -77,8 +77,7 @@ template <
...
@@ -77,8 +77,7 @@ template <
int
KeysPerBlock_
,
int
KeysPerBlock_
,
bool
SingleValueIteration_
,
bool
SingleValueIteration_
,
GroupScheduleMode
GroupScheduleMode_
,
GroupScheduleMode
GroupScheduleMode_
,
bool
AddMask
,
bool
AddMask
>
bool
MaskBroadcastRow
>
struct
DefaultFMHAGrouped
{
struct
DefaultFMHAGrouped
{
using
scalar_t
=
scalar_t_
;
using
scalar_t
=
scalar_t_
;
using
accum_t
=
float
;
using
accum_t
=
float
;
...
@@ -92,7 +91,6 @@ struct DefaultFMHAGrouped {
...
@@ -92,7 +91,6 @@ struct DefaultFMHAGrouped {
using
ArchTag
=
ArchTag_
;
using
ArchTag
=
ArchTag_
;
static
bool
const
kIsAligned
=
isAligned_
;
static
bool
const
kIsAligned
=
isAligned_
;
static
bool
const
kAddMask
=
AddMask
;
static
bool
const
kAddMask
=
AddMask
;
static
bool
const
kMaskBroadcastRow
=
MaskBroadcastRow
;
static
bool
const
kSingleValueIteration
=
SingleValueIteration_
;
static
bool
const
kSingleValueIteration
=
SingleValueIteration_
;
static
int
const
kKeysPerBlock
=
KeysPerBlock_
;
static
int
const
kKeysPerBlock
=
KeysPerBlock_
;
static
bool
const
kMaskIsAligned
=
maskIsAligned_
;
static
bool
const
kMaskIsAligned
=
maskIsAligned_
;
...
@@ -288,7 +286,6 @@ struct DefaultFMHAGrouped {
...
@@ -288,7 +286,6 @@ struct DefaultFMHAGrouped {
SingleValueIteration_
,
SingleValueIteration_
,
GroupScheduleMode_
,
GroupScheduleMode_
,
AddMask
,
AddMask
,
MaskBroadcastRow
,
maskIsAligned_
>
;
maskIsAligned_
>
;
};
};
...
...
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h
浏览文件 @
6839a7b9
...
@@ -72,7 +72,6 @@ template <typename MM0_, ///! Structure for computing P = Q @ K
...
@@ -72,7 +72,6 @@ template <typename MM0_, ///! Structure for computing P = Q @ K
/// perform
/// perform
bool
kAddMask
,
bool
kAddMask
,
// This is quite faster when mask need broadcast at row axis
// This is quite faster when mask need broadcast at row axis
bool
kMaskBroadcastRow
,
bool
kMaskIsAligned
>
bool
kMaskIsAligned
>
struct
FMHAGrouped
{
struct
FMHAGrouped
{
public:
public:
...
@@ -191,6 +190,7 @@ struct FMHAGrouped {
...
@@ -191,6 +190,7 @@ struct FMHAGrouped {
// Whether causal masking is to be performed
// Whether causal masking is to be performed
bool
causal
;
bool
causal
;
bool
mask_broadcast_head
;
// Only used by device-level operator
// Only used by device-level operator
GemmCoord
*
host_problem_sizes
;
GemmCoord
*
host_problem_sizes
;
...
@@ -224,6 +224,7 @@ struct FMHAGrouped {
...
@@ -224,6 +224,7 @@ struct FMHAGrouped {
kElementV
(
0
),
kElementV
(
0
),
kElementO
(
0
),
kElementO
(
0
),
causal
(
false
),
causal
(
false
),
mask_broadcast_head
(
true
),
host_problem_sizes
(
nullptr
)
{}
host_problem_sizes
(
nullptr
)
{}
/// Ctor
/// Ctor
...
@@ -250,6 +251,7 @@ struct FMHAGrouped {
...
@@ -250,6 +251,7 @@ struct FMHAGrouped {
int64_t
kElementV
,
int64_t
kElementV
,
int64_t
kElementO
,
int64_t
kElementO
,
bool
causal
,
bool
causal
,
bool
mask_broadcast_head
,
ElementAccumulator
scale
,
ElementAccumulator
scale
,
GemmCoord
*
host_problem_sizes
=
nullptr
)
GemmCoord
*
host_problem_sizes
=
nullptr
)
:
problem_sizes0
(
problem_sizes0
),
:
problem_sizes0
(
problem_sizes0
),
...
@@ -276,6 +278,7 @@ struct FMHAGrouped {
...
@@ -276,6 +278,7 @@ struct FMHAGrouped {
kElementV
(
kElementV
),
kElementV
(
kElementV
),
kElementO
(
kElementO
),
kElementO
(
kElementO
),
causal
(
causal
),
causal
(
causal
),
mask_broadcast_head
(
mask_broadcast_head
),
scale
(
scale
),
scale
(
scale
),
host_problem_sizes
(
host_problem_sizes
)
{}
host_problem_sizes
(
host_problem_sizes
)
{}
...
@@ -327,6 +330,7 @@ struct FMHAGrouped {
...
@@ -327,6 +330,7 @@ struct FMHAGrouped {
ElementAccumulator
scale
;
ElementAccumulator
scale
;
bool
causal
;
bool
causal
;
bool
mask_broadcast_head
;
//
//
// Methods
// Methods
...
@@ -352,6 +356,7 @@ struct FMHAGrouped {
...
@@ -352,6 +356,7 @@ struct FMHAGrouped {
kElementV
(
0
),
kElementV
(
0
),
kElementO
(
0
),
kElementO
(
0
),
causal
(
false
),
causal
(
false
),
mask_broadcast_head
(
true
),
scale
(
0
)
{}
scale
(
0
)
{}
explicit
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
explicit
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
...
@@ -384,6 +389,7 @@ struct FMHAGrouped {
...
@@ -384,6 +389,7 @@ struct FMHAGrouped {
kElementV
(
args
.
kElementV
),
kElementV
(
args
.
kElementV
),
kElementO
(
args
.
kElementO
),
kElementO
(
args
.
kElementO
),
causal
(
args
.
causal
),
causal
(
args
.
causal
),
mask_broadcast_head
(
args
.
mask_broadcast_head
),
scale
(
args
.
scale
)
{}
scale
(
args
.
scale
)
{}
// CUTLASS_HOST_DEVICE
// CUTLASS_HOST_DEVICE
...
@@ -704,6 +710,8 @@ struct FMHAGrouped {
...
@@ -704,6 +710,8 @@ struct FMHAGrouped {
// apply attention mask if applicable
// apply attention mask if applicable
if
(
kAddMask
)
{
if
(
kAddMask
)
{
const
int
mask_id
=
params
.
mask_broadcast_head
?
batch_idx
:
problem_idx
;
accum
=
cutlass
::
multiplies
<
typename
MM0
::
Mma
::
FragmentC
>
()(
accum
=
cutlass
::
multiplies
<
typename
MM0
::
Mma
::
FragmentC
>
()(
params
.
scale
,
accum
);
params
.
scale
,
accum
);
// load mask tile Bij into shared memory
// load mask tile Bij into shared memory
...
@@ -711,7 +719,7 @@ struct FMHAGrouped {
...
@@ -711,7 +719,7 @@ struct FMHAGrouped {
{
cutlass
::
layout
::
RowMajor
(
params
.
ldm
)},
{
cutlass
::
layout
::
RowMajor
(
params
.
ldm
)},
// attn_mask_pointer points to matrix of size (n_queries, n_keys)
// attn_mask_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id
// for the relevant batch_id and head_id
params
.
ptr_M
+
batch_idx
*
params
.
kElementM
+
params
.
ptr_M
+
mask_id
*
params
.
kElementM
+
TileParams
::
query_start
(
threadblock_idx
)
*
params
.
ldm
+
TileParams
::
query_start
(
threadblock_idx
)
*
params
.
ldm
+
iter_key_start
,
iter_key_start
,
{
problem_size_0_m
,
problem_size_0_n
},
{
problem_size_0_m
,
problem_size_0_n
},
...
...
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py
浏览文件 @
6839a7b9
...
@@ -200,6 +200,7 @@ void {NAME}({CPP_CLASS} default_fmha, Params ¶ms, const phi::GPUContext& ct
...
@@ -200,6 +200,7 @@ void {NAME}({CPP_CLASS} default_fmha, Params ¶ms, const phi::GPUContext& ct
params.ElementV,
params.ElementV,
params.ElementO,
params.ElementO,
params.causal,
params.causal,
params.mask_broadcast_head,
params.scale,
params.scale,
problem_sizes1.data());
problem_sizes1.data());
...
@@ -234,7 +235,6 @@ class FwdKernel:
...
@@ -234,7 +235,6 @@ class FwdKernel:
k
:
int
k
:
int
single_value_iter
:
bool
single_value_iter
:
bool
support_mask
:
bool
=
True
support_mask
:
bool
=
True
mask_broadcast
:
bool
=
False
dispatch_cond
:
Optional
[
str
]
=
None
dispatch_cond
:
Optional
[
str
]
=
None
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
...
@@ -249,7 +249,6 @@ class FwdKernel:
...
@@ -249,7 +249,6 @@ class FwdKernel:
0
if
self
.
single_value_iter
else
1
,
0
if
self
.
single_value_iter
else
1
,
self
.
q
,
self
.
q
,
0
if
self
.
mask_aligned
else
1
,
0
if
self
.
mask_aligned
else
1
,
0
if
self
.
mask_broadcast
else
1
,
)
)
@
property
@
property
...
@@ -264,10 +263,6 @@ class FwdKernel:
...
@@ -264,10 +263,6 @@ class FwdKernel:
def
_mask_support_suffix
(
self
)
->
str
:
def
_mask_support_suffix
(
self
)
->
str
:
return
"sm"
if
self
.
support_mask
else
"usm"
return
"sm"
if
self
.
support_mask
else
"usm"
@
property
def
_mask_broadcast_suffix
(
self
)
->
str
:
return
"mb"
if
self
.
mask_broadcast
else
"mnb"
@
property
@
property
def
_single_value_suffix
(
self
)
->
str
:
def
_single_value_suffix
(
self
)
->
str
:
return
"rf"
if
self
.
single_value_iter
else
"urf"
return
"rf"
if
self
.
single_value_iter
else
"urf"
...
@@ -289,7 +284,6 @@ class FwdKernel:
...
@@ -289,7 +284,6 @@ class FwdKernel:
"true"
if
self
.
single_value_iter
else
"false"
,
"true"
if
self
.
single_value_iter
else
"false"
,
"cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly"
,
"cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly"
,
"true"
if
self
.
support_mask
else
"false"
,
"true"
if
self
.
support_mask
else
"false"
,
"false"
,
]
]
)
)
return
f
"cutlass::gemm::kernel::DefaultFMHAGrouped<
{
template_args
}
>"
return
f
"cutlass::gemm::kernel::DefaultFMHAGrouped<
{
template_args
}
>"
...
@@ -297,7 +291,7 @@ class FwdKernel:
...
@@ -297,7 +291,7 @@ class FwdKernel:
@
property
@
property
def
impl_group
(
self
)
->
str
:
def
impl_group
(
self
)
->
str
:
# Maps to file which will contain the implementation
# Maps to file which will contain the implementation
return
f
"
{
self
.
dtype
}
_
{
self
.
_aligned_suffix
}
_
{
self
.
_mask_support_suffix
}
_
{
self
.
_mask_aligned_suffix
}
_
{
self
.
_
mask_broadcast_suffix
}
_
{
self
.
_
single_value_suffix
}
_
{
self
.
q
}
x
{
self
.
k
}
"
return
f
"
{
self
.
dtype
}
_
{
self
.
_aligned_suffix
}
_
{
self
.
_mask_support_suffix
}
_
{
self
.
_mask_aligned_suffix
}
_
{
self
.
_single_value_suffix
}
_
{
self
.
q
}
x
{
self
.
k
}
"
@
property
@
property
def
cpp_impl
(
self
)
->
str
:
def
cpp_impl
(
self
)
->
str
:
...
@@ -336,7 +330,6 @@ class FwdKernel:
...
@@ -336,7 +330,6 @@ class FwdKernel:
single_value_iter
=
single_value_iter
,
single_value_iter
=
single_value_iter
,
support_mask
=
support_mask
,
support_mask
=
support_mask
,
mask_aligned
=
mask_aligned
,
mask_aligned
=
mask_aligned
,
mask_broadcast
=
False
,
)
)
)
)
return
kernels
return
kernels
...
@@ -490,7 +483,7 @@ struct Params {{
...
@@ -490,7 +483,7 @@ struct Params {{
int64_t ElementO;
int64_t ElementO;
bool causal;
bool causal;
bool mask_broadcast_
row
;
bool mask_broadcast_
head
;
}};
}};
__global__ static void get_problem_sizes(const int* seq_lens,
__global__ static void get_problem_sizes(const int* seq_lens,
...
...
paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu
浏览文件 @
6839a7b9
...
@@ -65,10 +65,11 @@ void MultiHeadAttentionVariableForwardKernel(
...
@@ -65,10 +65,11 @@ void MultiHeadAttentionVariableForwardKernel(
if
(
mask
)
{
if
(
mask
)
{
// [B, 1, S, D]
// [B, 1, S, D]
auto
mask_tensor
=
mask
.
get
();
auto
mask_tensor
=
mask
.
get
();
int64_t
mask_num_heads
=
mask_tensor
.
dims
()[
1
];
params
.
ldm
=
mask_tensor
.
dims
()[
3
];
params
.
ldm
=
mask_tensor
.
dims
()[
3
];
params
.
ElementM
=
mask_tensor
.
dims
()[
2
]
*
mask_tensor
.
dims
()[
3
];
params
.
ElementM
=
mask_tensor
.
dims
()[
2
]
*
mask_tensor
.
dims
()[
3
];
params
.
mask_ptr
=
mask_tensor
.
data
();
params
.
mask_ptr
=
mask_tensor
.
data
();
params
.
mask_broadcast_
row
=
false
;
params
.
mask_broadcast_
head
=
mask_num_heads
==
1
?
true
:
false
;
}
}
bool
kernel_launched
=
false
;
bool
kernel_launched
=
false
;
...
@@ -84,10 +85,6 @@ void MultiHeadAttentionVariableForwardKernel(
...
@@ -84,10 +85,6 @@ void MultiHeadAttentionVariableForwardKernel(
if
(
!
mask
&&
KernelType
::
kAddMask
)
{
if
(
!
mask
&&
KernelType
::
kAddMask
)
{
return
;
return
;
}
}
if
(
KernelType
::
kMaskBroadcastRow
)
{
// not support mask_broad_cast
return
;
}
if
(
mask
&&
reinterpret_cast
<
uintptr_t
>
(
params
.
mask_ptr
)
%
16
==
0
&&
if
(
mask
&&
reinterpret_cast
<
uintptr_t
>
(
params
.
mask_ptr
)
%
16
==
0
&&
params
.
ldm
%
(
16
/
sizeof
(
T
))
==
0
&&
!
KernelType
::
kMaskIsAligned
)
{
params
.
ldm
%
(
16
/
sizeof
(
T
))
==
0
&&
!
KernelType
::
kMaskIsAligned
)
{
return
;
return
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录