未验证 提交 8ac5a6b6 编写于 作者: S sneaxiy 提交者: GitHub

Fix flash attention bug (#52551)

* fix flash attn

* fix another API
上级 7976e2a3
......@@ -565,7 +565,7 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -576,7 +576,7 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......
......@@ -563,7 +563,7 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false)
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
......@@ -575,7 +575,7 @@
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
......
......@@ -262,7 +262,9 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
out->set_dims(q.dims());
auto out_dims = q.dims();
out_dims[3] = v.dims()[3];
out->set_dims(out_dims);
out->set_dtype(q.dtype());
out->set_layout(q.layout());
}
......
......@@ -32,6 +32,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......@@ -45,6 +46,7 @@ void FlashAttnKernel(const Context& ctx,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......
......@@ -43,11 +43,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
......@@ -187,6 +190,7 @@ void FlashAttnKernel(const Context& ctx,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......@@ -237,6 +241,7 @@ void FlashAttnKernel(const Context& ctx,
dropout,
causal,
return_softmax,
is_test,
out,
softmax,
softmax_lse,
......
......@@ -24,6 +24,7 @@ def flash_attention(
dropout=0.0,
causal=False,
return_softmax=False,
training=True,
name=None,
):
r"""
......@@ -54,8 +55,9 @@ def flash_attention(
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
dropout(float): The dropout ratio.
causal(bool): Wether enable causal mode.
return_softmax(bool): Wether to return softmax.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -85,8 +87,9 @@ def flash_attention(
dropout,
causal,
return_softmax,
not training,
)
return result_attention, result_softmax
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
......@@ -113,9 +116,10 @@ def flash_attention(
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
},
)
return out, softmax
return out, softmax if return_softmax else None
def flash_attn_unpadded(
......@@ -130,6 +134,7 @@ def flash_attn_unpadded(
dropout=0.0,
causal=False,
return_softmax=False,
training=True,
name=None,
):
r"""
......@@ -167,8 +172,9 @@ def flash_attn_unpadded(
max_seqlen_k(int): Maximum sequence length of key/value in the batch.
scale(float): The scaling of QK^T before applying softmax.
dropout(float): The dropout ratio.
causal(bool): Wether enable causal mode.
return_softmax(bool): Wether to return softmax.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -203,8 +209,9 @@ def flash_attn_unpadded(
dropout,
causal,
return_softmax,
not training,
)
return result_attention, result_softmax
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn_unpadded', **locals())
dtype = helper.input_dtype(input_param_name='q')
......@@ -236,6 +243,7 @@ def flash_attn_unpadded(
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
},
)
return out, softmax
return out, softmax if return_softmax else None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册