diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 6108fbb5081ab6f5829e66bab8bbab5b92a2bf97..5921d4514d7c6aad0d80262bec495bac99fa0702 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -518,6 +518,18 @@ param : [q, k, v] kernel : func : flash_attn_grad + data_type: q + +- backward_op : flash_attn_raw_grad + forward : flash_attn_raw (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) + output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) + infer_meta : + func : FlashAttnGradInferMeta + param : [q, k, v] + kernel : + func : flash_attn_raw_grad + data_type: q - backward_op : flip_grad forward : flip (Tensor x, int[] axis) -> Tensor(out) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index a59f47d011248e3cb113b7ce0e9e73a606c3736b..fa14ab29d2a85bd6e42143fdb1b3cf2fb9cc732f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -500,8 +500,20 @@ param : [q, k, v] kernel : func : flash_attn + data_type : q backward : flash_attn_grad +- op : flash_attn_raw + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) + output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + infer_meta : + func : FlashAttnInferMeta + param : [q, k, v] + kernel : + func : flash_attn_raw + data_type : q + backward : flash_attn_raw_grad + - op : flip args : (Tensor x, int[] axis) output : Tensor (out) diff --git a/paddle/phi/kernels/flash_attn_grad_kernel.h b/paddle/phi/kernels/flash_attn_grad_kernel.h index 92ec093b27a4bbd887842cf93d3a90823a557608..d22ddb0ef18406b591efc1998d4d7ed146657a6f 100644 --- a/paddle/phi/kernels/flash_attn_grad_kernel.h +++ b/paddle/phi/kernels/flash_attn_grad_kernel.h @@ -19,6 +19,26 @@ namespace phi { +template +void FlashAttnRawGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv); + template void FlashAttnGradKernel(const Context& ctx, const DenseTensor& q, diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index 6a633d13b249972649d8bd9e675190e27fa3d8cd..dd6db04d45cd5b243603dc2911b863c8007c3453 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -19,6 +19,24 @@ namespace phi { +template +void FlashAttnRawKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax_lse, + DenseTensor* softmax, + DenseTensor* seed_offset); + template void FlashAttnKernel(const Context& ctx, const DenseTensor& q, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 127d51562e56bf8af247e3d7050919449f50f304..038557d9feb211cd1db0fa1be901f9faa9cc8dd1 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -28,19 +28,24 @@ namespace phi { template -void FlashAttnGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const DenseTensor& dout, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv) { +void FlashAttnRawGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(dq); ctx.template Alloc(dk); @@ -49,36 +54,16 @@ void FlashAttnGradKernel(const Context& ctx, cudaStream_t stream = ctx.stream(); bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; - // q,k,v [batch_size, seq_len, num_heads, head_dim] + // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); - int64_t batch_size = dims[0]; - int64_t seq_len_q = dims[1]; - int64_t num_heads = dims[2]; - int64_t head_size = dims[3]; - - int64_t seq_len_k = k.dims()[1]; + int64_t total_q = dims[0]; + int64_t num_heads = dims[1]; + int64_t head_size = dims[2]; - int64_t total_q = batch_size * seq_len_q; - int64_t total_k = batch_size * seq_len_k; + int64_t total_k = k.dims()[0]; + int64_t batch_size = cu_seqlens_q.numel() - 1; - DenseTensor q_t_s = - Reshape(ctx, q, {total_q, num_heads, head_size}); - DenseTensor k_t_s = - Reshape(ctx, k, {total_k, num_heads, head_size}); - DenseTensor v_t_s = - Reshape(ctx, v, {total_k, num_heads, head_size}); - - // q,k,v [total_*, num_heads, head_dim] - - DenseTensor cu_seqlens_q; - DenseTensor cu_seqlens_k; - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); - - float scale = 1.0f / std::sqrt(head_size); int num_splits = 0; // 0 for an internal heuristic, which is optimal bool zero_tensors = false; @@ -87,15 +72,16 @@ void FlashAttnGradKernel(const Context& ctx, uint64_t seed = seed_offset_vec[0]; uint64_t offset = seed_offset_vec[1]; + int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); uint64_t workspace_size; // calculate workspace size before execution bool succ = phi::dynload::flash_attn_bwd( - q_t_s.data(), - k_t_s.data(), - v_t_s.data(), + q.data(), + k.data(), + v.data(), dq->data(), dk->data(), dv->data(), @@ -108,8 +94,8 @@ void FlashAttnGradKernel(const Context& ctx, batch_size, num_heads, head_size, - seq_len_q, - seq_len_k, + max_seqlen_q, + max_seqlen_k, dropout, scale, zero_tensors, @@ -134,9 +120,9 @@ void FlashAttnGradKernel(const Context& ctx, } succ = phi::dynload::flash_attn_bwd( - q_t_s.data(), - k_t_s.data(), - v_t_s.data(), + q.data(), + k.data(), + v.data(), dq->data(), dk->data(), dv->data(), @@ -149,8 +135,8 @@ void FlashAttnGradKernel(const Context& ctx, batch_size, num_heads, head_size, - seq_len_q, - seq_len_k, + max_seqlen_q, + max_seqlen_k, dropout, scale, zero_tensors, @@ -172,8 +158,83 @@ void FlashAttnGradKernel(const Context& ctx, #endif } +template +void FlashAttnGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN + // q,k,v [batch_size, seq_len, num_heads, head_dim] + + auto dims = q.dims(); + int64_t batch_size = dims[0]; + int64_t seq_len_q = dims[1]; + int64_t num_heads = dims[2]; + int64_t head_size = dims[3]; + + int64_t seq_len_k = k.dims()[1]; + + int64_t total_q = batch_size * seq_len_q; + int64_t total_k = batch_size * seq_len_k; + + float scale = 1.0f / std::sqrt(head_size); + + DenseTensor q_t_s = + Reshape(ctx, q, {total_q, num_heads, head_size}); + DenseTensor k_t_s = + Reshape(ctx, k, {total_k, num_heads, head_size}); + DenseTensor v_t_s = + Reshape(ctx, v, {total_k, num_heads, head_size}); + + DenseTensor cu_seqlens_q; + DenseTensor cu_seqlens_k; + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); + + FlashAttnRawGradKernel(ctx, + q_t_s, + k_t_s, + v_t_s, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + dout, + seq_len_q, + seq_len_k, + scale, + dropout, + causal, + dq, + dk, + dv); + +#endif +} + } // namespace phi +PD_REGISTER_KERNEL(flash_attn_raw_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnRawGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 19079a3573f4ed78612ed5d7c05d6fbc8007f07f..ef8bd2a98d15e11d9b68596c7f0a0e72570d1b57 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -30,53 +31,44 @@ namespace phi { template -void FlashAttnKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - float dropout, - bool causal, - bool return_softmax, - DenseTensor* out, - DenseTensor* softmax_lse, - DenseTensor* softmax, - DenseTensor* seed_offset) { +void FlashAttnRawKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax_lse, + DenseTensor* softmax, + DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; - // q,k,v [batch_size, seq_len, num_heads, head_dim] + // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); - int64_t batch_size = dims[0]; - int64_t seq_len_q = dims[1]; - int64_t num_heads = dims[2]; - int64_t head_size = dims[3]; + PADDLE_ENFORCE_EQ( + dims.size(), + 3, + phi::errors::InvalidArgument("flash_attn_raw receive input with dim " + "[total_seq_len, num_heads, head_dim]")); - int64_t seq_len_k = k.dims()[1]; + int64_t total_q = dims[0]; + int64_t num_heads = dims[1]; + int64_t head_size = dims[2]; - int64_t total_q = batch_size * seq_len_q; - int64_t total_k = batch_size * seq_len_k; + int64_t total_k = k.dims()[0]; + int64_t batch_size = cu_seqlens_q.numel() - 1; - DenseTensor q_t_s = - Reshape(ctx, q, {total_q, num_heads, head_size}); - DenseTensor k_t_s = - Reshape(ctx, k, {total_k, num_heads, head_size}); - DenseTensor v_t_s = - Reshape(ctx, v, {total_k, num_heads, head_size}); - - // q,k,v [total_*, num_heads, head_dim] - - DenseTensor cu_seqlens_q; - DenseTensor cu_seqlens_k; - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); - - float scale = 1.0f / std::sqrt(head_size); int num_splits = 0; // 0 for an internal heuristic, which is optimal bool zero_tensors = false; @@ -89,27 +81,33 @@ void FlashAttnKernel(const Context& ctx, std::vector seed_offset_vec{int64_t(seed), int64_t(offset)}; phi::TensorFromVector(seed_offset_vec, ctx, seed_offset); + int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; + softmax_lse->Resize({batch_size, num_heads, seq_len_q}); ctx.template Alloc(softmax_lse); if (return_softmax) { - // may allocate more space than *seq_len_k* + // may allocate more space than *max_seqlen_k* int64_t blocksize_c = head_size > 64 ? 128 : 256; - int64_t max_len_k_ = - ((seq_len_k + blocksize_c - 1) / blocksize_c) * blocksize_c; - int64_t max_len_k = - seq_len_k <= 128 ? 128 : (seq_len_k <= 256 ? 256 : max_len_k_); - softmax->Resize({batch_size, num_heads, seq_len_q, max_len_k}); + int64_t seq_len_k = + ((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c; + if (max_seqlen_k <= 128) { + seq_len_k = 128; + } else if (max_seqlen_k <= 256) { + seq_len_k = 256; + } + softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k}); ctx.template Alloc(softmax); } uint64_t workspace_size; + // TODO(kuizhiqing) pass allocation/empty func in capi to decouple // calculate workspace size before execution bool succ = - phi::dynload::flash_attn_fwd(q_t_s.data(), - k_t_s.data(), - v_t_s.data(), + phi::dynload::flash_attn_fwd(q.data(), + k.data(), + v.data(), nullptr, // for calculation workspace size cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -118,8 +116,8 @@ void FlashAttnKernel(const Context& ctx, batch_size, num_heads, head_size, - seq_len_q, - seq_len_k, + max_seqlen_q, + max_seqlen_k, dropout, scale, zero_tensors, @@ -144,9 +142,9 @@ void FlashAttnKernel(const Context& ctx, } succ = phi::dynload::flash_attn_fwd( - q_t_s.data(), - k_t_s.data(), - v_t_s.data(), + q.data(), + k.data(), + v.data(), out->data(), cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -155,8 +153,8 @@ void FlashAttnKernel(const Context& ctx, batch_size, num_heads, head_size, - seq_len_q, - seq_len_k, + max_seqlen_q, + max_seqlen_k, dropout, scale, zero_tensors, @@ -178,8 +176,83 @@ void FlashAttnKernel(const Context& ctx, #endif } +template +void FlashAttnKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax_lse, + DenseTensor* softmax, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + // q,k,v [batch_size, seq_len, num_heads, head_dim] + + auto dims = q.dims(); + PADDLE_ENFORCE_EQ(dims.size(), + 4, + phi::errors::InvalidArgument( + "flash_attn receive input with dim " + "[batch_size, seq_len, num_heads, head_dim]")); + + int64_t batch_size = dims[0]; + int64_t seq_len_q = dims[1]; + int64_t num_heads = dims[2]; + int64_t head_size = dims[3]; + + int64_t seq_len_k = k.dims()[1]; + + int64_t total_q = batch_size * seq_len_q; + int64_t total_k = batch_size * seq_len_k; + + float scale = 1.0f / std::sqrt(head_size); + + DenseTensor q_t_s = + Reshape(ctx, q, {total_q, num_heads, head_size}); + DenseTensor k_t_s = + Reshape(ctx, k, {total_k, num_heads, head_size}); + DenseTensor v_t_s = + Reshape(ctx, v, {total_k, num_heads, head_size}); + + DenseTensor cu_seqlens_q; + DenseTensor cu_seqlens_k; + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); + + FlashAttnRawKernel(ctx, + q_t_s, + k_t_s, + v_t_s, + cu_seqlens_q, + cu_seqlens_k, + seq_len_q, + seq_len_k, + scale, + dropout, + causal, + return_softmax, + out, + softmax_lse, + softmax, + seed_offset); + +#endif +} + } // namespace phi +PD_REGISTER_KERNEL(flash_attn_raw, + GPU, + ALL_LAYOUT, + phi::FlashAttnRawKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} + PD_REGISTER_KERNEL(flash_attn, GPU, ALL_LAYOUT, diff --git a/python/paddle/fluid/tests/unittests/test_flash_attention.py b/python/paddle/fluid/tests/unittests/test_flash_attention.py index 1b3593c74a4ef8511f4476dd40d102453dec5edf..223a17c797b7225803b8037e8316fbb50f213361 100644 --- a/python/paddle/fluid/tests/unittests/test_flash_attention.py +++ b/python/paddle/fluid/tests/unittests/test_flash_attention.py @@ -61,12 +61,61 @@ class TestFlashAttentionAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 128, 8, 16) - self.blocksize = 2 self.dtype = 'float16' self.dropout = 0.0 self.causal = False self.return_softmax = False + def test_raw(self): + print( + f"Test Raw case shape {self.shape} dtype {self.dtype} causal {self.causal}" + ) + + paddle.disable_static() + + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out_ = attention_naive(q_, q_, q_, self.causal) + + scale = 1.0 / np.sqrt(q.shape[-1]) + + bs = self.shape[0] + ms = self.shape[1] + nh = self.shape[2] + hd = self.shape[3] + cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32') + + qq = paddle.reshape(q, [bs * ms, nh, hd]) + out, _, _, _ = paddle._C_ops.flash_attn_raw( + qq, + qq, + qq, + cu_q, + cu_q, + ms, + ms, + scale, + self.dropout, + self.causal, + self.return_softmax, + ) + out_ = paddle.reshape(out_, [bs * ms, nh, hd]) + + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + + out.backward() + out_.backward() + + np.testing.assert_allclose( + q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 + ) + def test_all(self): print( f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" @@ -152,7 +201,6 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 128, 8, 16) - self.blocksize = 2 self.dtype = paddle.float16 self.dropout = 0.0 self.causal = False @@ -163,7 +211,6 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 256, 8, 16) - self.blocksize = 2 self.dtype = paddle.float16 self.dropout = 0.0 self.causal = False @@ -174,7 +221,6 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 512, 8, 16) - self.blocksize = 2 self.dtype = paddle.float16 self.dropout = 0.0 self.causal = True @@ -185,7 +231,6 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (8, 1024, 16, 128) - self.blocksize = 2 self.dtype = paddle.float16 self.dropout = 0.0 self.causal = False