未验证 提交 1509a036 编写于 作者: Y yinwei 提交者: GitHub

Add flash attention backward grad check (#56249)

---------
Co-authored-by: Ntianhaodongbd <tianhaodong@baidu.com>
上级 a26a3a60
......@@ -70,7 +70,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
"64, or 128, but recieved %d.",
head_size));
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
......@@ -88,8 +87,13 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
bool fa_zero_tensors = false;
uint64_t workspace_size;
int64_t q_size = total_q * num_heads * head_size;
DenseTensor scaled_q = Empty<T>(ctx, {total_q, num_heads, head_size});
ComputeScaleQ(ctx, q_size, scale, q.data<T>(), scaled_q.data<T>());
bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()),
static_cast<const void*>(scaled_q.data<T>()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()),
......@@ -124,7 +128,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(
......@@ -132,7 +135,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
}
succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()),
static_cast<const void*>(scaled_q.data<T>()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()),
......@@ -168,7 +171,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
nullptr);
CheckFlashAttnStatus(succ);
int64_t q_size = total_q * num_heads * head_size;
ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>());
#else
RaiseNotSupportedError();
......
......@@ -238,7 +238,7 @@ static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
rank,
4,
phi::errors::InvalidArgument(
"Teh number of dimenstions of attn_mask is expected to be greater "
"The number of dimenstions of attn_mask is expected to be greater "
"or equal to 4, but recieved %d. The shape of attn_mask is {%s}",
rank,
origin_dims));
......
......@@ -417,6 +417,7 @@ def scaled_dot_product_attention(
dropout_p=0.0,
is_causal=False,
training=True,
name=None,
):
r"""
The equation is:
......@@ -447,10 +448,12 @@ def scaled_dot_product_attention(
The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
training(bool): Whether it is in the training phase
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns:
out(Tensor): The attention tensor.
......@@ -459,13 +462,13 @@ def scaled_dot_product_attention(
Examples:
.. code-block:: python
# required: skiptest
>>> # xdoctest: +SKIP()
>>> # doctest: +SKIP()
>>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output)
>>> # xdoctest: -SKIP
>>> # doctest: -SKIP
"""
if attn_mask is None:
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
......
......@@ -312,7 +312,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
not core.is_compiled_with_cuda()
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 7.5 or 8.x",
)
class TestFlashAttentionWithMaskAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册