From 1509a0369fa6dc6df88a3a58559625009c308d90 Mon Sep 17 00:00:00 2001 From: yinwei <1871465933@qq.com> Date: Tue, 15 Aug 2023 16:35:18 +0800 Subject: [PATCH] Add flash attention backward grad check (#56249) --------- Co-authored-by: tianhaodongbd --- paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu | 12 +++++++----- paddle/phi/kernels/gpu/flash_attn_utils.h | 2 +- python/paddle/nn/functional/flash_attention.py | 13 ++++++++----- test/legacy_test/test_flash_attention.py | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index de479cf9adf..60a1c54d726 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -70,7 +70,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, phi::errors::InvalidArgument("The head_dim is expected to be either 32, " "64, or 128, but recieved %d.", head_size)); - const int64_t* seed_offset_data = seed_offset.data(); uint64_t seed = static_cast(seed_offset_data[0]); uint64_t offset = static_cast(seed_offset_data[1]); @@ -88,8 +87,13 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, bool fa_zero_tensors = false; uint64_t workspace_size; + + int64_t q_size = total_q * num_heads * head_size; + DenseTensor scaled_q = Empty(ctx, {total_q, num_heads, head_size}); + ComputeScaleQ(ctx, q_size, scale, q.data(), scaled_q.data()); + bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( - static_cast(q.data()), + static_cast(scaled_q.data()), static_cast(k.data()), static_cast(v.data()), static_cast(dq->data()), @@ -124,7 +128,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, mask_dims.data() ? mask_dims.data() : nullptr, nullptr); CheckFlashAttnStatus(succ); - DenseTensor workspace; if (workspace_size > 0) { workspace = Empty( @@ -132,7 +135,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, } succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( - static_cast(q.data()), + static_cast(scaled_q.data()), static_cast(k.data()), static_cast(v.data()), static_cast(dq->data()), @@ -168,7 +171,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, nullptr); CheckFlashAttnStatus(succ); - int64_t q_size = total_q * num_heads * head_size; ComputeScaleQ(ctx, q_size, scale, dq->data(), dq->data()); #else RaiseNotSupportedError(); diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index 00ba036df09..03601dae16d 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -238,7 +238,7 @@ static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { rank, 4, phi::errors::InvalidArgument( - "Teh number of dimenstions of attn_mask is expected to be greater " + "The number of dimenstions of attn_mask is expected to be greater " "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", rank, origin_dims)); diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index ae2ddc2a921..d6eb44e66e2 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -417,6 +417,7 @@ def scaled_dot_product_attention( dropout_p=0.0, is_causal=False, training=True, + name=None, ): r""" The equation is: @@ -447,10 +448,12 @@ def scaled_dot_product_attention( The dtype can be float61 or bfloat16. attn_mask(Tensor,optional): A float mask of the same type as query, key, value that is added to the attention score. - not supported yet. dropout_p(float): The dropout ratio. is_causal(bool): Whether enable causal mode. - training(bool): Whether it is in the training phase + training(bool): Whether it is in the training phase. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. Returns: out(Tensor): The attention tensor. @@ -459,13 +462,13 @@ def scaled_dot_product_attention( Examples: .. code-block:: python - # required: skiptest - >>> # xdoctest: +SKIP() + + >>> # doctest: +SKIP() >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) >>> print(output) - >>> # xdoctest: -SKIP + >>> # doctest: -SKIP """ if attn_mask is None: out, _ = flash_attention(query, key, value, dropout_p, is_causal) diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 4088f60570f..979217d7d22 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -312,7 +312,7 @@ class TestFlashAttentionAPI(unittest.TestCase): not core.is_compiled_with_cuda() or get_cuda_version() < 11040 or not is_sm_supported, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" "and device's compute capability must be 7.5 or 8.x", ) class TestFlashAttentionWithMaskAPI(unittest.TestCase): -- GitLab