未验证 提交 8ac5a6b6 编写于 作者: S sneaxiy 提交者: GitHub

Fix flash attention bug (#52551)

* fix flash attn

* fix another API
上级 7976e2a3
...@@ -565,7 +565,7 @@ ...@@ -565,7 +565,7 @@
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad - backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
...@@ -576,7 +576,7 @@ ...@@ -576,7 +576,7 @@
data_type: q data_type: q
- backward_op : flash_attn_unpadded_grad - backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
......
...@@ -563,7 +563,7 @@ ...@@ -563,7 +563,7 @@
backward : fill_diagonal_tensor_grad backward : fill_diagonal_tensor_grad
- op : flash_attn - op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
...@@ -575,7 +575,7 @@ ...@@ -575,7 +575,7 @@
backward : flash_attn_grad backward : flash_attn_grad
- op : flash_attn_unpadded - op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
......
...@@ -262,7 +262,9 @@ void FlashAttnInferMeta(const MetaTensor& q, ...@@ -262,7 +262,9 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax, MetaTensor* softmax,
MetaTensor* softmax_lse, MetaTensor* softmax_lse,
MetaTensor* seed_offset) { MetaTensor* seed_offset) {
out->set_dims(q.dims()); auto out_dims = q.dims();
out_dims[3] = v.dims()[3];
out->set_dims(out_dims);
out->set_dtype(q.dtype()); out->set_dtype(q.dtype());
out->set_layout(q.layout()); out->set_layout(q.layout());
} }
......
...@@ -32,6 +32,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -32,6 +32,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -45,6 +46,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -45,6 +46,7 @@ void FlashAttnKernel(const Context& ctx,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
......
...@@ -43,11 +43,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -43,11 +43,14 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
DenseTensor* seed_offset) { DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
...@@ -187,6 +190,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -187,6 +190,7 @@ void FlashAttnKernel(const Context& ctx,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -237,6 +241,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -237,6 +241,7 @@ void FlashAttnKernel(const Context& ctx,
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
is_test,
out, out,
softmax, softmax,
softmax_lse, softmax_lse,
......
...@@ -24,6 +24,7 @@ def flash_attention( ...@@ -24,6 +24,7 @@ def flash_attention(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
training=True,
name=None, name=None,
): ):
r""" r"""
...@@ -54,8 +55,9 @@ def flash_attention( ...@@ -54,8 +55,9 @@ def flash_attention(
[batch_size, seq_len, num_heads, head_dim]. [batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16. The dtype can be float61 or bfloat16.
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Wether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Wether to return softmax. return_softmax(bool): Whether to return softmax.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
...@@ -85,8 +87,9 @@ def flash_attention( ...@@ -85,8 +87,9 @@ def flash_attention(
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
not training,
) )
return result_attention, result_softmax return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn', **locals()) helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q') dtype = helper.input_dtype(input_param_name='q')
...@@ -113,9 +116,10 @@ def flash_attention( ...@@ -113,9 +116,10 @@ def flash_attention(
'dropout': dropout, 'dropout': dropout,
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training,
}, },
) )
return out, softmax return out, softmax if return_softmax else None
def flash_attn_unpadded( def flash_attn_unpadded(
...@@ -130,6 +134,7 @@ def flash_attn_unpadded( ...@@ -130,6 +134,7 @@ def flash_attn_unpadded(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
training=True,
name=None, name=None,
): ):
r""" r"""
...@@ -167,8 +172,9 @@ def flash_attn_unpadded( ...@@ -167,8 +172,9 @@ def flash_attn_unpadded(
max_seqlen_k(int): Maximum sequence length of key/value in the batch. max_seqlen_k(int): Maximum sequence length of key/value in the batch.
scale(float): The scaling of QK^T before applying softmax. scale(float): The scaling of QK^T before applying softmax.
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Wether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Wether to return softmax. return_softmax(bool): Whether to return softmax.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
...@@ -203,8 +209,9 @@ def flash_attn_unpadded( ...@@ -203,8 +209,9 @@ def flash_attn_unpadded(
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
not training,
) )
return result_attention, result_softmax return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn_unpadded', **locals()) helper = LayerHelper('flash_attn_unpadded', **locals())
dtype = helper.input_dtype(input_param_name='q') dtype = helper.input_dtype(input_param_name='q')
...@@ -236,6 +243,7 @@ def flash_attn_unpadded( ...@@ -236,6 +243,7 @@ def flash_attn_unpadded(
'dropout': dropout, 'dropout': dropout,
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training,
}, },
) )
return out, softmax return out, softmax if return_softmax else None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册