Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6839a7b9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6839a7b9
编写于
8月 29, 2023
作者:
L
lzy
提交者:
GitHub
8月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make variable_length_memory_efficient_attention supports mask_broadcast_heads (#56673)
上级
f5d9981e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
16 addition
and
21 deletion
+16
-21
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h
...cutlass/memory_efficient_attention/default_fmha_grouped.h
+1
-4
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h
...on/cutlass/memory_efficient_attention/gemm/fmha_grouped.h
+10
-2
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py
..._efficient_attention/generate_variable_forward_kernels.py
+3
-10
paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu
...ion/cutlass/variable_length_memory_efficient_attention.cu
+2
-5
未找到文件。
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h
浏览文件 @
6839a7b9
...
...
@@ -77,8 +77,7 @@ template <
int
KeysPerBlock_
,
bool
SingleValueIteration_
,
GroupScheduleMode
GroupScheduleMode_
,
bool
AddMask
,
bool
MaskBroadcastRow
>
bool
AddMask
>
struct
DefaultFMHAGrouped
{
using
scalar_t
=
scalar_t_
;
using
accum_t
=
float
;
...
...
@@ -92,7 +91,6 @@ struct DefaultFMHAGrouped {
using
ArchTag
=
ArchTag_
;
static
bool
const
kIsAligned
=
isAligned_
;
static
bool
const
kAddMask
=
AddMask
;
static
bool
const
kMaskBroadcastRow
=
MaskBroadcastRow
;
static
bool
const
kSingleValueIteration
=
SingleValueIteration_
;
static
int
const
kKeysPerBlock
=
KeysPerBlock_
;
static
bool
const
kMaskIsAligned
=
maskIsAligned_
;
...
...
@@ -288,7 +286,6 @@ struct DefaultFMHAGrouped {
SingleValueIteration_
,
GroupScheduleMode_
,
AddMask
,
MaskBroadcastRow
,
maskIsAligned_
>
;
};
...
...
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h
浏览文件 @
6839a7b9
...
...
@@ -72,7 +72,6 @@ template <typename MM0_, ///! Structure for computing P = Q @ K
/// perform
bool
kAddMask
,
// This is quite faster when mask need broadcast at row axis
bool
kMaskBroadcastRow
,
bool
kMaskIsAligned
>
struct
FMHAGrouped
{
public:
...
...
@@ -191,6 +190,7 @@ struct FMHAGrouped {
// Whether causal masking is to be performed
bool
causal
;
bool
mask_broadcast_head
;
// Only used by device-level operator
GemmCoord
*
host_problem_sizes
;
...
...
@@ -224,6 +224,7 @@ struct FMHAGrouped {
kElementV
(
0
),
kElementO
(
0
),
causal
(
false
),
mask_broadcast_head
(
true
),
host_problem_sizes
(
nullptr
)
{}
/// Ctor
...
...
@@ -250,6 +251,7 @@ struct FMHAGrouped {
int64_t
kElementV
,
int64_t
kElementO
,
bool
causal
,
bool
mask_broadcast_head
,
ElementAccumulator
scale
,
GemmCoord
*
host_problem_sizes
=
nullptr
)
:
problem_sizes0
(
problem_sizes0
),
...
...
@@ -276,6 +278,7 @@ struct FMHAGrouped {
kElementV
(
kElementV
),
kElementO
(
kElementO
),
causal
(
causal
),
mask_broadcast_head
(
mask_broadcast_head
),
scale
(
scale
),
host_problem_sizes
(
host_problem_sizes
)
{}
...
...
@@ -327,6 +330,7 @@ struct FMHAGrouped {
ElementAccumulator
scale
;
bool
causal
;
bool
mask_broadcast_head
;
//
// Methods
...
...
@@ -352,6 +356,7 @@ struct FMHAGrouped {
kElementV
(
0
),
kElementO
(
0
),
causal
(
false
),
mask_broadcast_head
(
true
),
scale
(
0
)
{}
explicit
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
,
...
...
@@ -384,6 +389,7 @@ struct FMHAGrouped {
kElementV
(
args
.
kElementV
),
kElementO
(
args
.
kElementO
),
causal
(
args
.
causal
),
mask_broadcast_head
(
args
.
mask_broadcast_head
),
scale
(
args
.
scale
)
{}
// CUTLASS_HOST_DEVICE
...
...
@@ -704,6 +710,8 @@ struct FMHAGrouped {
// apply attention mask if applicable
if
(
kAddMask
)
{
const
int
mask_id
=
params
.
mask_broadcast_head
?
batch_idx
:
problem_idx
;
accum
=
cutlass
::
multiplies
<
typename
MM0
::
Mma
::
FragmentC
>
()(
params
.
scale
,
accum
);
// load mask tile Bij into shared memory
...
...
@@ -711,7 +719,7 @@ struct FMHAGrouped {
{
cutlass
::
layout
::
RowMajor
(
params
.
ldm
)},
// attn_mask_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id
params
.
ptr_M
+
batch_idx
*
params
.
kElementM
+
params
.
ptr_M
+
mask_id
*
params
.
kElementM
+
TileParams
::
query_start
(
threadblock_idx
)
*
params
.
ldm
+
iter_key_start
,
{
problem_size_0_m
,
problem_size_0_n
},
...
...
paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py
浏览文件 @
6839a7b9
...
...
@@ -200,6 +200,7 @@ void {NAME}({CPP_CLASS} default_fmha, Params ¶ms, const phi::GPUContext& ct
params.ElementV,
params.ElementO,
params.causal,
params.mask_broadcast_head,
params.scale,
problem_sizes1.data());
...
...
@@ -234,7 +235,6 @@ class FwdKernel:
k
:
int
single_value_iter
:
bool
support_mask
:
bool
=
True
mask_broadcast
:
bool
=
False
dispatch_cond
:
Optional
[
str
]
=
None
def
__post_init__
(
self
)
->
None
:
...
...
@@ -249,7 +249,6 @@ class FwdKernel:
0
if
self
.
single_value_iter
else
1
,
self
.
q
,
0
if
self
.
mask_aligned
else
1
,
0
if
self
.
mask_broadcast
else
1
,
)
@
property
...
...
@@ -264,10 +263,6 @@ class FwdKernel:
def
_mask_support_suffix
(
self
)
->
str
:
return
"sm"
if
self
.
support_mask
else
"usm"
@
property
def
_mask_broadcast_suffix
(
self
)
->
str
:
return
"mb"
if
self
.
mask_broadcast
else
"mnb"
@
property
def
_single_value_suffix
(
self
)
->
str
:
return
"rf"
if
self
.
single_value_iter
else
"urf"
...
...
@@ -289,7 +284,6 @@ class FwdKernel:
"true"
if
self
.
single_value_iter
else
"false"
,
"cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly"
,
"true"
if
self
.
support_mask
else
"false"
,
"false"
,
]
)
return
f
"cutlass::gemm::kernel::DefaultFMHAGrouped<
{
template_args
}
>"
...
...
@@ -297,7 +291,7 @@ class FwdKernel:
@
property
def
impl_group
(
self
)
->
str
:
# Maps to file which will contain the implementation
return
f
"
{
self
.
dtype
}
_
{
self
.
_aligned_suffix
}
_
{
self
.
_mask_support_suffix
}
_
{
self
.
_mask_aligned_suffix
}
_
{
self
.
_
mask_broadcast_suffix
}
_
{
self
.
_
single_value_suffix
}
_
{
self
.
q
}
x
{
self
.
k
}
"
return
f
"
{
self
.
dtype
}
_
{
self
.
_aligned_suffix
}
_
{
self
.
_mask_support_suffix
}
_
{
self
.
_mask_aligned_suffix
}
_
{
self
.
_single_value_suffix
}
_
{
self
.
q
}
x
{
self
.
k
}
"
@
property
def
cpp_impl
(
self
)
->
str
:
...
...
@@ -336,7 +330,6 @@ class FwdKernel:
single_value_iter
=
single_value_iter
,
support_mask
=
support_mask
,
mask_aligned
=
mask_aligned
,
mask_broadcast
=
False
,
)
)
return
kernels
...
...
@@ -490,7 +483,7 @@ struct Params {{
int64_t ElementO;
bool causal;
bool mask_broadcast_
row
;
bool mask_broadcast_
head
;
}};
__global__ static void get_problem_sizes(const int* seq_lens,
...
...
paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu
浏览文件 @
6839a7b9
...
...
@@ -65,10 +65,11 @@ void MultiHeadAttentionVariableForwardKernel(
if
(
mask
)
{
// [B, 1, S, D]
auto
mask_tensor
=
mask
.
get
();
int64_t
mask_num_heads
=
mask_tensor
.
dims
()[
1
];
params
.
ldm
=
mask_tensor
.
dims
()[
3
];
params
.
ElementM
=
mask_tensor
.
dims
()[
2
]
*
mask_tensor
.
dims
()[
3
];
params
.
mask_ptr
=
mask_tensor
.
data
();
params
.
mask_broadcast_
row
=
false
;
params
.
mask_broadcast_
head
=
mask_num_heads
==
1
?
true
:
false
;
}
bool
kernel_launched
=
false
;
...
...
@@ -84,10 +85,6 @@ void MultiHeadAttentionVariableForwardKernel(
if
(
!
mask
&&
KernelType
::
kAddMask
)
{
return
;
}
if
(
KernelType
::
kMaskBroadcastRow
)
{
// not support mask_broad_cast
return
;
}
if
(
mask
&&
reinterpret_cast
<
uintptr_t
>
(
params
.
mask_ptr
)
%
16
==
0
&&
params
.
ldm
%
(
16
/
sizeof
(
T
))
==
0
&&
!
KernelType
::
kMaskIsAligned
)
{
return
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录