Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
00ac8014
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
00ac8014
编写于
4月 20, 2023
作者:
C
Chitsing KUI
提交者:
GitHub
4月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[FlashAttn] add flash randomness control (#52902)
* add flash randomness control * fix VLOG undefied
上级
67c6cfe0
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
229 addition
and
48 deletion
+229
-48
cmake/external/flashattn.cmake
cmake/external/flashattn.cmake
+1
-1
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+2
-2
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+4
-2
paddle/phi/kernels/flash_attn_kernel.h
paddle/phi/kernels/flash_attn_kernel.h
+22
-17
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
+13
-0
paddle/phi/kernels/gpu/flash_attn_kernel.cu
paddle/phi/kernels/gpu/flash_attn_kernel.cu
+66
-26
python/paddle/distributed/auto_parallel/operators/__init__.py
...on/paddle/distributed/auto_parallel/operators/__init__.py
+1
-0
python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py
...le/distributed/auto_parallel/operators/dist_flash_attn.py
+103
-0
python/paddle/nn/functional/flash_attention.py
python/paddle/nn/functional/flash_attention.py
+17
-0
未找到文件。
cmake/external/flashattn.cmake
浏览文件 @
00ac8014
...
...
@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set
(
FLASHATTN_SOURCE_SUBDIR csrc/flash_attn
)
set
(
FLASHATTN_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/flashattn
)
set
(
FLASHATTN_REPOSITORY
${
GIT_URL
}
/PaddlePaddle/flash-attention.git
)
set
(
FLASHATTN_TAG
f0edf243a813a65d05c75fcb331b2a95faf96bbc
)
set
(
FLASHATTN_TAG
5ff4bbf56ad066750407c4aef16ac740ebda0717
)
set
(
FLASHATTN_INCLUDE_DIR
"
${
FLASHATTN_INSTALL_DIR
}
/include"
...
...
paddle/phi/api/yaml/backward.yaml
浏览文件 @
00ac8014
...
...
@@ -617,7 +617,7 @@
inplace
:
(out_grad -> x_grad)
-
backward_op
:
flash_attn_grad
forward
:
flash_attn (Tensor q, Tensor k, Tensor v,
float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward
:
flash_attn (Tensor q, Tensor k, Tensor v,
Tensor fixed_seed_offset, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args
:
(Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal =
false
)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta
:
...
...
@@ -628,7 +628,7 @@
data_type
:
q
-
backward_op
:
flash_attn_unpadded_grad
forward
:
flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward
:
flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
)
output
:
Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta
:
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
00ac8014
...
...
@@ -678,8 +678,9 @@
backward
:
fill_diagonal_tensor_grad
-
op
:
flash_attn
args
:
(Tensor q, Tensor k, Tensor v,
float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
)
args
:
(Tensor q, Tensor k, Tensor v,
Tensor fixed_seed_offset, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
)
output
:
Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional
:
fixed_seed_offset
infer_meta
:
func
:
FlashAttnInferMeta
param
:
[
q
,
k
,
v
]
...
...
@@ -690,8 +691,9 @@
backward
:
flash_attn_grad
-
op
:
flash_attn_unpadded
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
)
args
:
(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k,
Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal =
false
, bool return_softmax =
false
, bool is_test =
false
, str rng_name = ""
)
output
:
Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional
:
fixed_seed_offset
infer_meta
:
func
:
FlashAttnInferMeta
param
:
[
q
,
k
,
v
]
...
...
paddle/phi/kernels/flash_attn_kernel.h
浏览文件 @
00ac8014
...
...
@@ -20,33 +20,38 @@
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FlashAttnUnpaddedKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
bool
is_test
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
seed_offset
);
void
FlashAttnUnpaddedKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
bool
is_test
,
const
std
::
string
&
rng_name
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
seed_offset
);
template
<
typename
T
,
typename
Context
>
void
FlashAttnKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
bool
is_test
,
const
std
::
string
&
rng_name
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
...
...
paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
浏览文件 @
00ac8014
...
...
@@ -13,8 +13,10 @@
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
...
...
@@ -25,6 +27,8 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
template
<
typename
T
,
typename
Context
>
...
...
@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
if
(
FLAGS_cudnn_deterministic
)
{
num_splits
=
1
;
}
bool
zero_tensors
=
false
;
const
int64_t
*
seed_offset_data
=
seed_offset
.
data
<
int64_t
>
();
uint64_t
seed
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
0
]);
uint64_t
offset
=
static_cast
<
uint64_t
>
(
seed_offset_data
[
1
]);
VLOG
(
4
)
<<
"FlashAttn bwd seed: "
<<
seed
<<
", offset: "
<<
offset
<<
", num_splits:"
<<
num_splits
;
int64_t
seq_len_q
=
((
max_seqlen_q
+
16
-
1
)
/
16
)
*
16
;
DenseTensor
dsoftmax
=
Empty
<
float
>
(
ctx
,
{
batch_size
,
num_heads
,
seq_len_q
});
...
...
@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
VLOG
(
4
)
<<
"FlashAttn bwd dims q["
<<
q
.
dims
()
<<
"], k["
<<
k
.
dims
()
<<
"], v["
<<
v
.
dims
()
<<
"]"
;
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
...
...
paddle/phi/kernels/gpu/flash_attn_kernel.cu
浏览文件 @
00ac8014
...
...
@@ -14,12 +14,13 @@
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
...
...
@@ -28,26 +29,31 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FlashAttnUnpaddedKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
bool
is_test
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
seed_offset
)
{
void
FlashAttnUnpaddedKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
DenseTensor
&
cu_seqlens_q
,
const
DenseTensor
&
cu_seqlens_k
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
int64_t
max_seqlen_q
,
int64_t
max_seqlen_k
,
float
scale
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
bool
is_test
,
const
std
::
string
&
rng_name
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
DenseTensor
*
seed_offset
)
{
#ifdef PADDLE_WITH_FLASHATTN
if
(
is_test
)
dropout
=
0.0
f
;
...
...
@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
int64_t
batch_size
=
cu_seqlens_q
.
numel
()
-
1
;
int
num_splits
=
0
;
// 0 for an internal heuristic, which is optimal
if
(
FLAGS_cudnn_deterministic
)
{
num_splits
=
1
;
}
bool
zero_tensors
=
false
;
auto
gen
=
ctx
.
GetGenerator
();
uint64_t
inc
=
batch_size
*
num_heads
*
32
;
auto
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
uint64_t
seed
;
uint64_t
offset
;
if
(
fixed_seed_offset
.
get_ptr
())
{
const
int64_t
*
fixed_seed_offset_data
=
fixed_seed_offset
.
get_ptr
()
->
data
<
int64_t
>
();
seed
=
static_cast
<
uint64_t
>
(
fixed_seed_offset_data
[
0
]);
offset
=
static_cast
<
uint64_t
>
(
fixed_seed_offset_data
[
1
]);
}
else
{
uint64_t
inc
=
batch_size
*
num_heads
*
32
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed_offset_pair
;
if
(
rng_name
!=
""
)
{
auto
gen
=
phi
::
GetRandomSeedGenerator
(
rng_name
);
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
}
else
{
auto
*
gen
=
ctx
.
GetGenerator
();
seed_offset_pair
=
gen
->
IncrementOffset
(
inc
);
}
seed
=
seed_offset_pair
.
first
;
offset
=
seed_offset_pair
.
second
;
}
uint64_t
seed
=
seed_offset_pair
.
first
;
uint64_t
offset
=
seed_offset_pair
.
second
;
VLOG
(
4
)
<<
"FlashAttn fwd seed: "
<<
seed
<<
", offset: "
<<
offset
<<
", num_splits:"
<<
num_splits
;
seed_offset
->
Resize
({
2
});
auto
*
seed_offset_data
=
ctx
.
template
HostAlloc
<
int64_t
>(
seed_offset
);
int64_t
*
seed_offset_data
=
ctx
.
template
HostAlloc
<
int64_t
>(
seed_offset
);
seed_offset_data
[
0
]
=
static_cast
<
int64_t
>
(
seed
);
seed_offset_data
[
1
]
=
static_cast
<
int64_t
>
(
offset
);
...
...
@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
const
DenseTensor
&
q
,
const
DenseTensor
&
k
,
const
DenseTensor
&
v
,
const
paddle
::
optional
<
DenseTensor
>&
fixed_seed_offset
,
float
dropout
,
bool
causal
,
bool
return_softmax
,
bool
is_test
,
const
std
::
string
&
rng_name
,
DenseTensor
*
out
,
DenseTensor
*
softmax
,
DenseTensor
*
softmax_lse
,
...
...
@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
float
scale
=
1.0
f
/
std
::
sqrt
(
head_size
);
VLOG
(
4
)
<<
"FlashAttn fwd dims q["
<<
q
.
dims
()
<<
"], k["
<<
k
.
dims
()
<<
"], v["
<<
v
.
dims
()
<<
"]"
;
DenseTensor
q_t_s
,
k_t_s
,
v_t_s
;
q_t_s
.
ShareDataWith
(
q
).
Resize
({
total_q
,
num_heads
,
head_size
});
k_t_s
.
ShareDataWith
(
k
).
Resize
({
total_k
,
num_heads
,
head_size
});
...
...
@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
v_t_s
,
cu_seqlens_q
,
cu_seqlens_k
,
fixed_seed_offset
,
seq_len_q
,
seq_len_k
,
scale
,
...
...
@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
causal
,
return_softmax
,
is_test
,
rng_name
,
out
,
softmax
,
softmax_lse
,
...
...
@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT
,
phi
::
FlashAttnUnpaddedKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
5
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
// fixed_seed_offset
}
PD_REGISTER_KERNEL
(
flash_attn
,
GPU
,
ALL_LAYOUT
,
phi
::
FlashAttnKernel
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
3
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
// fixed_seed_offset
}
python/paddle/distributed/auto_parallel/operators/__init__.py
浏览文件 @
00ac8014
...
...
@@ -38,3 +38,4 @@ from . import dist_shape
from
.
import
dist_assign
from
.
import
dist_scale
from
.
import
dist_dropout
from
.
import
dist_flash_attn
python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py
0 → 100644
浏览文件 @
00ac8014
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import
logging
from
...utils.log_utils
import
get_logger
_logger
=
get_logger
(
logging
.
INFO
)
from
..random
import
determinate_rng
,
is_enable_auto_rand_ctrl
from
.common
import
(
DistributedOperatorImplContainer
,
register_distributed_operator_impl
,
register_distributed_operator_impl_container
,
)
from
.dist_eltwise
import
DistributedDefaultImpl0
,
DistributedElementwiseImpl0
class
DistributedFlashAttn
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
().
__init__
(
op_type
)
register_distributed_operator_impl_container
(
DistributedFlashAttn
(
"flash_attn"
))
# Dist FlashAttn with Random Control
class
DistributedFlashAttnImpl0
(
DistributedElementwiseImpl0
):
def
__init__
(
self
,
name
):
super
().
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
is_input_compatible
(
self
,
dist_op
):
return
True
def
is_output_compatible
(
self
,
dist_op
):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
return
True
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
dist_op_context
=
ctx
.
dist_op_context
main_block
=
dist_op_context
.
work_block
startup_block
=
dist_op_context
.
startup_block
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
if
(
is_enable_auto_rand_ctrl
()
and
not
op_dist_attr
.
is_recompute
and
rank_id
in
op_dist_attr
.
process_mesh
.
process_ids
):
assert
(
op_dist_attr
is
not
None
),
f
"forward op [
{
str
(
src_op
)
}
] don't have dist attribute !"
if
(
len
(
kwargs
.
get
(
'fixed_seed_offset'
,
[]))
>
0
or
len
(
src_op
.
input
(
"fixed_seed_offset"
))
>
0
):
# TODO(kuizhiqing) recompute should go here
pass
else
:
# determinate rng
q_var
=
main_block
.
_var_recursive
(
kwargs
[
'q'
][
0
])
k_var
=
main_block
.
_var_recursive
(
kwargs
[
'k'
][
0
])
q_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
q_var
.
name
)
k_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
k_var
.
name
)
process_mesh
=
op_dist_attr
.
process_mesh
dims_mapping
=
q_dims_mapping
[:
3
]
+
[
q_dims_mapping
[
2
]]
rng_name
=
determinate_rng
(
rank_id
,
dims_mapping
,
process_mesh
)
assert
rng_name
is
not
None
and
rng_name
!=
""
src_op
.
_set_attr
(
'rng_name'
,
rng_name
)
DistributedDefaultImpl0
.
forward
(
ctx
,
*
args
,
**
kwargs
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0
.
backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"flash_attn"
,
DistributedFlashAttnImpl0
(
"random_control"
)
)
python/paddle/nn/functional/flash_attention.py
浏览文件 @
00ac8014
...
...
@@ -24,6 +24,9 @@ def flash_attention(
dropout
=
0.0
,
causal
=
False
,
return_softmax
=
False
,
*
,
fixed_seed_offset
=
None
,
rng_name
=
""
,
training
=
True
,
name
=
None
,
):
...
...
@@ -57,7 +60,9 @@ def flash_attention(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
...
...
@@ -84,10 +89,12 @@ def flash_attention(
query
,
key
,
value
,
fixed_seed_offset
,
dropout
,
causal
,
return_softmax
,
not
training
,
rng_name
,
)
return
result_attention
,
result_softmax
if
return_softmax
else
None
...
...
@@ -101,6 +108,7 @@ def flash_attention(
'q'
:
query
,
'k'
:
key
,
'v'
:
value
,
'fixed_seed_offset'
:
fixed_seed_offset
,
}
outputs
=
{
'out'
:
out
,
...
...
@@ -117,6 +125,7 @@ def flash_attention(
'causal'
:
causal
,
'return_softmax'
:
return_softmax
,
'is_test'
:
not
training
,
'rng_name'
:
rng_name
,
},
)
return
out
,
softmax
if
return_softmax
else
None
...
...
@@ -134,6 +143,8 @@ def flash_attn_unpadded(
dropout
=
0.0
,
causal
=
False
,
return_softmax
=
False
,
fixed_seed_offset
=
None
,
rng_name
=
""
,
training
=
True
,
name
=
None
,
):
...
...
@@ -174,6 +185,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
...
...
@@ -203,6 +216,7 @@ def flash_attn_unpadded(
value
,
cu_seqlens_q
,
cu_seqlens_k
,
fixed_seed_offset
,
max_seqlen_q
,
max_seqlen_k
,
scale
,
...
...
@@ -210,6 +224,7 @@ def flash_attn_unpadded(
causal
,
return_softmax
,
not
training
,
rng_name
,
)
return
result_attention
,
result_softmax
if
return_softmax
else
None
...
...
@@ -225,6 +240,7 @@ def flash_attn_unpadded(
'v'
:
value
,
'cu_seqlens_q'
:
cu_seqlens_q
,
'cu_seqlens_k'
:
cu_seqlens_k
,
'fixed_seed_offset'
:
fixed_seed_offset
,
}
outputs
=
{
'out'
:
out
,
...
...
@@ -244,6 +260,7 @@ def flash_attn_unpadded(
'causal'
:
causal
,
'return_softmax'
:
return_softmax
,
'is_test'
:
not
training
,
'rng_name'
:
rng_name
,
},
)
return
out
,
softmax
if
return_softmax
else
None
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录