Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1509a036
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1509a036
编写于
8月 15, 2023
作者:
Y
yinwei
提交者:
GitHub
8月 15, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add flash attention backward grad check (#56249)
--------- Co-authored-by:
N
tianhaodongbd
<
tianhaodong@baidu.com
>
上级
a26a3a60
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
17 addition
and
12 deletion
+17
-12
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
+7
-5
paddle/phi/kernels/gpu/flash_attn_utils.h
paddle/phi/kernels/gpu/flash_attn_utils.h
+1
-1
python/paddle/nn/functional/flash_attention.py
python/paddle/nn/functional/flash_attention.py
+8
-5
test/legacy_test/test_flash_attention.py
test/legacy_test/test_flash_attention.py
+1
-1
未找到文件。
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
浏览文件 @
1509a036
...
...
@@ -70,7 +70,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
phi
::
errors
::
InvalidArgument
(
"The head_dim is expected to be either 32, "
"64, or 128, but recieved %d."
,
head_size
));
const
int64_t
*
seed_offset_data
=
seed_offset
.
data
<
int64_t
>
();
uint64_t
seed
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
0
]);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
1
]);
...
...
@@ -88,8 +87,13 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
bool
fa_zero_tensors
=
false
;
uint64_t
workspace_size
;
int64_t
q_size
=
total_q
*
num_heads
*
head_size
;
DenseTensor
scaled_q
=
Empty
<
T
>
(
ctx
,
{
total_q
,
num_heads
,
head_size
});
ComputeScaleQ
(
ctx
,
q_size
,
scale
,
q
.
data
<
T
>
(),
scaled_q
.
data
<
T
>
());
bool
succ
=
phi
::
dynload
::
flash_attn_bwd_with_bias_and_mask
(
static_cast
<
const
void
*>
(
q
.
data
()),
static_cast
<
const
void
*>
(
scaled_q
.
data
<
T
>
()),
static_cast
<
const
void
*>
(
k
.
data
()),
static_cast
<
const
void
*>
(
v
.
data
()),
static_cast
<
void
*>
(
dq
->
data
()),
...
...
@@ -124,7 +128,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
mask_dims
.
data
()
?
mask_dims
.
data
()
:
nullptr
,
nullptr
);
CheckFlashAttnStatus
(
succ
);
DenseTensor
workspace
;
if
(
workspace_size
>
0
)
{
workspace
=
Empty
<
float
>
(
...
...
@@ -132,7 +135,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
}
succ
=
phi
::
dynload
::
flash_attn_bwd_with_bias_and_mask
(
static_cast
<
const
void
*>
(
q
.
data
()),
static_cast
<
const
void
*>
(
scaled_q
.
data
<
T
>
()),
static_cast
<
const
void
*>
(
k
.
data
()),
static_cast
<
const
void
*>
(
v
.
data
()),
static_cast
<
void
*>
(
dq
->
data
()),
...
...
@@ -168,7 +171,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
nullptr
);
CheckFlashAttnStatus
(
succ
);
int64_t
q_size
=
total_q
*
num_heads
*
head_size
;
ComputeScaleQ
(
ctx
,
q_size
,
scale
,
dq
->
data
<
T
>
(),
dq
->
data
<
T
>
());
#else
RaiseNotSupportedError
();
...
...
paddle/phi/kernels/gpu/flash_attn_utils.h
浏览文件 @
1509a036
...
...
@@ -238,7 +238,7 @@ static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
rank
,
4
,
phi
::
errors
::
InvalidArgument
(
"T
eh
number of dimenstions of attn_mask is expected to be greater "
"T
he
number of dimenstions of attn_mask is expected to be greater "
"or equal to 4, but recieved %d. The shape of attn_mask is {%s}"
,
rank
,
origin_dims
));
...
...
python/paddle/nn/functional/flash_attention.py
浏览文件 @
1509a036
...
...
@@ -417,6 +417,7 @@ def scaled_dot_product_attention(
dropout_p
=
0.0
,
is_causal
=
False
,
training
=
True
,
name
=
None
,
):
r
"""
The equation is:
...
...
@@ -447,10 +448,12 @@ def scaled_dot_product_attention(
The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
training(bool): Whether it is in the training phase
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns:
out(Tensor): The attention tensor.
...
...
@@ -459,13 +462,13 @@ def scaled_dot_product_attention(
Examples:
.. code-block:: python
# required: skiptest
>>> #
x
doctest: +SKIP()
>>> # doctest: +SKIP()
>>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output)
>>> #
x
doctest: -SKIP
>>> # doctest: -SKIP
"""
if
attn_mask
is
None
:
out
,
_
=
flash_attention
(
query
,
key
,
value
,
dropout_p
,
is_causal
)
...
...
test/legacy_test/test_flash_attention.py
浏览文件 @
1509a036
...
...
@@ -312,7 +312,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
not
core
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11040
or
not
is_sm_supported
,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.
3
"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.
4
"
"and device's compute capability must be 7.5 or 8.x"
,
)
class
TestFlashAttentionWithMaskAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录