未验证 提交 0b778bdc 编写于 作者: C Chitsing KUI 提交者: GitHub

rename flash_attn_raw to flash_attn_unpadded (#51704)

* rename flash_attn_raw to flash_attn_unpadded

* fix static api

* fix static return
上级 86bf8274
......@@ -540,7 +540,7 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -550,15 +550,15 @@
func : flash_attn_grad
data_type: q
- backward_op : flash_attn_raw_grad
forward : flash_attn_raw (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_raw_grad
func : flash_attn_unpadded_grad
data_type: q
- backward_op : flip_grad
......
......@@ -530,25 +530,27 @@
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_grad
- op : flash_attn_raw
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn_raw
func : flash_attn_unpadded
data_type : q
backward : flash_attn_raw_grad
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad
- op : flip
args : (Tensor x, int[] axis)
......
......@@ -259,8 +259,8 @@ void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
out->set_dims(q.dims());
out->set_dtype(q.dtype());
......
......@@ -67,8 +67,8 @@ void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);
void InstanceNormInferMeta(const MetaTensor& x,
......
......@@ -20,24 +20,24 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnRawGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
......
......@@ -20,22 +20,22 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnRawKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset);
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
......@@ -46,8 +46,8 @@ void FlashAttnKernel(const Context& ctx,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
} // namespace phi
......@@ -28,24 +28,24 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnRawGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
......@@ -202,34 +202,34 @@ void FlashAttnGradKernel(const Context& ctx,
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnRawGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_raw_grad,
PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnRawGradKernel,
phi::FlashAttnUnpaddedGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset
......
......@@ -31,22 +31,22 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnRawKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset) {
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(out);
......@@ -185,8 +185,8 @@ void FlashAttnKernel(const Context& ctx,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
......@@ -224,32 +224,32 @@ void FlashAttnKernel(const Context& ctx,
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnRawKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
out,
softmax_lse,
softmax,
seed_offset);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
out,
softmax,
softmax_lse,
seed_offset);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_raw,
PD_REGISTER_KERNEL(flash_attn_unpadded,
GPU,
ALL_LAYOUT,
phi::FlashAttnRawKernel,
phi::FlashAttnUnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......
......@@ -22,7 +22,10 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.nn.functional.flash_attention import flash_attention
from paddle.nn.functional.flash_attention import (
flash_attention,
flash_attn_unpadded,
)
def get_cuda_version():
......@@ -66,9 +69,9 @@ class TestFlashAttentionAPI(unittest.TestCase):
self.causal = False
self.return_softmax = False
def test_raw(self):
def test_unpadded(self):
print(
f"Test Raw case shape {self.shape} dtype {self.dtype} causal {self.causal}"
f"Test unpadded case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
paddle.disable_static()
......@@ -92,7 +95,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32')
qq = paddle.reshape(q, [bs * ms, nh, hd])
out, _, _, _ = paddle._C_ops.flash_attn_raw(
out, _ = flash_attn_unpadded(
qq,
qq,
qq,
......@@ -116,6 +119,45 @@ class TestFlashAttentionAPI(unittest.TestCase):
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)
# test static
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
qs = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32')
qs = paddle.reshape(qs, [bs * ms, nh, hd])
outs, softmax = flash_attn_unpadded(
qs,
qs,
qs,
cu_q,
cu_q,
ms,
ms,
scale,
self.dropout,
self.causal,
self.return_softmax,
)
exe = fluid.Executor(self.place)
fetches_result = exe.run(
feed={
"q": query.astype('float16'),
"k": query.astype('float16'),
"v": query.astype('float16'),
},
fetch_list=[outs],
)
np.testing.assert_allclose(
fetches_result[0], out_, rtol=5e-03, atol=1e-03
)
def test_all(self):
print(
f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}"
......
......@@ -78,12 +78,7 @@ def flash_attention(
print(output)
"""
if in_dynamic_mode():
(
result_attention,
result_softmax_lse,
result_softmax,
seed_offset,
) = _C_ops.flash_attn(
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
key,
value,
......@@ -121,3 +116,126 @@ def flash_attention(
},
)
return out, softmax
def flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
scale,
dropout=0.0,
causal=False,
return_softmax=False,
name=None,
):
r"""
The equation is:
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Warning:
This API is only support inputs with dtype float16 and bfloat16.
Args:
query(Tensor): The query tensor in the Attention module.
3-D tensor with shape:
[total_seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
key(Tensor): The key tensor in the Attention module.
3-D tensor with shape:
[total_seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
value(Tensor): The value tensor in the Attention module.
3-D tensor with shape:
[total_seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
used to index query.
cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
used to index key and value.
max_seqlen_q(int): Maximum sequence length of query in the batch.
max_seqlen_k(int): Maximum sequence length of key/value in the batch.
scale(float): The scaling of QK^T before applying softmax.
dropout(float): The dropout ratio.
causal(bool): Wether enable causal mode.
return_softmax(bool): Wether to return softmax.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns:
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.
softmax(Tensor): The softmax tensor. None if return_softmax is False.
Examples:
.. code-block:: python
# required: skiptest
import paddle
q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16)
output = paddle.nn.functional.flash_attn_unpadded(q, q, q, 0.9, False, False)
print(output)
"""
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
)
return result_attention, result_softmax
helper = LayerHelper('flash_attn_unpadded', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(paddle.float32)
seed_offset = helper.create_variable_for_type_inference(paddle.int64)
inputs = {
'q': query,
'k': key,
'v': value,
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn_unpadded',
inputs=inputs,
outputs=outputs,
attrs={
'max_seqlen_q': max_seqlen_q,
'max_seqlen_k': max_seqlen_k,
'scale': scale,
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
},
)
return out, softmax
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册