未验证 提交 7f9ab2bd 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add scaled_dot_product_attention (#53113)

上级 decc4c38
...@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
def test_unpadded(self): def test_unpadded(self):
print( print(
...@@ -189,9 +190,19 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -189,9 +190,19 @@ class TestFlashAttentionAPI(unittest.TestCase):
value, place=self.place, dtype=self.dtype, stop_gradient=False value, place=self.place, dtype=self.dtype, stop_gradient=False
) )
out, _ = flash_attention( if self.use_sdp_kernel:
q, k, v, self.dropout, self.causal, self.return_softmax with paddle.nn.functional.sdp_kernel(
) enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
out_ = attention_naive(q_, k_, v_, self.causal) out_ = attention_naive(q_, k_, v_, self.causal)
out.backward() out.backward()
...@@ -220,9 +231,24 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -220,9 +231,24 @@ class TestFlashAttentionAPI(unittest.TestCase):
name="v", shape=self.shape, dtype=self.dtype name="v", shape=self.shape, dtype=self.dtype
) )
outs, softmax = flash_attention( if self.use_sdp_kernel:
qs, ks, vs, self.dropout, self.causal, self.return_softmax with paddle.nn.functional.sdp_kernel(
) enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
)
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
fetches_result = exe.run( fetches_result = exe.run(
...@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI): ...@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest2(TestFlashAttentionAPI): class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
...@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI): ...@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = True self.return_softmax = True
self.use_sdp_kernel = False
class TestFlashAttentionAPITest3(TestFlashAttentionAPI): class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
...@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI): ...@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = True self.causal = True
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest4(TestFlashAttentionAPI): class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
...@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI): ...@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401 ...@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401
from .extension import temporal_shift # noqa: F401 from .extension import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention from .sparse_attention import sparse_attention
from .flash_attention import scaled_dot_product_attention
from .flash_attention import sdp_kernel
__all__ = [ # noqa __all__ = [ # noqa
'celu', 'celu',
......
...@@ -13,8 +13,113 @@ ...@@ -13,8 +13,113 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.nn.functional as F
from paddle import _C_ops, in_dynamic_mode from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
g_enable_math = None
g_enable_flash = None
g_enable_mem_efficient = None
@signature_safe_contextmanager
def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True):
r"""
With the sdp_kernel context manager, different algorithm implementations can
be selected for scaled_dot_product_attention.
"""
global g_enable_math, g_enable_flash, g_enable_mem_efficient
original_enable_math = g_enable_math
original_enable_flash = g_enable_math
original_enable_mem_efficient = g_enable_mem_efficient
g_enable_math = enable_math
g_enable_flash = enable_flash
g_enable_mem_efficient = enable_mem_efficient
try:
yield
finally:
g_enable_math = original_enable_math
g_enable_flash = original_enable_flash
g_enable_mem_efficient = original_enable_mem_efficient
def _math_attention(
query,
key,
value,
dropout_rate=0.0,
causal=False,
return_softmax=False,
training=True,
):
r"""
This is a basic implementation of scaled dot product attention composed of
combinations of fundamental components.
"""
head_dim = query.shape[-1]
query = paddle.transpose(query, [0, 2, 1, 3])
key = paddle.transpose(key, [0, 2, 1, 3])
value = paddle.transpose(value, [0, 2, 1, 3])
product = paddle.matmul(
x=query * (head_dim**-0.5), y=key, transpose_y=True
)
weights = (
paddle.incubate.softmax_mask_fuse_upper_triangle(product)
if causal
else F.softmax(product)
)
if dropout_rate > 0.0:
weights = F.dropout(
weights, dropout_rate, training=training, mode="upscale_in_train"
)
out = paddle.matmul(weights, value)
out = paddle.transpose(out, [0, 2, 1, 3])
return out, weights if return_softmax else None
def _select_sdp_cuda(head_dim):
if head_dim < 128:
return "flash_attn"
else:
return "mem_efficient"
def _select_sdp(head_dim):
r"""
There are currently three different implementation options available for
scaled dot product attention, and the chosen approach depends on whether it
is determined by the sdp_kernel configuration or specified through input values.
"""
place = paddle.get_device()
# not use sdp_kernel
if g_enable_flash is None:
if "gpu" not in place:
return "math"
else:
return _select_sdp_cuda(head_dim)
if (
g_enable_math is False
and g_enable_flash is False
and g_enable_mem_efficient is False
):
raise AssertionError(
"No available backend for scaled_dot_product_attention was found."
)
if g_enable_math is True:
if g_enable_flash is False and g_enable_mem_efficient is False:
return "math"
if "gpu" not in place:
return "math"
if g_enable_flash is True and g_enable_mem_efficient is True:
return _select_sdp_cuda(head_dim)
if g_enable_flash is True:
return "flash_attn"
return "mem_efficient"
def flash_attention( def flash_attention(
...@@ -84,51 +189,81 @@ def flash_attention( ...@@ -84,51 +189,81 @@ def flash_attention(
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False) output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
print(output) print(output)
""" """
if in_dynamic_mode(): head_dim = query.shape[3]
(result_attention, result_softmax,) = _C_ops.flash_attn( sdp_func_name = _select_sdp(head_dim)
query,
key, if sdp_func_name == "flash_attn":
value, if in_dynamic_mode():
fixed_seed_offset, (result_attention, result_softmax,) = _C_ops.flash_attn(
dropout, query,
causal, key,
return_softmax, value,
not training, fixed_seed_offset,
rng_name, dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(paddle.float32)
seed_offset = helper.create_variable_for_type_inference(paddle.int64)
inputs = {
'q': query,
'k': key,
'v': value,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn',
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
) )
return result_attention, result_softmax if return_softmax else None return out, softmax if return_softmax else None
else:
if sdp_func_name == "mem_efficient":
from paddle.incubate.nn.memory_efficient_attention import (
memory_efficient_attention,
)
helper = LayerHelper('flash_attn', **locals()) output = memory_efficient_attention(
dtype = helper.input_dtype(input_param_name='q') query,
out = helper.create_variable_for_type_inference(dtype) key,
softmax = helper.create_variable_for_type_inference(dtype) value,
softmax_lse = helper.create_variable_for_type_inference(paddle.float32) attn_bias=None,
seed_offset = helper.create_variable_for_type_inference(paddle.int64) p=dropout,
inputs = { scale=None,
'q': query, training=training,
'k': key, )
'v': value, return output, None
'fixed_seed_offset': fixed_seed_offset, else:
} return _math_attention(
outputs = { query,
'out': out, key,
'softmax': softmax, value,
'softmax_lse': softmax_lse, dropout_rate=dropout,
'seed_offset': seed_offset, causal=causal,
} return_softmax=return_softmax,
helper.append_op( training=training,
type='flash_attn', )
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
def flash_attn_unpadded( def flash_attn_unpadded(
...@@ -264,3 +399,6 @@ def flash_attn_unpadded( ...@@ -264,3 +399,6 @@ def flash_attn_unpadded(
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
scaled_dot_product_attention = flash_attention
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册