Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1d23e0bb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1d23e0bb
编写于
5月 06, 2023
作者:
Z
zhangkaihuo
提交者:
GitHub
5月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick]add flash randomness control and add scaled_dot_product_attention (#53518)
att, cherry-pick: #52902 #53113
上级
39b704c1
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
458 addition
and
93 deletion
+458
-93
cmake/external/flashattn.cmake
cmake/external/flashattn.cmake
+1
-1
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+2
-2
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+4
-2
paddle/phi/kernels/flash_attn_kernel.h
paddle/phi/kernels/flash_attn_kernel.h
+22
-17
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
+13
-0
paddle/phi/kernels/gpu/flash_attn_kernel.cu
paddle/phi/kernels/gpu/flash_attn_kernel.cu
+66
-26
python/paddle/distributed/auto_parallel/operators/__init__.py
...on/paddle/distributed/auto_parallel/operators/__init__.py
+1
-0
python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py
...le/distributed/auto_parallel/operators/dist_flash_attn.py
+103
-0
python/paddle/fluid/tests/unittests/test_flash_attention.py
python/paddle/fluid/tests/unittests/test_flash_attention.py
+50
-6
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+2
-0
python/paddle/nn/functional/flash_attention.py
python/paddle/nn/functional/flash_attention.py
+194
-39
未找到文件。
cmake/external/flashattn.cmake
浏览文件 @
1d23e0bb
...
@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
...
@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set
(
FLASHATTN_SOURCE_SUBDIR csrc/flash_attn
)
set
(
FLASHATTN_SOURCE_SUBDIR csrc/flash_attn
)
set
(
FLASHATTN_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/flashattn
)
set
(
FLASHATTN_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/flashattn
)
set
(
FLASHATTN_REPOSITORY
${
GIT_URL
}
/PaddlePaddle/flash-attention.git
)
set
(
FLASHATTN_REPOSITORY
${
GIT_URL
}
/PaddlePaddle/flash-attention.git
)
set
(
FLASHATTN_TAG
f0edf243a813a65d05c75fcb331b2a95faf96bbc
)
set
(
FLASHATTN_TAG
5ff4bbf56ad066750407c4aef16ac740ebda0717
)
set
(
FLASHATTN_INCLUDE_DIR
set
(
FLASHATTN_INCLUDE_DIR
"
${
FLASHATTN_INSTALL_DIR
}
/include"
"
${
FLASHATTN_INSTALL_DIR
}
/include"
...
...
paddle/phi/api/yaml/backward.yaml
浏览文件 @
1d23e0bb
...
@@ -617,7 +617,7 @@
...
@@ -617,7 +617,7 @@
inplace
:
(out_grad -> x_grad)
inplace
:
(out_grad -> x_grad)
-
backward_op
:
flash_attn_grad
-
backward_op
:
flash_attn_grad
forward
:
flash_attn (Tensor q, Tensor k, Tensor v,
float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward
:
flash_attn (Tensor q, Tensor k, Tensor v,
Tensor fixed_seed_offset, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args
:
(Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal =
false
)
args
:
(Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal =
false
)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta
:
infer_meta
:
...
@@ -628,7 +628,7 @@
...
@@ -628,7 +628,7 @@
data_type
:
q
data_type
:
q
-
backward_op
:
flash_attn_unpadded_grad
-
backward_op
:
flash_attn_unpadded_grad
forward
:
flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward
:
flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
)
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta
:
infer_meta
:
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
1d23e0bb
...
@@ -678,8 +678,9 @@
...
@@ -678,8 +678,9 @@
backward
:
fill_diagonal_tensor_grad
backward
:
fill_diagonal_tensor_grad
-
op
:
flash_attn
-
op
:
flash_attn
args
:
(Tensor q, Tensor k, Tensor v,
float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
)
args
:
(Tensor q, Tensor k, Tensor v,
Tensor fixed_seed_offset, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
)
output
:
Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
output
:
Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional
:
fixed_seed_offset
infer_meta
:
infer_meta
:
func
:
FlashAttnInferMeta
func
:
FlashAttnInferMeta
param
:
[
q
,
k
,
v
]
param
:
[
q
,
k
,
v
]
...
@@ -690,8 +691,9 @@
...
@@ -690,8 +691,9 @@
backward
:
flash_attn_grad
backward
:
flash_attn_grad
-
op
:
flash_attn_unpadded
-
op
:
flash_attn_unpadded
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
)
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
)
output
:
Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
output
:
Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional
:
fixed_seed_offset
infer_meta
:
infer_meta
:
func
:
FlashAttnInferMeta
func
:
FlashAttnInferMeta
param
:
[
q
,
k
,
v
]
param
:
[
q
,
k
,
v
]
...
...
paddle/phi/kernels/flash_attn_kernel.h
浏览文件 @
1d23e0bb
...
@@ -20,33 +20,38 @@
...
@@ -20,33 +20,38 @@
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
FlashAttnUnpaddedKernel
(
const
Context
&
ctx
,
void
FlashAttnUnpaddedKernel
(
const
DenseTensor
&
q
,
const
Context
&
ctx
,
const
DenseTensor
&
k
,
const
DenseTensor
&
q
,
const
DenseTensor
&
v
,
const
DenseTensor
&
k
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_k
,
const
DenseTensor
&
cu_seqlens_q
,
int64_t
max_seqlen_q
,
const
DenseTensor
&
cu_seqlens_k
,
int64_t
max_seqlen_k
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
float
scale
,
int64_t
max_seqlen_q
,
float
dropout
,
int64_t
max_seqlen_k
,
bool
causal
,
float
scale
,
bool
return_softmax
,
float
dropout
,
bool
is_test
,
bool
causal
,
DenseTensor
*
out
,
bool
return_softmax
,
DenseTensor
*
softmax
,
bool
is_test
,
DenseTensor
*
softmax_lse
,
const
std
::
string
&
rng_name
,
DenseTensor
*
seed_offset
);
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
seed_offset
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
FlashAttnKernel
(
const
Context
&
ctx
,
void
FlashAttnKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
v
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
float
dropout
,
float
dropout
,
bool
causal
,
bool
causal
,
bool
return_softmax
,
bool
return_softmax
,
bool
is_test
,
bool
is_test
,
const
std
::
string
&
rng_name
,
DenseTensor
*
out
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
softmax_lse
,
...
...
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
浏览文件 @
1d23e0bb
...
@@ -13,8 +13,10 @@
...
@@ -13,8 +13,10 @@
// limitations under the License.
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/arange_kernel.h"
...
@@ -25,6 +27,8 @@
...
@@ -25,6 +27,8 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
#endif
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
...
@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
if
(
FLAGS_cudnn_deterministic
)
{
num_splits
=
1
;
}
bool
zero_tensors
=
false
;
bool
zero_tensors
=
false
;
const
int64_t
*
seed_offset_data
=
seed_offset
.
data
<
int64_t
>
();
const
int64_t
*
seed_offset_data
=
seed_offset
.
data
<
int64_t
>
();
uint64_t
seed
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
0
]);
uint64_t
seed
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
0
]);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
1
]);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
1
]);
VLOG
(
4
)
<<
"FlashAttn bwd seed: "
<<
seed
<<
", offset: "
<<
offset
<<
", num_splits:"
<<
num_splits
;
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
DenseTensor
dsoftmax
=
Empty
<
float
>
(
ctx
,
{
batch_size
,
num_heads
,
seq_len_q
});
DenseTensor
dsoftmax
=
Empty
<
float
>
(
ctx
,
{
batch_size
,
num_heads
,
seq_len_q
});
...
@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
...
@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
VLOG
(
4
)
<<
"FlashAttn bwd dims q["
<<
q
.
dims
()
<<
"], k["
<<
k
.
dims
()
<<
"], v["
<<
v
.
dims
()
<<
"]"
;
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
...
...
paddle/phi/kernels/gpu/flash_attn_kernel.cu
浏览文件 @
1d23e0bb
...
@@ -14,12 +14,13 @@
...
@@ -14,12 +14,13 @@
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
...
@@ -28,26 +29,31 @@
...
@@ -28,26 +29,31 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
#endif
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
FlashAttnUnpaddedKernel
(
const
Context
&
ctx
,
void
FlashAttnUnpaddedKernel
(
const
DenseTensor
&
q
,
const
Context
&
ctx
,
const
DenseTensor
&
k
,
const
DenseTensor
&
q
,
const
DenseTensor
&
v
,
const
DenseTensor
&
k
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_k
,
const
DenseTensor
&
cu_seqlens_q
,
int64_t
max_seqlen_q
,
const
DenseTensor
&
cu_seqlens_k
,
int64_t
max_seqlen_k
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
float
scale
,
int64_t
max_seqlen_q
,
float
dropout
,
int64_t
max_seqlen_k
,
bool
causal
,
float
scale
,
bool
return_softmax
,
float
dropout
,
bool
is_test
,
bool
causal
,
DenseTensor
*
out
,
bool
return_softmax
,
DenseTensor
*
softmax
,
bool
is_test
,
DenseTensor
*
softmax_lse
,
const
std
::
string
&
rng_name
,
DenseTensor
*
seed_offset
)
{
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
seed_offset
)
{
#ifdef PADDLE_WITH_FLASHATTN
#ifdef PADDLE_WITH_FLASHATTN
if
(
is_test
)
dropout
=
0.0
f
;
if
(
is_test
)
dropout
=
0.0
f
;
...
@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
...
@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
if
(
FLAGS_cudnn_deterministic
)
{
num_splits
=
1
;
}
bool
zero_tensors
=
false
;
bool
zero_tensors
=
false
;
auto
gen
=
ctx
.
GetGenerator
();
uint64_t
seed
;
uint64_t
inc
=
batch_size
*
num_heads
*
32
;
uint64_t
offset
;
auto
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
if
(
fixed_seed_offset
.
get_ptr
())
{
const
int64_t
*
fixed_seed_offset_data
=
fixed_seed_offset
.
get_ptr
()
->
data
<
int64_t
>
();
seed
=
static_cast
<
uint64_t
>
(
fixed_seed_offset_data
[
0
]);
offset
=
static_cast
<
uint64_t
>
(
fixed_seed_offset_data
[
1
]);
}
else
{
uint64_t
inc
=
batch_size
*
num_heads
*
32
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed_offset_pair
;
if
(
rng_name
!=
""
)
{
auto
gen
=
phi
::
GetRandomSeedGenerator
(
rng_name
);
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
}
else
{
auto
*
gen
=
ctx
.
GetGenerator
();
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
}
seed
=
seed_offset_pair
.
first
;
offset
=
seed_offset_pair
.
second
;
}
uint64_t
seed
=
seed_offset_pair
.
first
;
VLOG
(
4
)
<<
"FlashAttn fwd seed: "
<<
seed
<<
", offset: "
<<
offset
uint64_t
offset
=
seed_offset_pair
.
second
;
<<
", num_splits:"
<<
num_splits
;
seed_offset
->
Resize
({
2
});
seed_offset
->
Resize
({
2
});
auto
*
seed_offset_data
=
ctx
.
template
HostAlloc
<
int64_t
>(
seed_offset
);
int64_t
*
seed_offset_data
=
ctx
.
template
HostAlloc
<
int64_t
>(
seed_offset
);
seed_offset_data
[
0
]
=
static_cast
<
int64_t
>
(
seed
);
seed_offset_data
[
0
]
=
static_cast
<
int64_t
>
(
seed
);
seed_offset_data
[
1
]
=
static_cast
<
int64_t
>
(
offset
);
seed_offset_data
[
1
]
=
static_cast
<
int64_t
>
(
offset
);
...
@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
...
@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
const
DenseTensor
&
q
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
v
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
float
dropout
,
float
dropout
,
bool
causal
,
bool
causal
,
bool
return_softmax
,
bool
return_softmax
,
bool
is_test
,
bool
is_test
,
const
std
::
string
&
rng_name
,
DenseTensor
*
out
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
softmax_lse
,
...
@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
...
@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
VLOG
(
4
)
<<
"FlashAttn fwd dims q["
<<
q
.
dims
()
<<
"], k["
<<
k
.
dims
()
<<
"], v["
<<
v
.
dims
()
<<
"]"
;
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
...
@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
...
@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
v_t_s
,
v_t_s
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
fixed_seed_offset
,
seq_len_q
,
seq_len_q
,
seq_len_k
,
seq_len_k
,
scale
,
scale
,
...
@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
...
@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
causal
,
causal
,
return_softmax
,
return_softmax
,
is_test
,
is_test
,
rng_name
,
out
,
out
,
softmax
,
softmax
,
softmax_lse
,
softmax_lse
,
...
@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
...
@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
FlashAttnUnpaddedKernel
,
phi
::
FlashAttnUnpaddedKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
// fixed_seed_offset
}
PD_REGISTER_KERNEL
(
flash_attn
,
PD_REGISTER_KERNEL
(
flash_attn
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
FlashAttnKernel
,
phi
::
FlashAttnKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
// fixed_seed_offset
}
python/paddle/distributed/auto_parallel/operators/__init__.py
浏览文件 @
1d23e0bb
...
@@ -38,3 +38,4 @@ from . import dist_shape
...
@@ -38,3 +38,4 @@ from . import dist_shape
from
.
import
dist_assign
from
.
import
dist_assign
from
.
import
dist_scale
from
.
import
dist_scale
from
.
import
dist_dropout
from
.
import
dist_dropout
from
.
import
dist_flash_attn
python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py
0 → 100644
浏览文件 @
1d23e0bb
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import
logging
from
...utils.log_utils
import
get_logger
_logger
=
get_logger
(
logging
.
INFO
)
from
..random
import
determinate_rng
,
is_enable_auto_rand_ctrl
from
.common
import
(
DistributedOperatorImplContainer
,
register_distributed_operator_impl
,
register_distributed_operator_impl_container
,
)
from
.dist_eltwise
import
DistributedDefaultImpl0
,
DistributedElementwiseImpl0
class
DistributedFlashAttn
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
().
__init__
(
op_type
)
register_distributed_operator_impl_container
(
DistributedFlashAttn
(
"flash_attn"
))
# Dist FlashAttn with Random Control
class
DistributedFlashAttnImpl0
(
DistributedElementwiseImpl0
):
def
__init__
(
self
,
name
):
super
().
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
is_input_compatible
(
self
,
dist_op
):
return
True
def
is_output_compatible
(
self
,
dist_op
):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
return
True
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
dist_op_context
=
ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
startup_block
=
dist_op_context
.
startup_block
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
if
(
is_enable_auto_rand_ctrl
()
and
not
op_dist_attr
.
is_recompute
and
rank_id
in
op_dist_attr
.
process_mesh
.
process_ids
):
assert
(
op_dist_attr
is
not
None
),
f
"forward op [
{
str
(
src_op
)
}
] don't have dist attribute !"
if
(
len
(
kwargs
.
get
(
'fixed_seed_offset'
,
[]))
>
0
or
len
(
src_op
.
input
(
"fixed_seed_offset"
))
>
0
):
# TODO(kuizhiqing) recompute should go here
pass
else
:
# determinate rng
q_var
=
main_block
.
_var_recursive
(
kwargs
[
'q'
][
0
])
k_var
=
main_block
.
_var_recursive
(
kwargs
[
'k'
][
0
])
q_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
q_var
.
name
)
k_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
k_var
.
name
)
process_mesh
=
op_dist_attr
.
process_mesh
dims_mapping
=
q_dims_mapping
[:
3
]
+
[
q_dims_mapping
[
2
]]
rng_name
=
determinate_rng
(
rank_id
,
dims_mapping
,
process_mesh
)
assert
rng_name
is
not
None
and
rng_name
!=
""
src_op
.
_set_attr
(
'rng_name'
,
rng_name
)
DistributedDefaultImpl0
.
forward
(
ctx
,
*
args
,
**
kwargs
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0
.
backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"flash_attn"
,
DistributedFlashAttnImpl0
(
"random_control"
)
)
python/paddle/fluid/tests/unittests/test_flash_attention.py
浏览文件 @
1d23e0bb
...
@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
...
@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
self
.
dropout
=
0.0
self
.
dropout
=
0.0
self
.
causal
=
False
self
.
causal
=
False
self
.
return_softmax
=
False
self
.
return_softmax
=
False
self
.
use_sdp_kernel
=
False
def
test_unpadded
(
self
):
def
test_unpadded
(
self
):
print
(
print
(
...
@@ -189,9 +190,19 @@ class TestFlashAttentionAPI(unittest.TestCase):
...
@@ -189,9 +190,19 @@ class TestFlashAttentionAPI(unittest.TestCase):
value
,
place
=
self
.
place
,
dtype
=
self
.
dtype
,
stop_gradient
=
False
value
,
place
=
self
.
place
,
dtype
=
self
.
dtype
,
stop_gradient
=
False
)
)
out
,
_
=
flash_attention
(
if
self
.
use_sdp_kernel
:
q
,
k
,
v
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
with
paddle
.
nn
.
functional
.
sdp_kernel
(
)
enable_math
=
self
.
enable_math
,
enable_flash
=
self
.
enable_flash
,
enable_mem_efficient
=
self
.
enable_mem_efficient
,
):
out
,
_
=
flash_attention
(
q
,
k
,
v
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
)
else
:
out
,
_
=
flash_attention
(
q
,
k
,
v
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
)
out_
=
attention_naive
(
q_
,
k_
,
v_
,
self
.
causal
)
out_
=
attention_naive
(
q_
,
k_
,
v_
,
self
.
causal
)
out
.
backward
()
out
.
backward
()
...
@@ -220,9 +231,24 @@ class TestFlashAttentionAPI(unittest.TestCase):
...
@@ -220,9 +231,24 @@ class TestFlashAttentionAPI(unittest.TestCase):
name
=
"v"
,
shape
=
self
.
shape
,
dtype
=
self
.
dtype
name
=
"v"
,
shape
=
self
.
shape
,
dtype
=
self
.
dtype
)
)
outs
,
softmax
=
flash_attention
(
if
self
.
use_sdp_kernel
:
qs
,
ks
,
vs
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
with
paddle
.
nn
.
functional
.
sdp_kernel
(
)
enable_math
=
self
.
enable_math
,
enable_flash
=
self
.
enable_flash
,
enable_mem_efficient
=
self
.
enable_mem_efficient
,
):
outs
,
softmax
=
flash_attention
(
qs
,
ks
,
vs
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
,
)
else
:
outs
,
softmax
=
flash_attention
(
qs
,
ks
,
vs
,
self
.
dropout
,
self
.
causal
,
self
.
return_softmax
)
exe
=
fluid
.
Executor
(
self
.
place
)
exe
=
fluid
.
Executor
(
self
.
place
)
fetches_result
=
exe
.
run
(
fetches_result
=
exe
.
run
(
...
@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
...
@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
self
.
dropout
=
0.0
self
.
dropout
=
0.0
self
.
causal
=
False
self
.
causal
=
False
self
.
return_softmax
=
False
self
.
return_softmax
=
False
self
.
use_sdp_kernel
=
False
class
TestFlashAttentionAPITest2
(
TestFlashAttentionAPI
):
class
TestFlashAttentionAPITest2
(
TestFlashAttentionAPI
):
...
@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
...
@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
self
.
dropout
=
0.0
self
.
dropout
=
0.0
self
.
causal
=
False
self
.
causal
=
False
self
.
return_softmax
=
True
self
.
return_softmax
=
True
self
.
use_sdp_kernel
=
False
class
TestFlashAttentionAPITest3
(
TestFlashAttentionAPI
):
class
TestFlashAttentionAPITest3
(
TestFlashAttentionAPI
):
...
@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
...
@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
self
.
dropout
=
0.0
self
.
dropout
=
0.0
self
.
causal
=
True
self
.
causal
=
True
self
.
return_softmax
=
False
self
.
return_softmax
=
False
self
.
use_sdp_kernel
=
False
class
TestFlashAttentionAPITest4
(
TestFlashAttentionAPI
):
class
TestFlashAttentionAPITest4
(
TestFlashAttentionAPI
):
...
@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
...
@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
self
.
dropout
=
0.0
self
.
dropout
=
0.0
self
.
causal
=
False
self
.
causal
=
False
self
.
return_softmax
=
False
self
.
return_softmax
=
False
self
.
use_sdp_kernel
=
False
class
TestMathAttentionAPITest
(
TestFlashAttentionAPI
):
def
setUp
(
self
):
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
shape
=
(
8
,
1024
,
16
,
128
)
self
.
dtype
=
paddle
.
float16
self
.
dropout
=
0.0
self
.
causal
=
False
self
.
return_softmax
=
False
self
.
use_sdp_kernel
=
True
self
.
enable_math
=
True
self
.
enable_flash
=
False
self
.
enable_mem_efficient
=
False
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/nn/functional/__init__.py
浏览文件 @
1d23e0bb
...
@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401
...
@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401
from
.extension
import
temporal_shift
# noqa: F401
from
.extension
import
temporal_shift
# noqa: F401
from
.sparse_attention
import
sparse_attention
from
.sparse_attention
import
sparse_attention
from
.flash_attention
import
scaled_dot_product_attention
from
.flash_attention
import
sdp_kernel
__all__
=
[
# noqa
__all__
=
[
# noqa
'celu'
,
'celu'
,
...
...
python/paddle/nn/functional/flash_attention.py
浏览文件 @
1d23e0bb
...
@@ -13,8 +13,113 @@
...
@@ -13,8 +13,113 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.nn.functional
as
F
from
paddle
import
_C_ops
,
in_dynamic_mode
from
paddle
import
_C_ops
,
in_dynamic_mode
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
g_enable_math
=
None
g_enable_flash
=
None
g_enable_mem_efficient
=
None
@
signature_safe_contextmanager
def
sdp_kernel
(
enable_math
=
False
,
enable_flash
=
True
,
enable_mem_efficient
=
True
):
r
"""
With the sdp_kernel context manager, different algorithm implementations can
be selected for scaled_dot_product_attention.
"""
global
g_enable_math
,
g_enable_flash
,
g_enable_mem_efficient
original_enable_math
=
g_enable_math
original_enable_flash
=
g_enable_math
original_enable_mem_efficient
=
g_enable_mem_efficient
g_enable_math
=
enable_math
g_enable_flash
=
enable_flash
g_enable_mem_efficient
=
enable_mem_efficient
try
:
yield
finally
:
g_enable_math
=
original_enable_math
g_enable_flash
=
original_enable_flash
g_enable_mem_efficient
=
original_enable_mem_efficient
def
_math_attention
(
query
,
key
,
value
,
dropout_rate
=
0.0
,
causal
=
False
,
return_softmax
=
False
,
training
=
True
,
):
r
"""
This is a basic implementation of scaled dot product attention composed of
combinations of fundamental components.
"""
head_dim
=
query
.
shape
[
-
1
]
query
=
paddle
.
transpose
(
query
,
[
0
,
2
,
1
,
3
])
key
=
paddle
.
transpose
(
key
,
[
0
,
2
,
1
,
3
])
value
=
paddle
.
transpose
(
value
,
[
0
,
2
,
1
,
3
])
product
=
paddle
.
matmul
(
x
=
query
*
(
head_dim
**-
0.5
),
y
=
key
,
transpose_y
=
True
)
weights
=
(
paddle
.
incubate
.
softmax_mask_fuse_upper_triangle
(
product
)
if
causal
else
F
.
softmax
(
product
)
)
if
dropout_rate
>
0.0
:
weights
=
F
.
dropout
(
weights
,
dropout_rate
,
training
=
training
,
mode
=
"upscale_in_train"
)
out
=
paddle
.
matmul
(
weights
,
value
)
out
=
paddle
.
transpose
(
out
,
[
0
,
2
,
1
,
3
])
return
out
,
weights
if
return_softmax
else
None
def
_select_sdp_cuda
(
head_dim
):
if
head_dim
<
128
:
return
"flash_attn"
else
:
return
"mem_efficient"
def
_select_sdp
(
head_dim
):
r
"""
There are currently three different implementation options available for
scaled dot product attention, and the chosen approach depends on whether it
is determined by the sdp_kernel configuration or specified through input values.
"""
place
=
paddle
.
get_device
()
# not use sdp_kernel
if
g_enable_flash
is
None
:
if
"gpu"
not
in
place
:
return
"math"
else
:
return
_select_sdp_cuda
(
head_dim
)
if
(
g_enable_math
is
False
and
g_enable_flash
is
False
and
g_enable_mem_efficient
is
False
):
raise
AssertionError
(
"No available backend for scaled_dot_product_attention was found."
)
if
g_enable_math
is
True
:
if
g_enable_flash
is
False
and
g_enable_mem_efficient
is
False
:
return
"math"
if
"gpu"
not
in
place
:
return
"math"
if
g_enable_flash
is
True
and
g_enable_mem_efficient
is
True
:
return
_select_sdp_cuda
(
head_dim
)
if
g_enable_flash
is
True
:
return
"flash_attn"
return
"mem_efficient"
def
flash_attention
(
def
flash_attention
(
...
@@ -24,6 +129,9 @@ def flash_attention(
...
@@ -24,6 +129,9 @@ def flash_attention(
dropout
=
0.0
,
dropout
=
0.0
,
causal
=
False
,
causal
=
False
,
return_softmax
=
False
,
return_softmax
=
False
,
*
,
fixed_seed_offset
=
None
,
rng_name
=
""
,
training
=
True
,
training
=
True
,
name
=
None
,
name
=
None
,
):
):
...
@@ -57,7 +165,9 @@ def flash_attention(
...
@@ -57,7 +165,9 @@ def flash_attention(
dropout(float): The dropout ratio.
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase.
training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
:ref:`api_guide_Name`.
...
@@ -79,47 +189,81 @@ def flash_attention(
...
@@ -79,47 +189,81 @@ def flash_attention(
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
print(output)
print(output)
"""
"""
if
in_dynamic_mode
():
head_dim
=
query
.
shape
[
3
]
(
result_attention
,
result_softmax
,)
=
_C_ops
.
flash_attn
(
sdp_func_name
=
_select_sdp
(
head_dim
)
query
,
key
,
if
sdp_func_name
==
"flash_attn"
:
value
,
if
in_dynamic_mode
():
dropout
,
(
result_attention
,
result_softmax
,)
=
_C_ops
.
flash_attn
(
causal
,
query
,
return_softmax
,
key
,
not
training
,
value
,
fixed_seed_offset
,
dropout
,
causal
,
return_softmax
,
not
training
,
rng_name
,
)
return
result_attention
,
result_softmax
if
return_softmax
else
None
helper
=
LayerHelper
(
'flash_attn'
,
**
locals
())
dtype
=
helper
.
input_dtype
(
input_param_name
=
'q'
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
)
softmax
=
helper
.
create_variable_for_type_inference
(
dtype
)
softmax_lse
=
helper
.
create_variable_for_type_inference
(
paddle
.
float32
)
seed_offset
=
helper
.
create_variable_for_type_inference
(
paddle
.
int64
)
inputs
=
{
'q'
:
query
,
'k'
:
key
,
'v'
:
value
,
'fixed_seed_offset'
:
fixed_seed_offset
,
}
outputs
=
{
'out'
:
out
,
'softmax'
:
softmax
,
'softmax_lse'
:
softmax_lse
,
'seed_offset'
:
seed_offset
,
}
helper
.
append_op
(
type
=
'flash_attn'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
{
'dropout'
:
dropout
,
'causal'
:
causal
,
'return_softmax'
:
return_softmax
,
'is_test'
:
not
training
,
'rng_name'
:
rng_name
,
},
)
)
return
result_attention
,
result_softmax
if
return_softmax
else
None
return
out
,
softmax
if
return_softmax
else
None
else
:
if
sdp_func_name
==
"mem_efficient"
:
from
paddle.incubate.nn.memory_efficient_attention
import
(
memory_efficient_attention
,
)
helper
=
LayerHelper
(
'flash_attn'
,
**
locals
())
output
=
memory_efficient_attention
(
dtype
=
helper
.
input_dtype
(
input_param_name
=
'q'
)
query
,
out
=
helper
.
create_variable_for_type_inference
(
dtype
)
key
,
softmax
=
helper
.
create_variable_for_type_inference
(
dtype
)
value
,
softmax_lse
=
helper
.
create_variable_for_type_inference
(
paddle
.
float32
)
attn_bias
=
None
,
seed_offset
=
helper
.
create_variable_for_type_inference
(
paddle
.
int64
)
p
=
dropout
,
inputs
=
{
scale
=
None
,
'q'
:
query
,
training
=
training
,
'k'
:
key
,
)
'v'
:
value
,
return
output
,
None
}
else
:
outputs
=
{
return
_math_attention
(
'out'
:
out
,
query
,
'softmax'
:
softmax
,
key
,
'softmax_lse'
:
softmax_lse
,
value
,
'seed_offset'
:
seed_offset
,
dropout_rate
=
dropout
,
}
causal
=
causal
,
helper
.
append_op
(
return_softmax
=
return_softmax
,
type
=
'flash_attn'
,
training
=
training
,
inputs
=
inputs
,
)
outputs
=
outputs
,
attrs
=
{
'dropout'
:
dropout
,
'causal'
:
causal
,
'return_softmax'
:
return_softmax
,
'is_test'
:
not
training
,
},
)
return
out
,
softmax
if
return_softmax
else
None
def
flash_attn_unpadded
(
def
flash_attn_unpadded
(
...
@@ -134,6 +278,8 @@ def flash_attn_unpadded(
...
@@ -134,6 +278,8 @@ def flash_attn_unpadded(
dropout
=
0.0
,
dropout
=
0.0
,
causal
=
False
,
causal
=
False
,
return_softmax
=
False
,
return_softmax
=
False
,
fixed_seed_offset
=
None
,
rng_name
=
""
,
training
=
True
,
training
=
True
,
name
=
None
,
name
=
None
,
):
):
...
@@ -174,6 +320,8 @@ def flash_attn_unpadded(
...
@@ -174,6 +320,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio.
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
to set this property. For more information, please refer to
...
@@ -203,6 +351,7 @@ def flash_attn_unpadded(
...
@@ -203,6 +351,7 @@ def flash_attn_unpadded(
value
,
value
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
fixed_seed_offset
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_k
,
max_seqlen_k
,
scale
,
scale
,
...
@@ -210,6 +359,7 @@ def flash_attn_unpadded(
...
@@ -210,6 +359,7 @@ def flash_attn_unpadded(
causal
,
causal
,
return_softmax
,
return_softmax
,
not
training
,
not
training
,
rng_name
,
)
)
return
result_attention
,
result_softmax
if
return_softmax
else
None
return
result_attention
,
result_softmax
if
return_softmax
else
None
...
@@ -225,6 +375,7 @@ def flash_attn_unpadded(
...
@@ -225,6 +375,7 @@ def flash_attn_unpadded(
'v'
:
value
,
'v'
:
value
,
'cu_seqlens_q'
:
cu_seqlens_q
,
'cu_seqlens_q'
:
cu_seqlens_q
,
'cu_seqlens_k'
:
cu_seqlens_k
,
'cu_seqlens_k'
:
cu_seqlens_k
,
'fixed_seed_offset'
:
fixed_seed_offset
,
}
}
outputs
=
{
outputs
=
{
'out'
:
out
,
'out'
:
out
,
...
@@ -244,6 +395,10 @@ def flash_attn_unpadded(
...
@@ -244,6 +395,10 @@ def flash_attn_unpadded(
'causal'
:
causal
,
'causal'
:
causal
,
'return_softmax'
:
return_softmax
,
'return_softmax'
:
return_softmax
,
'is_test'
:
not
training
,
'is_test'
:
not
training
,
'rng_name'
:
rng_name
,
},
},
)
)
return
out
,
softmax
if
return_softmax
else
None
return
out
,
softmax
if
return_softmax
else
None
scaled_dot_product_attention
=
flash_attention
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录