未验证 提交 1509a036 编写于 作者: Y yinwei 提交者: GitHub

Add flash attention backward grad check (#56249)

---------
Co-authored-by: Ntianhaodongbd <tianhaodong@baidu.com>
上级 a26a3a60
...@@ -70,7 +70,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, ...@@ -70,7 +70,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
phi::errors::InvalidArgument("The head_dim is expected to be either 32, " phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
"64, or 128, but recieved %d.", "64, or 128, but recieved %d.",
head_size)); head_size));
const int64_t* seed_offset_data = seed_offset.data<int64_t>(); const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]); uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]); uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
...@@ -88,8 +87,13 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, ...@@ -88,8 +87,13 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
bool fa_zero_tensors = false; bool fa_zero_tensors = false;
uint64_t workspace_size; uint64_t workspace_size;
int64_t q_size = total_q * num_heads * head_size;
DenseTensor scaled_q = Empty<T>(ctx, {total_q, num_heads, head_size});
ComputeScaleQ(ctx, q_size, scale, q.data<T>(), scaled_q.data<T>());
bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()), static_cast<const void*>(scaled_q.data<T>()),
static_cast<const void*>(k.data()), static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()), static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()), static_cast<void*>(dq->data()),
...@@ -124,7 +128,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, ...@@ -124,7 +128,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
mask_dims.data() ? mask_dims.data() : nullptr, mask_dims.data() ? mask_dims.data() : nullptr,
nullptr); nullptr);
CheckFlashAttnStatus(succ); CheckFlashAttnStatus(succ);
DenseTensor workspace; DenseTensor workspace;
if (workspace_size > 0) { if (workspace_size > 0) {
workspace = Empty<float>( workspace = Empty<float>(
...@@ -132,7 +135,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, ...@@ -132,7 +135,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
} }
succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()), static_cast<const void*>(scaled_q.data<T>()),
static_cast<const void*>(k.data()), static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()), static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()), static_cast<void*>(dq->data()),
...@@ -168,7 +171,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, ...@@ -168,7 +171,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
nullptr); nullptr);
CheckFlashAttnStatus(succ); CheckFlashAttnStatus(succ);
int64_t q_size = total_q * num_heads * head_size;
ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>()); ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>());
#else #else
RaiseNotSupportedError(); RaiseNotSupportedError();
......
...@@ -238,7 +238,7 @@ static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) { ...@@ -238,7 +238,7 @@ static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
rank, rank,
4, 4,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Teh number of dimenstions of attn_mask is expected to be greater " "The number of dimenstions of attn_mask is expected to be greater "
"or equal to 4, but recieved %d. The shape of attn_mask is {%s}", "or equal to 4, but recieved %d. The shape of attn_mask is {%s}",
rank, rank,
origin_dims)); origin_dims));
......
...@@ -417,6 +417,7 @@ def scaled_dot_product_attention( ...@@ -417,6 +417,7 @@ def scaled_dot_product_attention(
dropout_p=0.0, dropout_p=0.0,
is_causal=False, is_causal=False,
training=True, training=True,
name=None,
): ):
r""" r"""
The equation is: The equation is:
...@@ -447,10 +448,12 @@ def scaled_dot_product_attention( ...@@ -447,10 +448,12 @@ def scaled_dot_product_attention(
The dtype can be float61 or bfloat16. The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query, attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score. key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio. dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode. is_causal(bool): Whether enable causal mode.
training(bool): Whether it is in the training phase training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns: Returns:
out(Tensor): The attention tensor. out(Tensor): The attention tensor.
...@@ -459,13 +462,13 @@ def scaled_dot_product_attention( ...@@ -459,13 +462,13 @@ def scaled_dot_product_attention(
Examples: Examples:
.. code-block:: python .. code-block:: python
# required: skiptest
>>> # xdoctest: +SKIP() >>> # doctest: +SKIP()
>>> import paddle >>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output) >>> print(output)
>>> # xdoctest: -SKIP >>> # doctest: -SKIP
""" """
if attn_mask is None: if attn_mask is None:
out, _ = flash_attention(query, key, value, dropout_p, is_causal) out, _ = flash_attention(query, key, value, dropout_p, is_causal)
......
...@@ -312,7 +312,7 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -312,7 +312,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
or get_cuda_version() < 11040 or get_cuda_version() < 11040
or not is_sm_supported, or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 7.5 or 8.x", "and device's compute capability must be 7.5 or 8.x",
) )
class TestFlashAttentionWithMaskAPI(unittest.TestCase): class TestFlashAttentionWithMaskAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册