未验证 提交 42e0c6b8 编写于 作者: Y yin wei 提交者: GitHub

Add attn_mask supported for FlashAttnKernel. (#55969)

* add mask

* add backword

* add enforce info

* update scale

* integrate code

* update enforce

* add enforce eq

* add error type

* update enforce

* add test_flash_attention

* Polish codes and fix compiling errors.

* Set num_splits to 0 for flash-attn with tensor mask.

* Fix the compiling error for non flash-attn case.

---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 0434b828
...@@ -818,8 +818,9 @@ ...@@ -818,8 +818,9 @@
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad - backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
func : FlashAttnGradInferMeta func : FlashAttnGradInferMeta
...@@ -829,8 +830,9 @@ ...@@ -829,8 +830,9 @@
data_type: q data_type: q
- backward_op : flash_attn_unpadded_grad - backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
func : FlashAttnGradInferMeta func : FlashAttnGradInferMeta
......
...@@ -910,9 +910,9 @@ ...@@ -910,9 +910,9 @@
backward : fill_diagonal_tensor_grad backward : fill_diagonal_tensor_grad
- op : flash_attn - op : flash_attn
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset optional : fixed_seed_offset, attn_mask
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
...@@ -923,9 +923,9 @@ ...@@ -923,9 +923,9 @@
backward : flash_attn_grad backward : flash_attn_grad
- op : flash_attn_unpadded - op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset optional : fixed_seed_offset , attn_mask
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
......
...@@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& softmax_lse, const DenseTensor& softmax_lse,
const DenseTensor& seed_offset, const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout, const DenseTensor& dout,
int64_t max_seqlen_q, int64_t max_seqlen_q,
int64_t max_seqlen_k, int64_t max_seqlen_k,
...@@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& softmax_lse, const DenseTensor& softmax_lse,
const DenseTensor& seed_offset, const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout, const DenseTensor& dout,
float dropout, float dropout,
bool causal, bool causal,
......
...@@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel( ...@@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel(
const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset, const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q, int64_t max_seqlen_q,
int64_t max_seqlen_k, int64_t max_seqlen_k,
float scale, float scale,
...@@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset, const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
......
...@@ -21,17 +21,160 @@ ...@@ -21,17 +21,160 @@
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif #include "paddle/phi/kernels/reshape_kernel.h"
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedGradImpl(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
const cudaStream_t stream = ctx.stream();
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
"attn_mask is not nullptr, causal can not be true"));
PADDLE_ENFORCE_EQ(
head_size == 32 || head_size == 64 || head_size == 128,
true,
phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
"64, or 128, but recieved %d.",
head_size));
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset;
int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seqlen_q});
const DenseTensor* attn_mask_tensor = attn_mask.get_ptr();
std::vector<int64_t> mask_dims = GetAttnMaskDims(attn_mask_tensor);
int fa_num_splits = 0;
bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16;
float fa_with_mask_scale = 1.0f;
bool fa_zero_tensors = false;
uint64_t workspace_size;
bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()),
static_cast<void*>(dk->data()),
static_cast<void*>(dv->data()),
nullptr, // set out to nullptr to calculate workspace size
dout.data(),
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
static_cast<const void*>(softmax_lse.data()),
static_cast<void*>(dsoftmax.data()),
nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(
ctx, {static_cast<int64_t>(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()),
static_cast<void*>(dk->data()),
static_cast<void*>(dv->data()),
out.data(), // set out to nullptr to calculate workspace size
dout.data(),
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
static_cast<const void*>(softmax_lse.data()),
static_cast<void*>(dsoftmax.data()),
nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
int64_t q_size = total_q * num_heads * head_size;
ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>());
#else
RaiseNotSupportedError();
#endif
}
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedGradKernel(const Context& ctx, void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
...@@ -42,6 +185,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -42,6 +185,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& softmax_lse, const DenseTensor& softmax_lse,
const DenseTensor& seed_offset, const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout, const DenseTensor& dout,
int64_t max_seqlen_q, int64_t max_seqlen_q,
int64_t max_seqlen_k, int64_t max_seqlen_k,
...@@ -59,86 +203,102 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -59,86 +203,102 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const cudaStream_t stream = ctx.stream(); const cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim] // q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims(); auto dims = q.dims();
const int64_t total_q = dims[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
const bool zero_tensors = false;
// TODO(umiswing): add shape check if (attn_mask.get_ptr()) {
PADDLE_ENFORCE_EQ( FlashAttnUnpaddedGradImpl<T, Context>(ctx,
head_size_og, q,
head_size, k,
phi::errors::InvalidArgument( v,
"flash_attn_bwd receive input with head_size_og == head_size")); cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
attn_mask,
dout,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
dq,
dk,
dv);
} else {
const int64_t total_q = dims[0];
const int64_t batch_size = cu_seqlens_q.numel() - 1;
const int64_t num_heads = dims[1];
const int64_t head_size_og = dout.dims()[2];
const int64_t head_size = dims[2];
const int64_t total_k = k.dims()[0];
const int64_t num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
FlashAttnBwdParamsV2 params = FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx, FlashAttnBwdParamsV2(ctx,
batch_size, batch_size,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
num_heads, num_heads,
num_heads_k, num_heads_k,
head_size, head_size,
dropout, dropout,
scale, scale,
causal, causal,
q.dtype(), q.dtype(),
seed_offset.data<int64_t>()); seed_offset.data<int64_t>());
VLOG(4) << "FlashAttn bwd seed: " << params.seed VLOG(10) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset; << ", offset: " << params.offset;
const bool succ = bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(), phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(), q.data(),
k.data(), k.data(),
v.data(), v.data(),
out.data(), out.data(),
params.softmax_d.data(), params.softmax_d.data(),
softmax_lse.data(), softmax_lse.data(),
cu_seqlens_q.data<int32_t>(), cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(), cu_seqlens_k.data<int32_t>(),
params.rng_state.data(), params.rng_state.data(),
dq->data(), dq->data(),
dk->data(), dk->data(),
dv->data(), dv->data(),
params.dq_accum.data(), params.dq_accum.data(),
params.batch_size, params.batch_size,
params.max_seqlen_q, params.max_seqlen_q,
params.max_seqlen_k, params.max_seqlen_k,
params.seqlen_q_rounded, params.seqlen_q_rounded,
params.seqlen_k_rounded, params.seqlen_k_rounded,
params.num_heads, params.num_heads,
params.num_heads_k, params.num_heads_k,
params.head_size, params.head_size,
params.head_size_rounded, params.head_size_rounded,
params.dropout, params.dropout,
params.scale, params.scale,
params.causal, params.causal,
params.is_bf16, params.is_bf16,
stream, stream,
params.seed, params.seed,
params.offset); params.offset);
CheckFlashAttnStatus(succ);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
} }
#else #else
PADDLE_THROW(phi::errors::Unimplemented( RaiseNotSupportedError();
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
...@@ -150,6 +310,7 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -150,6 +310,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& softmax_lse, const DenseTensor& softmax_lse,
const DenseTensor& seed_offset, const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout, const DenseTensor& dout,
float dropout, float dropout,
bool causal, bool causal,
...@@ -159,14 +320,17 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -159,14 +320,17 @@ void FlashAttnGradKernel(const Context& ctx,
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim] // q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims(); const auto& dims = q.dims();
const int batch_size = dims[0]; const int64_t batch_size = dims[0];
const int seqlen_q = dims[1]; const int64_t seqlen_q = dims[1];
const int num_heads = dims[2]; const int64_t num_heads = dims[2];
const int head_size_og = dout.dims()[3]; const int64_t head_size_og = dout.dims()[3];
const int head_size = dims[3]; const int64_t head_size = dims[3];
const int seqlen_k = k.dims()[1]; const int64_t seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2]; const int64_t num_heads_k = k.dims()[2];
const int64_t total_q = batch_size * seqlen_q;
const int64_t total_k = batch_size * seqlen_k;
// TODO(umiswing): add shape check // TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -175,71 +339,98 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -175,71 +339,98 @@ void FlashAttnGradKernel(const Context& ctx,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size")); "flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]"; << "], v[" << v.dims() << "]";
const float scale = 1.0f / std::sqrt(head_size); const float scale = 1.0f / std::sqrt(head_size);
if (attn_mask.get_ptr()) {
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
FlashAttnBwdParamsV2 params = DenseTensor cu_seqlens_q;
FlashAttnBwdParamsV2(ctx, DenseTensor cu_seqlens_k;
batch_size, ArangeNullaryKernel<int32_t, Context>(
seqlen_q, ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q);
seqlen_k, ArangeNullaryKernel<int32_t, Context>(
num_heads, ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k);
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
ctx.template Alloc<T>(dq); FlashAttnUnpaddedGradKernel<T, Context>(ctx,
ctx.template Alloc<T>(dk); q_t_s,
ctx.template Alloc<T>(dv); k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
attn_mask,
dout,
seqlen_q,
seqlen_k,
scale,
dropout,
causal,
dq,
dk,
dv);
} else {
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
cudaStream_t stream = ctx.stream(); ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
VLOG(4) << "FlashAttn bwd seed: " << params.seed ctx.template Alloc<T>(dv);
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ( cudaStream_t stream = ctx.stream();
succ,
true, VLOG(10) << "FlashAttn bwd seed: " << params.seed
phi::errors::External("Error in Flash-Attention-2, detail information is", << ", offset: " << params.offset;
phi::dynload::flash_attn_error()));
bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
}
#else #else
PADDLE_THROW(phi::errors::Unimplemented( RaiseNotSupportedError();
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
......
...@@ -15,27 +15,21 @@ ...@@ -15,27 +15,21 @@
#include "paddle/phi/kernels/flash_attn_kernel.h" #include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG() #include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif #include "paddle/phi/kernels/reshape_kernel.h"
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel( void FlashAttnWithMaskUnpaddedImpl(
const Context& ctx, const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
...@@ -43,6 +37,7 @@ void FlashAttnUnpaddedKernel( ...@@ -43,6 +37,7 @@ void FlashAttnUnpaddedKernel(
const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset, const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q, int64_t max_seqlen_q,
int64_t max_seqlen_k, int64_t max_seqlen_k,
float scale, float scale,
...@@ -56,13 +51,174 @@ void FlashAttnUnpaddedKernel( ...@@ -56,13 +51,174 @@ void FlashAttnUnpaddedKernel(
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
DenseTensor* seed_offset) { DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
cudaStream_t stream = ctx.stream();
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
"attn_mask is not nullptr, causal can not be true"));
PADDLE_ENFORCE_EQ(
head_size == 32 || head_size == 64 || head_size == 128,
true,
phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
"64, or 128, but recieved %d.",
head_size));
// Generate random state for dropout and save for recompute in grad.
auto seed_offset_pair =
GenerateRNGState(ctx, fixed_seed_offset, rng_name, batch_size, num_heads);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(10) << "FlashAttn fwd seed: " << seed << ", offset: " << offset;
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
// Allocate memory for softmax_lse and softmax.
int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seqlen_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seqlen_k = 128;
} else if (max_seqlen_k <= 256) {
seqlen_k = 256;
}
softmax->Resize({batch_size, num_heads, seqlen_q, seqlen_k});
ctx.template Alloc<T>(softmax);
}
// Compute scale Q
int64_t q_size = total_q * num_heads * head_size;
DenseTensor scaled_q = Empty<T>(ctx, {total_q, num_heads, head_size});
ComputeScaleQ(ctx, q_size, scale, q.data<T>(), scaled_q.data<T>());
const DenseTensor* attn_mask_tensor = attn_mask.get_ptr();
std::vector<int64_t> mask_dims = GetAttnMaskDims(attn_mask_tensor);
int fa_num_splits = 0;
bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16;
float fa_with_mask_scale = 1.0f;
bool fa_zero_tensors = false;
uint64_t workspace_size = 0;
bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask(
static_cast<const void*>(scaled_q.data()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
nullptr, // for calculation workspace size
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
softmax_lse->data(),
nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(
ctx, {static_cast<int64_t>(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd_with_bias_and_mask(
static_cast<const void*>(scaled_q.data()),
k.data(),
v.data(),
out->data(), // set out to nullptr to calculate workspace size
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
softmax_lse->data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
#else
RaiseNotSupportedError();
#endif
}
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim] // q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims(); auto dims = q.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dims.size(), dims.size(),
...@@ -70,79 +226,97 @@ void FlashAttnUnpaddedKernel( ...@@ -70,79 +226,97 @@ void FlashAttnUnpaddedKernel(
phi::errors::InvalidArgument("flash_attn_raw receive input with dim " phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]")); "[total_seq_len, num_heads, head_dim]"));
const int64_t total_q = dims[0]; if (attn_mask.get_ptr()) {
const int num_heads = dims[1]; FlashAttnWithMaskUnpaddedImpl<T, Context>(ctx,
const int head_size = dims[2]; q,
k,
const int total_k = k.dims()[0]; v,
const int num_heads_k = k.dims()[1]; cu_seqlens_q,
const int batch_size = cu_seqlens_q.numel() - 1; cu_seqlens_k,
fixed_seed_offset,
// TODO(umiswing): add deterministic in fa2. attn_mask,
// int num_splits = 0; // 0 for an internal heuristic, which is optimal max_seqlen_q,
// if (FLAGS_cudnn_deterministic) { max_seqlen_k,
// num_splits = 1; scale,
// } dropout,
causal,
// TODO(umiswing): add shape check return_softmax,
is_test,
FlashAttnFwdParamsV2<T> params = rng_name,
FlashAttnFwdParamsV2<T>(ctx, out,
batch_size, softmax,
max_seqlen_q, softmax_lse,
max_seqlen_k, seed_offset);
num_heads, } else {
num_heads_k, const int64_t total_q = dims[0];
head_size, const int64_t num_heads = dims[1];
dropout, const int64_t head_size = dims[2];
scale,
causal, const int64_t total_k = k.dims()[0];
return_softmax, const int64_t num_heads_k = k.dims()[1];
q.dtype(), const int64_t batch_size = cu_seqlens_q.numel() - 1;
is_test,
rng_name, // TODO(umiswing): add deterministic in fa2.
fixed_seed_offset.get_ptr(), // int num_splits = 0; // 0 for an internal heuristic, which is optimal
softmax, // if (FLAGS_cudnn_deterministic) {
softmax_lse, // num_splits = 1;
seed_offset); // }
VLOG(4) << "FlashAttn fwd seed: " << params.seed // TODO(umiswing): add shape check
<< ", offset: " << params.offset;
FlashAttnFwdParamsV2<T> params = FlashAttnFwdParamsV2<T>(ctx,
const bool succ = phi::dynload::flash_attn_varlen_fwd( batch_size,
q.data(), max_seqlen_q,
k.data(), max_seqlen_k,
v.data(), num_heads,
cu_seqlens_q.data<int32_t>(), num_heads_k,
cu_seqlens_k.data<int32_t>(), head_size,
params.rng_state.data(), dropout,
out->data(), scale,
params.return_softmax ? softmax->data() : nullptr, causal,
softmax_lse->data(), return_softmax,
params.batch_size, q.dtype(),
params.max_seqlen_q, is_test,
params.max_seqlen_k, rng_name,
params.seqlen_q_rounded, fixed_seed_offset,
params.seqlen_k_rounded, softmax,
params.num_heads, softmax_lse,
params.num_heads_k, seed_offset);
params.head_size,
params.head_size_rounded, VLOG(10) << "FlashAttn fwd seed: " << params.seed
params.dropout, << ", offset: " << params.offset;
params.scale,
params.causal, bool succ = phi::dynload::flash_attn_varlen_fwd(
params.return_softmax, q.data(),
params.is_bf16, k.data(),
stream, v.data(),
params.seed, cu_seqlens_q.data<int32_t>(),
params.offset); cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
if (!succ) { out->data(),
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); params.return_softmax ? softmax->data() : nullptr,
softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
} }
#else #else
PADDLE_THROW(phi::errors::Unimplemented( RaiseNotSupportedError();
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
...@@ -152,6 +326,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -152,6 +326,7 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset, const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
...@@ -163,89 +338,117 @@ void FlashAttnKernel(const Context& ctx, ...@@ -163,89 +338,117 @@ void FlashAttnKernel(const Context& ctx,
DenseTensor* seed_offset) { DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim] // q,k,v [batch_size, seq_len, num_heads, head_dim]
const auto& dims = q.dims();
auto dims = q.dims();
PADDLE_ENFORCE_EQ(dims.size(), PADDLE_ENFORCE_EQ(dims.size(),
4, 4,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"flash_attn receive input with dim " "flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]")); "[batch_size, seq_len, num_heads, head_dim]"));
const int batch_size = dims[0]; const int64_t batch_size = dims[0];
const int seqlen_q = dims[1]; const int64_t seqlen_q = dims[1];
const int num_heads = dims[2]; const int64_t num_heads = dims[2];
const int head_size = dims[3]; const int64_t head_size = dims[3];
const int seqlen_k = k.dims()[1]; const int64_t seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2]; const int64_t num_heads_k = k.dims()[2];
const int64_t total_q = batch_size * seqlen_q;
const int64_t total_k = batch_size * seqlen_k;
// TODO(umiswing): Add check shape // TODO(umiswing): Add check shape
const float scale = 1.0f / std::sqrt(head_size); const float scale = 1.0f / std::sqrt(head_size);
FlashAttnFwdParamsV2<T> params = if (attn_mask.get_ptr()) {
FlashAttnFwdParamsV2<T>(ctx, DenseTensor q_t_s, k_t_s, v_t_s;
batch_size, q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
seqlen_q, k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
seqlen_k, v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
num_heads,
num_heads_k, DenseTensor cu_seqlens_q;
head_size, DenseTensor cu_seqlens_k;
dropout, ArangeNullaryKernel<int32_t, Context>(
scale, ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q);
causal, ArangeNullaryKernel<int32_t, Context>(
return_softmax, ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k);
q.dtype(),
is_test, FlashAttnUnpaddedKernel<T, Context>(ctx,
rng_name, q_t_s,
fixed_seed_offset.get_ptr(), k_t_s,
softmax, v_t_s,
softmax_lse, cu_seqlens_q,
seed_offset); cu_seqlens_k,
fixed_seed_offset,
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() attn_mask,
<< "], v[" << v.dims() << "]"; seqlen_q,
seqlen_k,
ctx.template Alloc<T>(out); scale,
dropout,
cudaStream_t stream = ctx.stream(); causal,
return_softmax,
VLOG(4) << "FlashAttn fwd seed: " << params.seed is_test,
<< ", offset: " << params.offset; rng_name,
out,
bool succ = phi::dynload::flash_attn_fwd( softmax,
q.data(), softmax_lse,
k.data(), seed_offset);
v.data(), } else {
params.rng_state.data(), FlashAttnFwdParamsV2<T> params = FlashAttnFwdParamsV2<T>(ctx,
out->data(), batch_size,
params.return_softmax ? params.softmax->data() : nullptr, seqlen_q,
params.softmax_lse->data(), seqlen_k,
params.batch_size, num_heads,
params.max_seqlen_q, num_heads_k,
params.max_seqlen_k, head_size,
params.seqlen_q_rounded, dropout,
params.seqlen_k_rounded, scale,
params.num_heads, causal,
params.num_heads_k, return_softmax,
params.head_size, q.dtype(),
params.head_size_rounded, is_test,
params.dropout, rng_name,
params.scale, fixed_seed_offset,
params.causal, softmax,
params.return_softmax, softmax_lse,
params.is_bf16, seed_offset);
stream,
params.seed, VLOG(10) << "FlashAttn fwd dims: q[" << q.dims() << "], k[" << k.dims()
params.offset); << "], v[" << v.dims() << "]";
VLOG(10) << "FlashAttn fwd seed: " << params.seed
PADDLE_ENFORCE_EQ( << ", offset: " << params.offset;
succ,
true, ctx.template Alloc<T>(out);
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error())); cudaStream_t stream = ctx.stream();
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
}
#else #else
PADDLE_THROW(phi::errors::Unimplemented( RaiseNotSupportedError();
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
......
...@@ -14,8 +14,43 @@ ...@@ -14,8 +14,43 @@
#pragma once #pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
namespace phi { namespace phi {
#ifdef PADDLE_WITH_FLASHATTN
static std::pair<uint64_t, uint64_t> GenerateRNGState(
const GPUContext& ctx,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const std::string& rng_name,
const int64_t batch_size,
const int64_t num_heads) {
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
uint64_t seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
return std::make_pair(seed, offset);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
return seed_offset_pair;
}
}
template <typename T> template <typename T>
struct FlashAttnFwdParamsV2 { struct FlashAttnFwdParamsV2 {
int batch_size; int batch_size;
...@@ -55,7 +90,7 @@ struct FlashAttnFwdParamsV2 { ...@@ -55,7 +90,7 @@ struct FlashAttnFwdParamsV2 {
const DataType q_dtype, const DataType q_dtype,
const bool is_test, const bool is_test,
const std::string& rng_name, const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr, const paddle::optional<DenseTensor>& fixed_seed_offset,
DenseTensor* _softmax, DenseTensor* _softmax,
DenseTensor* _softmax_lse, DenseTensor* _softmax_lse,
DenseTensor* _seed_offset) DenseTensor* _seed_offset)
...@@ -78,24 +113,11 @@ struct FlashAttnFwdParamsV2 { ...@@ -78,24 +113,11 @@ struct FlashAttnFwdParamsV2 {
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size. // with the same size.
rng_state = Empty<int64_t>(ctx, {2}); rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data = auto seed_offset_pair = GenerateRNGState(
fixed_seed_offset_ptr->data<int64_t>(); ctx, fixed_seed_offset, rng_name, batch_size, num_heads);
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]); seed = seed_offset_pair.first;
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]); offset = seed_offset_pair.second;
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
seed_offset->Resize({2}); seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset); int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
...@@ -178,4 +200,66 @@ struct FlashAttnBwdParamsV2 { ...@@ -178,4 +200,66 @@ struct FlashAttnBwdParamsV2 {
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded}); ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
} }
}; };
static void CheckFlashAttnStatus(const bool status) {
PADDLE_ENFORCE_EQ(status,
true,
phi::errors::External(
"Error in Flash-Attention, detail information is: %s",
phi::dynload::flash_attn_error()));
}
template <typename T>
__global__ void SimleScaleKernel(const T* input,
int64_t numel,
float scale,
T* ouput) {
CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) {
ouput[i] = static_cast<T>(scale * static_cast<float>(input[i]));
}
}
template <typename T, typename Context>
void ComputeScaleQ(
const Context& ctx, int64_t numel, float scale, const T* input, T* output) {
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1);
SimleScaleKernel<<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
ctx.stream()>>>(input, numel, scale, output);
}
static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
std::vector<int64_t> mask_dim_4d;
if (attn_mask) {
const auto& origin_dims = attn_mask->dims();
auto rank = origin_dims.size();
PADDLE_ENFORCE_GE(
rank,
4,
phi::errors::InvalidArgument(
"Teh number of dimenstions of attn_mask is expected to be greater "
"or equal to 4, but recieved %d. The shape of attn_mask is {%s}",
rank,
origin_dims));
int64_t first_dim = 1;
for (int i = 0; i < rank - 3; i++) {
first_dim *= origin_dims[i];
}
mask_dim_4d = {first_dim,
origin_dims[rank - 3],
origin_dims[rank - 2],
origin_dims[rank - 1]};
}
return mask_dim_4d;
}
#endif
static void RaiseNotSupportedError() {
PADDLE_THROW(
phi::errors::Unimplemented("FlashAttention is unsupported, please check "
"the GPU compability and CUDA Version."));
}
} // namespace phi } // namespace phi
...@@ -202,6 +202,7 @@ def flash_attention( ...@@ -202,6 +202,7 @@ def flash_attention(
key, key,
value, value,
fixed_seed_offset, fixed_seed_offset,
None,
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
...@@ -358,6 +359,7 @@ def flash_attn_unpadded( ...@@ -358,6 +359,7 @@ def flash_attn_unpadded(
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
fixed_seed_offset, fixed_seed_offset,
None,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
scale, scale,
...@@ -408,7 +410,13 @@ def flash_attn_unpadded( ...@@ -408,7 +410,13 @@ def flash_attn_unpadded(
def scaled_dot_product_attention( def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
training=True,
): ):
r""" r"""
The equation is: The equation is:
...@@ -442,6 +450,7 @@ def scaled_dot_product_attention( ...@@ -442,6 +450,7 @@ def scaled_dot_product_attention(
not supported yet. not supported yet.
dropout_p(float): The dropout ratio. dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode. is_causal(bool): Whether enable causal mode.
training(bool): Whether it is in the training phase
Returns: Returns:
out(Tensor): The attention tensor. out(Tensor): The attention tensor.
...@@ -458,6 +467,22 @@ def scaled_dot_product_attention( ...@@ -458,6 +467,22 @@ def scaled_dot_product_attention(
>>> print(output) >>> print(output)
>>> # xdoctest: -SKIP >>> # xdoctest: -SKIP
""" """
assert attn_mask is None, "attn_mask is not supported yet" if attn_mask is None:
out, _ = flash_attention(query, key, value, dropout_p, is_causal) out, _ = flash_attention(query, key, value, dropout_p, is_causal)
else:
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out return out
...@@ -57,6 +57,18 @@ def attention_naive(q, k, v, causal=False): ...@@ -57,6 +57,18 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3]) return paddle.transpose(o, [0, 2, 1, 3])
def attention_naive_with_mask(q, k, v, attn_bias):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
p = F.softmax(s + attn_bias)
o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])
is_sm8x = ( is_sm8x = (
core.is_compiled_with_cuda() core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8 and paddle.device.cuda.get_device_capability()[0] == 8
...@@ -296,6 +308,64 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -296,6 +308,64 @@ class TestFlashAttentionAPI(unittest.TestCase):
) )
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"and device's compute capability must be 7.5 or 8.x",
)
class TestFlashAttentionWithMaskAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 32)
self.dtype = 'float16'
self.dropout = 0.0
self.causal = False
def test_dot_scale_product(self):
# test dynamic
paddle.disable_static()
query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
mask = np.random.random(mask_shape)
m = paddle.to_tensor(
mask, place=self.place, dtype=self.dtype, stop_gradient=False
)
out = scaled_dot_product_attention(
q, k, v, m, self.dropout, self.causal
)
out_ = attention_naive_with_mask(q_, k_, v_, m)
out.backward()
out_.backward()
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
class TestFlashAttentionAPITest1(TestFlashAttentionAPI): class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
...@@ -370,5 +440,14 @@ class TestSDPAttentionAPITest(TestFlashAttentionAPI): ...@@ -370,5 +440,14 @@ class TestSDPAttentionAPITest(TestFlashAttentionAPI):
self.enable_mem_efficient = False self.enable_mem_efficient = False
class TestFlashAttrnionWithMaskAPI(TestFlashAttentionWithMaskAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册