diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index de479cf9adfd28e75a37dcc2e2be7bdb61c52b6f..60a1c54d72678e8d734ad5a4bbae03d4997ac734 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -70,7 +70,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, phi::errors::InvalidArgument("The head_dim is expected to be either 32, " "64, or 128, but recieved %d.", head_size)); - const int64_t* seed_offset_data = seed_offset.data(); uint64_t seed = static_cast(seed_offset_data[0]); uint64_t offset = static_cast(seed_offset_data[1]); @@ -88,8 +87,13 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, bool fa_zero_tensors = false; uint64_t workspace_size; + + int64_t q_size = total_q * num_heads * head_size; + DenseTensor scaled_q = Empty(ctx, {total_q, num_heads, head_size}); + ComputeScaleQ(ctx, q_size, scale, q.data(), scaled_q.data()); + bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( - static_cast(q.data()), + static_cast(scaled_q.data()), static_cast(k.data()), static_cast(v.data()), static_cast(dq->data()), @@ -124,7 +128,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, mask_dims.data() ? mask_dims.data() : nullptr, nullptr); CheckFlashAttnStatus(succ); - DenseTensor workspace; if (workspace_size > 0) { workspace = Empty( @@ -132,7 +135,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, } succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( - static_cast(q.data()), + static_cast(scaled_q.data()), static_cast(k.data()), static_cast(v.data()), static_cast(dq->data()), @@ -168,7 +171,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, nullptr); CheckFlashAttnStatus(succ); - int64_t q_size = total_q * num_heads * head_size; ComputeScaleQ(ctx, q_size, scale, dq->data(), dq->data()); #else RaiseNotSupportedError(); diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index 00ba036df09ba582d2415ddab7f07c8db6ebed56..03601dae16db12da85371b52739c3e8968b47f84 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -238,7 +238,7 @@ static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { rank, 4, phi::errors::InvalidArgument( - "Teh number of dimenstions of attn_mask is expected to be greater " + "The number of dimenstions of attn_mask is expected to be greater " "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", rank, origin_dims)); diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index ae2ddc2a921a239ae67e8eb35c5bce029f64e365..d6eb44e66e251ad184b979e5df93cfec44d33430 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -417,6 +417,7 @@ def scaled_dot_product_attention( dropout_p=0.0, is_causal=False, training=True, + name=None, ): r""" The equation is: @@ -447,10 +448,12 @@ def scaled_dot_product_attention( The dtype can be float61 or bfloat16. attn_mask(Tensor,optional): A float mask of the same type as query, key, value that is added to the attention score. - not supported yet. dropout_p(float): The dropout ratio. is_causal(bool): Whether enable causal mode. - training(bool): Whether it is in the training phase + training(bool): Whether it is in the training phase. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. Returns: out(Tensor): The attention tensor. @@ -459,13 +462,13 @@ def scaled_dot_product_attention( Examples: .. code-block:: python - # required: skiptest - >>> # xdoctest: +SKIP() + + >>> # doctest: +SKIP() >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) >>> print(output) - >>> # xdoctest: -SKIP + >>> # doctest: -SKIP """ if attn_mask is None: out, _ = flash_attention(query, key, value, dropout_p, is_causal) diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 4088f60570f50c91d9eb5e8df9bdc7fd86cc7db4..979217d7d221a513d744173a222afb98ac0c7e29 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -312,7 +312,7 @@ class TestFlashAttentionAPI(unittest.TestCase): not core.is_compiled_with_cuda() or get_cuda_version() < 11040 or not is_sm_supported, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" "and device's compute capability must be 7.5 or 8.x", ) class TestFlashAttentionWithMaskAPI(unittest.TestCase):