未验证 提交 7f9ab2bd 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add scaled_dot_product_attention (#53113)

上级 decc4c38
......@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
def test_unpadded(self):
print(
......@@ -189,6 +190,16 @@ class TestFlashAttentionAPI(unittest.TestCase):
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
if self.use_sdp_kernel:
with paddle.nn.functional.sdp_kernel(
enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
......@@ -220,6 +231,21 @@ class TestFlashAttentionAPI(unittest.TestCase):
name="v", shape=self.shape, dtype=self.dtype
)
if self.use_sdp_kernel:
with paddle.nn.functional.sdp_kernel(
enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
)
......@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
......@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = False
self.return_softmax = True
self.use_sdp_kernel = False
class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
......@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = True
self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
......@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
if __name__ == '__main__':
......
......@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401
from .extension import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention
from .flash_attention import scaled_dot_product_attention
from .flash_attention import sdp_kernel
__all__ = [ # noqa
'celu',
......
......@@ -13,8 +13,113 @@
# limitations under the License.
import paddle
import paddle.nn.functional as F
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
g_enable_math = None
g_enable_flash = None
g_enable_mem_efficient = None
@signature_safe_contextmanager
def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True):
r"""
With the sdp_kernel context manager, different algorithm implementations can
be selected for scaled_dot_product_attention.
"""
global g_enable_math, g_enable_flash, g_enable_mem_efficient
original_enable_math = g_enable_math
original_enable_flash = g_enable_math
original_enable_mem_efficient = g_enable_mem_efficient
g_enable_math = enable_math
g_enable_flash = enable_flash
g_enable_mem_efficient = enable_mem_efficient
try:
yield
finally:
g_enable_math = original_enable_math
g_enable_flash = original_enable_flash
g_enable_mem_efficient = original_enable_mem_efficient
def _math_attention(
query,
key,
value,
dropout_rate=0.0,
causal=False,
return_softmax=False,
training=True,
):
r"""
This is a basic implementation of scaled dot product attention composed of
combinations of fundamental components.
"""
head_dim = query.shape[-1]
query = paddle.transpose(query, [0, 2, 1, 3])
key = paddle.transpose(key, [0, 2, 1, 3])
value = paddle.transpose(value, [0, 2, 1, 3])
product = paddle.matmul(
x=query * (head_dim**-0.5), y=key, transpose_y=True
)
weights = (
paddle.incubate.softmax_mask_fuse_upper_triangle(product)
if causal
else F.softmax(product)
)
if dropout_rate > 0.0:
weights = F.dropout(
weights, dropout_rate, training=training, mode="upscale_in_train"
)
out = paddle.matmul(weights, value)
out = paddle.transpose(out, [0, 2, 1, 3])
return out, weights if return_softmax else None
def _select_sdp_cuda(head_dim):
if head_dim < 128:
return "flash_attn"
else:
return "mem_efficient"
def _select_sdp(head_dim):
r"""
There are currently three different implementation options available for
scaled dot product attention, and the chosen approach depends on whether it
is determined by the sdp_kernel configuration or specified through input values.
"""
place = paddle.get_device()
# not use sdp_kernel
if g_enable_flash is None:
if "gpu" not in place:
return "math"
else:
return _select_sdp_cuda(head_dim)
if (
g_enable_math is False
and g_enable_flash is False
and g_enable_mem_efficient is False
):
raise AssertionError(
"No available backend for scaled_dot_product_attention was found."
)
if g_enable_math is True:
if g_enable_flash is False and g_enable_mem_efficient is False:
return "math"
if "gpu" not in place:
return "math"
if g_enable_flash is True and g_enable_mem_efficient is True:
return _select_sdp_cuda(head_dim)
if g_enable_flash is True:
return "flash_attn"
return "mem_efficient"
def flash_attention(
......@@ -84,6 +189,10 @@ def flash_attention(
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
print(output)
"""
head_dim = query.shape[3]
sdp_func_name = _select_sdp(head_dim)
if sdp_func_name == "flash_attn":
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
......@@ -129,6 +238,32 @@ def flash_attention(
},
)
return out, softmax if return_softmax else None
else:
if sdp_func_name == "mem_efficient":
from paddle.incubate.nn.memory_efficient_attention import (
memory_efficient_attention,
)
output = memory_efficient_attention(
query,
key,
value,
attn_bias=None,
p=dropout,
scale=None,
training=training,
)
return output, None
else:
return _math_attention(
query,
key,
value,
dropout_rate=dropout,
causal=causal,
return_softmax=return_softmax,
training=training,
)
def flash_attn_unpadded(
......@@ -264,3 +399,6 @@ def flash_attn_unpadded(
},
)
return out, softmax if return_softmax else None
scaled_dot_product_attention = flash_attention
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册