Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f951832d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f951832d
编写于
3月 10, 2023
作者:
C
Chitsing KUI
提交者:
GitHub
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add flashattn raw kernel (#51383)
上级
3f4917f6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
347 addition
and
106 deletion
+347
-106
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+12
-0
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+12
-0
paddle/phi/kernels/flash_attn_grad_kernel.h
paddle/phi/kernels/flash_attn_grad_kernel.h
+20
-0
paddle/phi/kernels/flash_attn_kernel.h
paddle/phi/kernels/flash_attn_kernel.h
+18
-0
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
+110
-49
paddle/phi/kernels/gpu/flash_attn_kernel.cu
paddle/phi/kernels/gpu/flash_attn_kernel.cu
+125
-52
python/paddle/fluid/tests/unittests/test_flash_attention.py
python/paddle/fluid/tests/unittests/test_flash_attention.py
+50
-5
未找到文件。
paddle/phi/api/yaml/backward.yaml
浏览文件 @
f951832d
...
...
@@ -518,6 +518,18 @@
param
:
[
q
,
k
,
v
]
kernel
:
func
:
flash_attn_grad
data_type
:
q
-
backward_op
:
flash_attn_raw_grad
forward
:
flash_attn_raw (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta
:
func
:
FlashAttnGradInferMeta
param
:
[
q
,
k
,
v
]
kernel
:
func
:
flash_attn_raw_grad
data_type
:
q
-
backward_op
:
flip_grad
forward
:
flip (Tensor x, int[] axis) -> Tensor(out)
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
f951832d
...
...
@@ -500,8 +500,20 @@
param
:
[
q
,
k
,
v
]
kernel
:
func
:
flash_attn
data_type
:
q
backward
:
flash_attn_grad
-
op
:
flash_attn_raw
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
)
output
:
Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
infer_meta
:
func
:
FlashAttnInferMeta
param
:
[
q
,
k
,
v
]
kernel
:
func
:
flash_attn_raw
data_type
:
q
backward
:
flash_attn_raw_grad
-
op
:
flip
args
:
(Tensor x, int[] axis)
output
:
Tensor (out)
...
...
paddle/phi/kernels/flash_attn_grad_kernel.h
浏览文件 @
f951832d
...
...
@@ -19,6 +19,26 @@
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FlashAttnRawGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
const
DenseTensor
&
out
,
const
DenseTensor
&
softmax_lse
,
const
DenseTensor
&
seed_offset
,
const
DenseTensor
&
dout
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
DenseTensor
*
dq
,
DenseTensor
*
dk
,
DenseTensor
*
dv
);
template
<
typename
T
,
typename
Context
>
void
FlashAttnGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
...
...
paddle/phi/kernels/flash_attn_kernel.h
浏览文件 @
f951832d
...
...
@@ -19,6 +19,24 @@
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FlashAttnRawKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
DenseTensor
*
out
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
softmax
,
DenseTensor
*
seed_offset
);
template
<
typename
T
,
typename
Context
>
void
FlashAttnKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
...
...
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
浏览文件 @
f951832d
...
...
@@ -28,19 +28,24 @@
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FlashAttnGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
out
,
const
DenseTensor
&
softmax_lse
,
const
DenseTensor
&
seed_offset
,
const
DenseTensor
&
dout
,
float
dropout
,
bool
causal
,
DenseTensor
*
dq
,
DenseTensor
*
dk
,
DenseTensor
*
dv
)
{
void
FlashAttnRawGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
const
DenseTensor
&
out
,
const
DenseTensor
&
softmax_lse
,
const
DenseTensor
&
seed_offset
,
const
DenseTensor
&
dout
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
DenseTensor
*
dq
,
DenseTensor
*
dk
,
DenseTensor
*
dv
)
{
#ifdef PADDLE_WITH_FLASHATTN
ctx
.
template
Alloc
<
T
>(
dq
);
ctx
.
template
Alloc
<
T
>(
dk
);
...
...
@@ -49,36 +54,16 @@ void FlashAttnGradKernel(const Context& ctx,
cudaStream_t
stream
=
ctx
.
stream
();
bool
is_bf16
=
q
.
dtype
()
==
DataType
::
BFLOAT16
?
true
:
false
;
// q,k,v [
batch_size, seq_len
, num_heads, head_dim]
// q,k,v [
total_*
, num_heads, head_dim]
auto
dims
=
q
.
dims
();
int64_t
batch_size
=
dims
[
0
];
int64_t
seq_len_q
=
dims
[
1
];
int64_t
num_heads
=
dims
[
2
];
int64_t
head_size
=
dims
[
3
];
int64_t
seq_len_k
=
k
.
dims
()[
1
];
int64_t
total_q
=
dims
[
0
];
int64_t
num_heads
=
dims
[
1
];
int64_t
head_size
=
dims
[
2
];
int64_t
total_
q
=
batch_size
*
seq_len_q
;
int64_t
total_k
=
batch_size
*
seq_len_k
;
int64_t
total_
k
=
k
.
dims
()[
0
]
;
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
DenseTensor
q_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
q
,
{
total_q
,
num_heads
,
head_size
});
DenseTensor
k_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
k
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
v_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
v
,
{
total_k
,
num_heads
,
head_size
});
// q,k,v [total_*, num_heads, head_dim]
DenseTensor
cu_seqlens_q
;
DenseTensor
cu_seqlens_k
;
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_q
,
seq_len_q
,
&
cu_seqlens_q
);
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_k
,
seq_len_k
,
&
cu_seqlens_k
);
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
bool
zero_tensors
=
false
;
...
...
@@ -87,15 +72,16 @@ void FlashAttnGradKernel(const Context& ctx,
uint64_t
seed
=
seed_offset_vec
[
0
];
uint64_t
offset
=
seed_offset_vec
[
1
];
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
DenseTensor
dsoftmax
=
Empty
<
float
>
(
ctx
,
{
batch_size
,
num_heads
,
seq_len_q
});
uint64_t
workspace_size
;
// calculate workspace size before execution
bool
succ
=
phi
::
dynload
::
flash_attn_bwd
(
q
_t_s
.
data
(),
k
_t_s
.
data
(),
v
_t_s
.
data
(),
q
.
data
(),
k
.
data
(),
v
.
data
(),
dq
->
data
(),
dk
->
data
(),
dv
->
data
(),
...
...
@@ -108,8 +94,8 @@ void FlashAttnGradKernel(const Context& ctx,
batch_size
,
num_heads
,
head_size
,
seq_
len_q
,
seq_
len_k
,
max_seq
len_q
,
max_seq
len_k
,
dropout
,
scale
,
zero_tensors
,
...
...
@@ -134,9 +120,9 @@ void FlashAttnGradKernel(const Context& ctx,
}
succ
=
phi
::
dynload
::
flash_attn_bwd
(
q
_t_s
.
data
(),
k
_t_s
.
data
(),
v
_t_s
.
data
(),
q
.
data
(),
k
.
data
(),
v
.
data
(),
dq
->
data
(),
dk
->
data
(),
dv
->
data
(),
...
...
@@ -149,8 +135,8 @@ void FlashAttnGradKernel(const Context& ctx,
batch_size
,
num_heads
,
head_size
,
seq_
len_q
,
seq_
len_k
,
max_seq
len_q
,
max_seq
len_k
,
dropout
,
scale
,
zero_tensors
,
...
...
@@ -172,8 +158,83 @@ void FlashAttnGradKernel(const Context& ctx,
#endif
}
template
<
typename
T
,
typename
Context
>
void
FlashAttnGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
out
,
const
DenseTensor
&
softmax_lse
,
const
DenseTensor
&
seed_offset
,
const
DenseTensor
&
dout
,
float
dropout
,
bool
causal
,
DenseTensor
*
dq
,
DenseTensor
*
dk
,
DenseTensor
*
dv
)
{
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto
dims
=
q
.
dims
();
int64_t
batch_size
=
dims
[
0
];
int64_t
seq_len_q
=
dims
[
1
];
int64_t
num_heads
=
dims
[
2
];
int64_t
head_size
=
dims
[
3
];
int64_t
seq_len_k
=
k
.
dims
()[
1
];
int64_t
total_q
=
batch_size
*
seq_len_q
;
int64_t
total_k
=
batch_size
*
seq_len_k
;
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
DenseTensor
q_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
q
,
{
total_q
,
num_heads
,
head_size
});
DenseTensor
k_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
k
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
v_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
v
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
cu_seqlens_q
;
DenseTensor
cu_seqlens_k
;
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_q
,
seq_len_q
,
&
cu_seqlens_q
);
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_k
,
seq_len_k
,
&
cu_seqlens_k
);
FlashAttnRawGradKernel
<
T
,
Context
>
(
ctx
,
q_t_s
,
k_t_s
,
v_t_s
,
cu_seqlens_q
,
cu_seqlens_k
,
out
,
softmax_lse
,
seed_offset
,
dout
,
seq_len_q
,
seq_len_k
,
scale
,
dropout
,
causal
,
dq
,
dk
,
dv
);
#endif
}
}
// namespace phi
PD_REGISTER_KERNEL
(
flash_attn_raw_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
FlashAttnRawGradKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
7
).
SetBackend
(
phi
::
Backend
::
CPU
);
// seed_offset
}
PD_REGISTER_KERNEL
(
flash_attn_grad
,
GPU
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/gpu/flash_attn_kernel.cu
浏览文件 @
f951832d
...
...
@@ -16,6 +16,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
...
...
@@ -30,53 +31,44 @@
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FlashAttnKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
DenseTensor
*
out
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
softmax
,
DenseTensor
*
seed_offset
)
{
void
FlashAttnRawKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
DenseTensor
*
out
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
softmax
,
DenseTensor
*
seed_offset
)
{
#ifdef PADDLE_WITH_FLASHATTN
ctx
.
template
Alloc
<
T
>(
out
);
cudaStream_t
stream
=
ctx
.
stream
();
bool
is_bf16
=
q
.
dtype
()
==
DataType
::
BFLOAT16
?
true
:
false
;
// q,k,v [
batch_size, seq_len
, num_heads, head_dim]
// q,k,v [
total_*
, num_heads, head_dim]
auto
dims
=
q
.
dims
();
int64_t
batch_size
=
dims
[
0
];
int64_t
seq_len_q
=
dims
[
1
];
int64_t
num_heads
=
dims
[
2
];
int64_t
head_size
=
dims
[
3
];
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
3
,
phi
::
errors
::
InvalidArgument
(
"flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"
));
int64_t
seq_len_k
=
k
.
dims
()[
1
];
int64_t
total_q
=
dims
[
0
];
int64_t
num_heads
=
dims
[
1
];
int64_t
head_size
=
dims
[
2
];
int64_t
total_
q
=
batch_size
*
seq_len_q
;
int64_t
total_k
=
batch_size
*
seq_len_k
;
int64_t
total_
k
=
k
.
dims
()[
0
]
;
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
DenseTensor
q_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
q
,
{
total_q
,
num_heads
,
head_size
});
DenseTensor
k_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
k
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
v_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
v
,
{
total_k
,
num_heads
,
head_size
});
// q,k,v [total_*, num_heads, head_dim]
DenseTensor
cu_seqlens_q
;
DenseTensor
cu_seqlens_k
;
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_q
,
seq_len_q
,
&
cu_seqlens_q
);
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_k
,
seq_len_k
,
&
cu_seqlens_k
);
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
bool
zero_tensors
=
false
;
...
...
@@ -89,27 +81,33 @@ void FlashAttnKernel(const Context& ctx,
std
::
vector
<
int64_t
>
seed_offset_vec
{
int64_t
(
seed
),
int64_t
(
offset
)};
phi
::
TensorFromVector
<
int64_t
>
(
seed_offset_vec
,
ctx
,
seed_offset
);
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
softmax_lse
->
Resize
({
batch_size
,
num_heads
,
seq_len_q
});
ctx
.
template
Alloc
<
float
>(
softmax_lse
);
if
(
return_softmax
)
{
// may allocate more space than *
seq_
len_k*
// may allocate more space than *
max_seq
len_k*
int64_t
blocksize_c
=
head_size
>
64
?
128
:
256
;
int64_t
max_len_k_
=
((
seq_len_k
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
int64_t
max_len_k
=
seq_len_k
<=
128
?
128
:
(
seq_len_k
<=
256
?
256
:
max_len_k_
);
softmax
->
Resize
({
batch_size
,
num_heads
,
seq_len_q
,
max_len_k
});
int64_t
seq_len_k
=
((
max_seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
)
*
blocksize_c
;
if
(
max_seqlen_k
<=
128
)
{
seq_len_k
=
128
;
}
else
if
(
max_seqlen_k
<=
256
)
{
seq_len_k
=
256
;
}
softmax
->
Resize
({
batch_size
,
num_heads
,
seq_len_q
,
seq_len_k
});
ctx
.
template
Alloc
<
T
>(
softmax
);
}
uint64_t
workspace_size
;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool
succ
=
phi
::
dynload
::
flash_attn_fwd
(
q
_t_s
.
data
(),
k
_t_s
.
data
(),
v
_t_s
.
data
(),
phi
::
dynload
::
flash_attn_fwd
(
q
.
data
(),
k
.
data
(),
v
.
data
(),
nullptr
,
// for calculation workspace size
cu_seqlens_q
.
data
(),
cu_seqlens_k
.
data
(),
...
...
@@ -118,8 +116,8 @@ void FlashAttnKernel(const Context& ctx,
batch_size
,
num_heads
,
head_size
,
seq_
len_q
,
seq_
len_k
,
max_seq
len_q
,
max_seq
len_k
,
dropout
,
scale
,
zero_tensors
,
...
...
@@ -144,9 +142,9 @@ void FlashAttnKernel(const Context& ctx,
}
succ
=
phi
::
dynload
::
flash_attn_fwd
(
q
_t_s
.
data
(),
k
_t_s
.
data
(),
v
_t_s
.
data
(),
q
.
data
(),
k
.
data
(),
v
.
data
(),
out
->
data
(),
cu_seqlens_q
.
data
(),
cu_seqlens_k
.
data
(),
...
...
@@ -155,8 +153,8 @@ void FlashAttnKernel(const Context& ctx,
batch_size
,
num_heads
,
head_size
,
seq_
len_q
,
seq_
len_k
,
max_seq
len_q
,
max_seq
len_k
,
dropout
,
scale
,
zero_tensors
,
...
...
@@ -178,8 +176,83 @@ void FlashAttnKernel(const Context& ctx,
#endif
}
template
<
typename
T
,
typename
Context
>
void
FlashAttnKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
DenseTensor
*
out
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
softmax
,
DenseTensor
*
seed_offset
)
{
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto
dims
=
q
.
dims
();
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"
));
int64_t
batch_size
=
dims
[
0
];
int64_t
seq_len_q
=
dims
[
1
];
int64_t
num_heads
=
dims
[
2
];
int64_t
head_size
=
dims
[
3
];
int64_t
seq_len_k
=
k
.
dims
()[
1
];
int64_t
total_q
=
batch_size
*
seq_len_q
;
int64_t
total_k
=
batch_size
*
seq_len_k
;
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
DenseTensor
q_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
q
,
{
total_q
,
num_heads
,
head_size
});
DenseTensor
k_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
k
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
v_t_s
=
Reshape
<
T
,
Context
>
(
ctx
,
v
,
{
total_k
,
num_heads
,
head_size
});
DenseTensor
cu_seqlens_q
;
DenseTensor
cu_seqlens_k
;
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_q
,
seq_len_q
,
&
cu_seqlens_q
);
ArangeNullaryKernel
<
int32_t
,
Context
>
(
ctx
,
0
,
(
batch_size
+
1
)
*
seq_len_k
,
seq_len_k
,
&
cu_seqlens_k
);
FlashAttnRawKernel
<
T
,
Context
>
(
ctx
,
q_t_s
,
k_t_s
,
v_t_s
,
cu_seqlens_q
,
cu_seqlens_k
,
seq_len_q
,
seq_len_k
,
scale
,
dropout
,
causal
,
return_softmax
,
out
,
softmax_lse
,
softmax
,
seed_offset
);
#endif
}
}
// namespace phi
PD_REGISTER_KERNEL
(
flash_attn_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
FlashAttnRawKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
flash_attn
,
GPU
,
ALL_LAYOUT
,
...
...
python/paddle/fluid/tests/unittests/test_flash_attention.py
浏览文件 @
f951832d
...
...
@@ -61,12 +61,61 @@ class TestFlashAttentionAPI(unittest.TestCase):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
2
,
128
,
8
,
16
)
self
.
blocksize
=
2
self
.
dtype
=
'float16'
self
.
dropout
=
0.0
self
.
causal
=
False
self
.
return_softmax
=
False
def
test_raw
(
self
):
print
(
f
"Test Raw case shape
{
self
.
shape
}
dtype
{
self
.
dtype
}
causal
{
self
.
causal
}
"
)
paddle
.
disable_static
()
query
=
np
.
random
.
random
(
self
.
shape
)
q
=
paddle
.
to_tensor
(
query
,
place
=
self
.
place
,
dtype
=
self
.
dtype
,
stop_gradient
=
False
)
q_
=
paddle
.
to_tensor
(
query
,
place
=
self
.
place
,
dtype
=
self
.
dtype
,
stop_gradient
=
False
)
out_
=
attention_naive
(
q_
,
q_
,
q_
,
self
.
causal
)
scale
=
1.0
/
np
.
sqrt
(
q
.
shape
[
-
1
])
bs
=
self
.
shape
[
0
]
ms
=
self
.
shape
[
1
]
nh
=
self
.
shape
[
2
]
hd
=
self
.
shape
[
3
]
cu_q
=
paddle
.
arange
(
0
,
(
bs
+
1
)
*
ms
,
ms
,
dtype
=
'int32'
)
qq
=
paddle
.
reshape
(
q
,
[
bs
*
ms
,
nh
,
hd
])
out
,
_
,
_
,
_
=
paddle
.
_C_ops
.
flash_attn_raw
(
qq
,
qq
,
qq
,
cu_q
,
cu_q
,
ms
,
ms
,
scale
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
,
)
out_
=
paddle
.
reshape
(
out_
,
[
bs
*
ms
,
nh
,
hd
])
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
out_
,
rtol
=
5e-03
,
atol
=
1e-03
)
out
.
backward
()
out_
.
backward
()
np
.
testing
.
assert_allclose
(
q
.
grad
.
numpy
(),
q_
.
grad
.
numpy
(),
rtol
=
5e-03
,
atol
=
1e-03
)
def
test_all
(
self
):
print
(
f
"Test case shape
{
self
.
shape
}
dtype
{
self
.
dtype
}
causal
{
self
.
causal
}
"
...
...
@@ -152,7 +201,6 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
2
,
128
,
8
,
16
)
self
.
blocksize
=
2
self
.
dtype
=
paddle
.
float16
self
.
dropout
=
0.0
self
.
causal
=
False
...
...
@@ -163,7 +211,6 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
2
,
256
,
8
,
16
)
self
.
blocksize
=
2
self
.
dtype
=
paddle
.
float16
self
.
dropout
=
0.0
self
.
causal
=
False
...
...
@@ -174,7 +221,6 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
2
,
512
,
8
,
16
)
self
.
blocksize
=
2
self
.
dtype
=
paddle
.
float16
self
.
dropout
=
0.0
self
.
causal
=
True
...
...
@@ -185,7 +231,6 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
8
,
1024
,
16
,
128
)
self
.
blocksize
=
2
self
.
dtype
=
paddle
.
float16
self
.
dropout
=
0.0
self
.
causal
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录