未验证 提交 b19dfb8c 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

Add scaled_dot_product_attention api (#55242)

上级 ef29468e
......@@ -261,4 +261,5 @@ __all__ = [ # noqa
'multi_margin_loss',
'soft_margin_loss',
'gaussian_nll_loss',
'scaled_dot_product_attention',
]
......@@ -407,4 +407,57 @@ def flash_attn_unpadded(
return out, softmax if return_softmax else None
scaled_dot_product_attention = flash_attention
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
):
r"""
The equation is:
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Warning:
This API only supports inputs with dtype float16 and bfloat16.
Args:
query(Tensor): The query tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
key(Tensor): The key tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
value(Tensor): The value tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
Returns:
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.
Examples:
.. code-block:: python
# required: skiptest
>>> # xdoctest: +SKIP()
>>> import paddle
>>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16)
>>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
>>> print(output)
>>> # xdoctest: -SKIP
"""
assert attn_mask is None, "attn_mask is not supported yet"
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
return out
......@@ -25,6 +25,7 @@ from paddle.fluid import core
from paddle.nn.functional.flash_attention import (
flash_attention,
flash_attn_unpadded,
scaled_dot_product_attention,
)
......@@ -85,6 +86,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
self.use_sdp_api = False
def test_unpadded(self):
print(
......@@ -212,9 +214,15 @@ class TestFlashAttentionAPI(unittest.TestCase):
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
if self.use_sdp_api:
out = scaled_dot_product_attention(
q, k, v, None, self.dropout, self.causal
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
......@@ -253,14 +261,19 @@ class TestFlashAttentionAPI(unittest.TestCase):
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
if self.use_sdp_api:
outs = scaled_dot_product_attention(
qs, ks, vs, None, self.dropout, self.causal
)
else:
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
......@@ -334,6 +347,22 @@ class TestMathAttentionAPITest(TestFlashAttentionAPI):
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.use_sdp_api = False
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
class TestSDPAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.use_sdp_api = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册