未验证 提交 f951832d 编写于 作者: C Chitsing KUI 提交者: GitHub

add flashattn raw kernel (#51383)

上级 3f4917f6
......@@ -518,6 +518,18 @@
param : [q, k, v]
kernel :
func : flash_attn_grad
data_type: q
- backward_op : flash_attn_raw_grad
forward : flash_attn_raw (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_raw_grad
data_type: q
- backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
......
......@@ -500,8 +500,20 @@
param : [q, k, v]
kernel :
func : flash_attn
data_type : q
backward : flash_attn_grad
- op : flash_attn_raw
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn_raw
data_type : q
backward : flash_attn_raw_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor (out)
......
......@@ -19,6 +19,26 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnRawGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
......
......@@ -19,6 +19,24 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnRawKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
......
......@@ -28,14 +28,19 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
void FlashAttnRawGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
......@@ -49,36 +54,16 @@ void FlashAttnGradKernel(const Context& ctx,
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [batch_size, seq_len, num_heads, head_dim]
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
// q,k,v [total_*, num_heads, head_dim]
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
float scale = 1.0f / std::sqrt(head_size);
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
......@@ -87,15 +72,16 @@ void FlashAttnGradKernel(const Context& ctx,
uint64_t seed = seed_offset_vec[0];
uint64_t offset = seed_offset_vec[1];
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd(
q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
......@@ -108,8 +94,8 @@ void FlashAttnGradKernel(const Context& ctx,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
......@@ -134,9 +120,9 @@ void FlashAttnGradKernel(const Context& ctx,
}
succ = phi::dynload::flash_attn_bwd(
q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
......@@ -149,8 +135,8 @@ void FlashAttnGradKernel(const Context& ctx,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
......@@ -172,8 +158,83 @@ void FlashAttnGradKernel(const Context& ctx,
#endif
}
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnRawGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_raw_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnRawGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::CPU); // seed_offset
}
PD_REGISTER_KERNEL(flash_attn_grad,
GPU,
ALL_LAYOUT,
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -30,10 +31,15 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
void FlashAttnRawKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
......@@ -47,36 +53,22 @@ void FlashAttnKernel(const Context& ctx,
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [batch_size, seq_len, num_heads, head_dim]
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
PADDLE_ENFORCE_EQ(
dims.size(),
3,
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
int64_t seq_len_k = k.dims()[1];
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
// q,k,v [total_*, num_heads, head_dim]
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
float scale = 1.0f / std::sqrt(head_size);
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
......@@ -89,27 +81,33 @@ void FlashAttnKernel(const Context& ctx,
std::vector<int64_t> seed_offset_vec{int64_t(seed), int64_t(offset)};
phi::TensorFromVector<int64_t>(seed_offset_vec, ctx, seed_offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *seq_len_k*
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t max_len_k_ =
((seq_len_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
int64_t max_len_k =
seq_len_k <= 128 ? 128 : (seq_len_k <= 256 ? 256 : max_len_k_);
softmax->Resize({batch_size, num_heads, seq_len_q, max_len_k});
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ =
phi::dynload::flash_attn_fwd(q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
phi::dynload::flash_attn_fwd(q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
......@@ -118,8 +116,8 @@ void FlashAttnKernel(const Context& ctx,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
......@@ -144,9 +142,9 @@ void FlashAttnKernel(const Context& ctx,
}
succ = phi::dynload::flash_attn_fwd(
q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
q.data(),
k.data(),
v.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
......@@ -155,8 +153,8 @@ void FlashAttnKernel(const Context& ctx,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
......@@ -178,8 +176,83 @@ void FlashAttnKernel(const Context& ctx,
#endif
}
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
PADDLE_ENFORCE_EQ(dims.size(),
4,
phi::errors::InvalidArgument(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnRawKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
out,
softmax_lse,
softmax,
seed_offset);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_raw,
GPU,
ALL_LAYOUT,
phi::FlashAttnRawKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(flash_attn,
GPU,
ALL_LAYOUT,
......
......@@ -61,12 +61,61 @@ class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 16)
self.blocksize = 2
self.dtype = 'float16'
self.dropout = 0.0
self.causal = False
self.return_softmax = False
def test_raw(self):
print(
f"Test Raw case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
paddle.disable_static()
query = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
out_ = attention_naive(q_, q_, q_, self.causal)
scale = 1.0 / np.sqrt(q.shape[-1])
bs = self.shape[0]
ms = self.shape[1]
nh = self.shape[2]
hd = self.shape[3]
cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32')
qq = paddle.reshape(q, [bs * ms, nh, hd])
out, _, _, _ = paddle._C_ops.flash_attn_raw(
qq,
qq,
qq,
cu_q,
cu_q,
ms,
ms,
scale,
self.dropout,
self.causal,
self.return_softmax,
)
out_ = paddle.reshape(out_, [bs * ms, nh, hd])
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
out.backward()
out_.backward()
np.testing.assert_allclose(
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)
def test_all(self):
print(
f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}"
......@@ -152,7 +201,6 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 16)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
......@@ -163,7 +211,6 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 256, 8, 16)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
......@@ -174,7 +221,6 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 512, 8, 16)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = True
......@@ -185,7 +231,6 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册