diff --git a/python/paddle/fluid/tests/unittests/test_flash_attention.py b/python/paddle/fluid/tests/unittests/test_flash_attention.py index b96c1e3cabd4f1ceffbe40f5892d9a6ba8e123af..8d6af72c3f4e33d2b6df14f9ae6b633dc2fa04ef 100644 --- a/python/paddle/fluid/tests/unittests/test_flash_attention.py +++ b/python/paddle/fluid/tests/unittests/test_flash_attention.py @@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase): self.dropout = 0.0 self.causal = False self.return_softmax = False + self.use_sdp_kernel = False def test_unpadded(self): print( @@ -189,9 +190,19 @@ class TestFlashAttentionAPI(unittest.TestCase): value, place=self.place, dtype=self.dtype, stop_gradient=False ) - out, _ = flash_attention( - q, k, v, self.dropout, self.causal, self.return_softmax - ) + if self.use_sdp_kernel: + with paddle.nn.functional.sdp_kernel( + enable_math=self.enable_math, + enable_flash=self.enable_flash, + enable_mem_efficient=self.enable_mem_efficient, + ): + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) out_ = attention_naive(q_, k_, v_, self.causal) out.backward() @@ -220,9 +231,24 @@ class TestFlashAttentionAPI(unittest.TestCase): name="v", shape=self.shape, dtype=self.dtype ) - outs, softmax = flash_attention( - qs, ks, vs, self.dropout, self.causal, self.return_softmax - ) + if self.use_sdp_kernel: + with paddle.nn.functional.sdp_kernel( + enable_math=self.enable_math, + enable_flash=self.enable_flash, + enable_mem_efficient=self.enable_mem_efficient, + ): + outs, softmax = flash_attention( + qs, + ks, + vs, + self.dropout, + self.causal, + self.return_softmax, + ) + else: + outs, softmax = flash_attention( + qs, ks, vs, self.dropout, self.causal, self.return_softmax + ) exe = fluid.Executor(self.place) fetches_result = exe.run( @@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI): self.dropout = 0.0 self.causal = False self.return_softmax = False + self.use_sdp_kernel = False class TestFlashAttentionAPITest2(TestFlashAttentionAPI): @@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI): self.dropout = 0.0 self.causal = False self.return_softmax = True + self.use_sdp_kernel = False class TestFlashAttentionAPITest3(TestFlashAttentionAPI): @@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI): self.dropout = 0.0 self.causal = True self.return_softmax = False + self.use_sdp_kernel = False class TestFlashAttentionAPITest4(TestFlashAttentionAPI): @@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI): self.dropout = 0.0 self.causal = False self.return_softmax = False + self.use_sdp_kernel = False + + +class TestMathAttentionAPITest(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = True + self.enable_math = True + self.enable_flash = False + self.enable_mem_efficient = False if __name__ == '__main__': diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 2eabd18d394b5c6ab472ff88ab90607d54f0ec35..448a533b82d5c057559e16f72782e4798d8164a0 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401 from .extension import temporal_shift # noqa: F401 from .sparse_attention import sparse_attention +from .flash_attention import scaled_dot_product_attention +from .flash_attention import sdp_kernel __all__ = [ # noqa 'celu', diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 78c2dd6e7618af285d5188dc1989f87644cdb160..50e79c1ad17f5c7945a6865dd00bf4a748e6faa8 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -13,8 +13,113 @@ # limitations under the License. import paddle +import paddle.nn.functional as F from paddle import _C_ops, in_dynamic_mode from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager + +g_enable_math = None +g_enable_flash = None +g_enable_mem_efficient = None + + +@signature_safe_contextmanager +def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True): + r""" + With the sdp_kernel context manager, different algorithm implementations can + be selected for scaled_dot_product_attention. + """ + global g_enable_math, g_enable_flash, g_enable_mem_efficient + original_enable_math = g_enable_math + original_enable_flash = g_enable_math + original_enable_mem_efficient = g_enable_mem_efficient + + g_enable_math = enable_math + g_enable_flash = enable_flash + g_enable_mem_efficient = enable_mem_efficient + try: + yield + finally: + g_enable_math = original_enable_math + g_enable_flash = original_enable_flash + g_enable_mem_efficient = original_enable_mem_efficient + + +def _math_attention( + query, + key, + value, + dropout_rate=0.0, + causal=False, + return_softmax=False, + training=True, +): + r""" + This is a basic implementation of scaled dot product attention composed of + combinations of fundamental components. + """ + head_dim = query.shape[-1] + query = paddle.transpose(query, [0, 2, 1, 3]) + key = paddle.transpose(key, [0, 2, 1, 3]) + value = paddle.transpose(value, [0, 2, 1, 3]) + product = paddle.matmul( + x=query * (head_dim**-0.5), y=key, transpose_y=True + ) + weights = ( + paddle.incubate.softmax_mask_fuse_upper_triangle(product) + if causal + else F.softmax(product) + ) + if dropout_rate > 0.0: + weights = F.dropout( + weights, dropout_rate, training=training, mode="upscale_in_train" + ) + + out = paddle.matmul(weights, value) + out = paddle.transpose(out, [0, 2, 1, 3]) + return out, weights if return_softmax else None + + +def _select_sdp_cuda(head_dim): + if head_dim < 128: + return "flash_attn" + else: + return "mem_efficient" + + +def _select_sdp(head_dim): + r""" + There are currently three different implementation options available for + scaled dot product attention, and the chosen approach depends on whether it + is determined by the sdp_kernel configuration or specified through input values. + """ + place = paddle.get_device() + # not use sdp_kernel + if g_enable_flash is None: + if "gpu" not in place: + return "math" + else: + return _select_sdp_cuda(head_dim) + + if ( + g_enable_math is False + and g_enable_flash is False + and g_enable_mem_efficient is False + ): + raise AssertionError( + "No available backend for scaled_dot_product_attention was found." + ) + + if g_enable_math is True: + if g_enable_flash is False and g_enable_mem_efficient is False: + return "math" + if "gpu" not in place: + return "math" + if g_enable_flash is True and g_enable_mem_efficient is True: + return _select_sdp_cuda(head_dim) + if g_enable_flash is True: + return "flash_attn" + return "mem_efficient" def flash_attention( @@ -84,51 +189,81 @@ def flash_attention( output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False) print(output) """ - if in_dynamic_mode(): - (result_attention, result_softmax,) = _C_ops.flash_attn( - query, - key, - value, - fixed_seed_offset, - dropout, - causal, - return_softmax, - not training, - rng_name, + head_dim = query.shape[3] + sdp_func_name = _select_sdp(head_dim) + + if sdp_func_name == "flash_attn": + if in_dynamic_mode(): + (result_attention, result_softmax,) = _C_ops.flash_attn( + query, + key, + value, + fixed_seed_offset, + dropout, + causal, + return_softmax, + not training, + rng_name, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn', **locals()) + dtype = helper.input_dtype(input_param_name='q') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'q': query, + 'k': key, + 'v': value, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, ) - return result_attention, result_softmax if return_softmax else None + return out, softmax if return_softmax else None + else: + if sdp_func_name == "mem_efficient": + from paddle.incubate.nn.memory_efficient_attention import ( + memory_efficient_attention, + ) - helper = LayerHelper('flash_attn', **locals()) - dtype = helper.input_dtype(input_param_name='q') - out = helper.create_variable_for_type_inference(dtype) - softmax = helper.create_variable_for_type_inference(dtype) - softmax_lse = helper.create_variable_for_type_inference(paddle.float32) - seed_offset = helper.create_variable_for_type_inference(paddle.int64) - inputs = { - 'q': query, - 'k': key, - 'v': value, - 'fixed_seed_offset': fixed_seed_offset, - } - outputs = { - 'out': out, - 'softmax': softmax, - 'softmax_lse': softmax_lse, - 'seed_offset': seed_offset, - } - helper.append_op( - type='flash_attn', - inputs=inputs, - outputs=outputs, - attrs={ - 'dropout': dropout, - 'causal': causal, - 'return_softmax': return_softmax, - 'is_test': not training, - 'rng_name': rng_name, - }, - ) - return out, softmax if return_softmax else None + output = memory_efficient_attention( + query, + key, + value, + attn_bias=None, + p=dropout, + scale=None, + training=training, + ) + return output, None + else: + return _math_attention( + query, + key, + value, + dropout_rate=dropout, + causal=causal, + return_softmax=return_softmax, + training=training, + ) def flash_attn_unpadded( @@ -264,3 +399,6 @@ def flash_attn_unpadded( }, ) return out, softmax if return_softmax else None + + +scaled_dot_product_attention = flash_attention