diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 1dda103c85c42840f93b9a3dce6a78cfe7b1955d..5d2253790ae398e0a948e1286285a9eb0ec5624d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -818,8 +818,9 @@ inplace : (out_grad -> x_grad) - backward_op : flash_attn_grad - forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) + forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) + optional : attn_mask output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : func : FlashAttnGradInferMeta @@ -829,8 +830,9 @@ data_type: q - backward_op : flash_attn_unpadded_grad - forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) + forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) + optional : attn_mask output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : func : FlashAttnGradInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index db83bacbbd60d0060c08fd3d1fb719126421079f..c5bca9f49206878cea2fa5ecf78400048876259f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -910,9 +910,9 @@ backward : fill_diagonal_tensor_grad - op : flash_attn - args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - optional : fixed_seed_offset + optional : fixed_seed_offset, attn_mask infer_meta : func : FlashAttnInferMeta param : [q, k, v] @@ -923,9 +923,9 @@ backward : flash_attn_grad - op : flash_attn_unpadded - args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - optional : fixed_seed_offset + optional : fixed_seed_offset , attn_mask infer_meta : func : FlashAttnInferMeta param : [q, k, v] diff --git a/paddle/phi/kernels/flash_attn_grad_kernel.h b/paddle/phi/kernels/flash_attn_grad_kernel.h index ba3a6020e4545fb471c61daa5bca8ab8a2b176a7..ef5458f4708ebc64b785384ac9977a0c0ba67912 100644 --- a/paddle/phi/kernels/flash_attn_grad_kernel.h +++ b/paddle/phi/kernels/flash_attn_grad_kernel.h @@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, int64_t max_seqlen_q, int64_t max_seqlen_k, @@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, float dropout, bool causal, diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index 296e24202608753c0eedd7ab6715922e74a974d7..ec72d85a0babbc21d25e1ec6373145212738d3d3 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel( const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_k, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, @@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx, const DenseTensor& k, const DenseTensor& v, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, float dropout, bool causal, bool return_softmax, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 1ae6c887aee027fe0dba0fa700ee841e64be6d12..de479cf9adfd28e75a37dcc2e2be7bdb61c52b6f 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -21,17 +21,160 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" DECLARE_bool(cudnn_deterministic); namespace phi { +template +void FlashAttnUnpaddedGradImpl(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN + const cudaStream_t stream = ctx.stream(); + + auto dims = q.dims(); + int64_t total_q = dims[0]; + int64_t num_heads = dims[1]; + int64_t head_size = dims[2]; + + int64_t total_k = k.dims()[0]; + int64_t batch_size = cu_seqlens_q.numel() - 1; + + PADDLE_ENFORCE_NE(causal, + true, + phi::errors::InvalidArgument( + "attn_mask is not nullptr, causal can not be true")); + + PADDLE_ENFORCE_EQ( + head_size == 32 || head_size == 64 || head_size == 128, + true, + phi::errors::InvalidArgument("The head_dim is expected to be either 32, " + "64, or 128, but recieved %d.", + head_size)); + + const int64_t* seed_offset_data = seed_offset.data(); + uint64_t seed = static_cast(seed_offset_data[0]); + uint64_t offset = static_cast(seed_offset_data[1]); + VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset; + + int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16; + DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seqlen_q}); + + const DenseTensor* attn_mask_tensor = attn_mask.get_ptr(); + std::vector mask_dims = GetAttnMaskDims(attn_mask_tensor); + + int fa_num_splits = 0; + bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16; + float fa_with_mask_scale = 1.0f; + bool fa_zero_tensors = false; + + uint64_t workspace_size; + bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( + static_cast(q.data()), + static_cast(k.data()), + static_cast(v.data()), + static_cast(dq->data()), + static_cast(dk->data()), + static_cast(dv->data()), + nullptr, // set out to nullptr to calculate workspace size + dout.data(), + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + fa_with_mask_scale, + fa_zero_tensors, + fa_is_bf16, + fa_num_splits, + static_cast(softmax_lse.data()), + static_cast(dsoftmax.data()), + nullptr, + nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, + nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, + nullptr); + CheckFlashAttnStatus(succ); + + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( + static_cast(q.data()), + static_cast(k.data()), + static_cast(v.data()), + static_cast(dq->data()), + static_cast(dk->data()), + static_cast(dv->data()), + out.data(), // set out to nullptr to calculate workspace size + dout.data(), + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + fa_with_mask_scale, + fa_zero_tensors, + fa_is_bf16, + fa_num_splits, + static_cast(softmax_lse.data()), + static_cast(dsoftmax.data()), + nullptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, + nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, + nullptr); + CheckFlashAttnStatus(succ); + + int64_t q_size = total_q * num_heads * head_size; + ComputeScaleQ(ctx, q_size, scale, dq->data(), dq->data()); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& q, @@ -42,6 +185,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, int64_t max_seqlen_q, int64_t max_seqlen_k, @@ -59,86 +203,102 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const cudaStream_t stream = ctx.stream(); // q,k,v [total_*, num_heads, head_dim] - auto dims = q.dims(); - const int64_t total_q = dims[0]; - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = dims[1]; - const int head_size_og = dout.dims()[2]; - const int head_size = dims[2]; - const int total_k = k.dims()[0]; - const int num_heads_k = k.dims()[1]; - - // TODO(umiswing): add deterministic in fa2. - // int num_splits = 0; // 0 for an internal heuristic, which is optimal - // if (FLAGS_cudnn_deterministic) { - // num_splits = 1; - // } - - const bool zero_tensors = false; - // TODO(umiswing): add shape check - PADDLE_ENFORCE_EQ( - head_size_og, - head_size, - phi::errors::InvalidArgument( - "flash_attn_bwd receive input with head_size_og == head_size")); + if (attn_mask.get_ptr()) { + FlashAttnUnpaddedGradImpl(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + dq, + dk, + dv); + } else { + const int64_t total_q = dims[0]; + const int64_t batch_size = cu_seqlens_q.numel() - 1; + const int64_t num_heads = dims[1]; + const int64_t head_size_og = dout.dims()[2]; + const int64_t head_size = dims[2]; + const int64_t total_k = k.dims()[0]; + const int64_t num_heads_k = k.dims()[1]; + + // TODO(umiswing): add deterministic in fa2. + // int num_splits = 0; // 0 for an internal heuristic, which is optimal + // if (FLAGS_cudnn_deterministic) { + // num_splits = 1; + // } + + // TODO(umiswing): add shape check + PADDLE_ENFORCE_EQ( + head_size_og, + head_size, + phi::errors::InvalidArgument( + "flash_attn_bwd receive input with head_size_og == head_size")); - FlashAttnBwdParamsV2 params = - FlashAttnBwdParamsV2(ctx, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - q.dtype(), - seed_offset.data()); - - VLOG(4) << "FlashAttn bwd seed: " << params.seed - << ", offset: " << params.offset; - - const bool succ = - phi::dynload::flash_attn_varlen_bwd(dout.data(), - q.data(), - k.data(), - v.data(), - out.data(), - params.softmax_d.data(), - softmax_lse.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - params.rng_state.data(), - dq->data(), - dk->data(), - dv->data(), - params.dq_accum.data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.is_bf16, - stream, - params.seed, - params.offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + FlashAttnBwdParamsV2 params = + FlashAttnBwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + q.dtype(), + seed_offset.data()); + + VLOG(10) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; + + bool succ = + phi::dynload::flash_attn_varlen_bwd(dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.is_bf16, + stream, + params.seed, + params.offset); + CheckFlashAttnStatus(succ); } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } @@ -150,6 +310,7 @@ void FlashAttnGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, float dropout, bool causal, @@ -159,14 +320,17 @@ void FlashAttnGradKernel(const Context& ctx, #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - auto dims = q.dims(); - const int batch_size = dims[0]; - const int seqlen_q = dims[1]; - const int num_heads = dims[2]; - const int head_size_og = dout.dims()[3]; - const int head_size = dims[3]; - const int seqlen_k = k.dims()[1]; - const int num_heads_k = k.dims()[2]; + const auto& dims = q.dims(); + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size_og = dout.dims()[3]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + const int64_t total_q = batch_size * seqlen_q; + const int64_t total_k = batch_size * seqlen_k; // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( @@ -175,71 +339,98 @@ void FlashAttnGradKernel(const Context& ctx, phi::errors::InvalidArgument( "flash_attn_bwd receive input with head_size_og == head_size")); - VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; + VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; const float scale = 1.0f / std::sqrt(head_size); + if (attn_mask.get_ptr()) { + DenseTensor q_t_s, k_t_s, v_t_s; + q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); + k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); + v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); - FlashAttnBwdParamsV2 params = - FlashAttnBwdParamsV2(ctx, - batch_size, - seqlen_q, - seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - q.dtype(), - seed_offset.data()); + DenseTensor cu_seqlens_q; + DenseTensor cu_seqlens_k; + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q); + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k); - ctx.template Alloc(dq); - ctx.template Alloc(dk); - ctx.template Alloc(dv); + FlashAttnUnpaddedGradKernel(ctx, + q_t_s, + k_t_s, + v_t_s, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + seqlen_q, + seqlen_k, + scale, + dropout, + causal, + dq, + dk, + dv); + } else { + FlashAttnBwdParamsV2 params = + FlashAttnBwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + q.dtype(), + seed_offset.data()); - cudaStream_t stream = ctx.stream(); - - VLOG(4) << "FlashAttn bwd seed: " << params.seed - << ", offset: " << params.offset; - - const bool succ = phi::dynload::flash_attn_bwd(dout.data(), - q.data(), - k.data(), - v.data(), - out.data(), - params.softmax_d.data(), - softmax_lse.data(), - params.rng_state.data(), - dq->data(), - dk->data(), - dv->data(), - params.dq_accum.data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.is_bf16, - stream, - params.seed, - params.offset); + ctx.template Alloc(dq); + ctx.template Alloc(dk); + ctx.template Alloc(dv); - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention-2, detail information is", - phi::dynload::flash_attn_error())); + cudaStream_t stream = ctx.stream(); + + VLOG(10) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; + + bool succ = phi::dynload::flash_attn_bwd(dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.is_bf16, + stream, + params.seed, + params.offset); + CheckFlashAttnStatus(succ); + } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index e943b7bbf78519cc1ff10be93e83c9d1d8302ed9..bcf8791d3c17f695bbb378162c883d4a0b0a49a7 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -15,27 +15,21 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" #include "glog/logging.h" // For VLOG() -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" DECLARE_bool(cudnn_deterministic); namespace phi { template -void FlashAttnUnpaddedKernel( +void FlashAttnWithMaskUnpaddedImpl( const Context& ctx, const DenseTensor& q, const DenseTensor& k, @@ -43,6 +37,7 @@ void FlashAttnUnpaddedKernel( const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_k, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, @@ -56,13 +51,174 @@ void FlashAttnUnpaddedKernel( DenseTensor* softmax_lse, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN + cudaStream_t stream = ctx.stream(); + + auto dims = q.dims(); + int64_t total_q = dims[0]; + int64_t num_heads = dims[1]; + int64_t head_size = dims[2]; + + int64_t total_k = k.dims()[0]; + int64_t batch_size = cu_seqlens_q.numel() - 1; + + PADDLE_ENFORCE_NE(causal, + true, + phi::errors::InvalidArgument( + "attn_mask is not nullptr, causal can not be true")); + + PADDLE_ENFORCE_EQ( + head_size == 32 || head_size == 64 || head_size == 128, + true, + phi::errors::InvalidArgument("The head_dim is expected to be either 32, " + "64, or 128, but recieved %d.", + head_size)); + + // Generate random state for dropout and save for recompute in grad. + auto seed_offset_pair = + GenerateRNGState(ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + uint64_t seed = seed_offset_pair.first; + uint64_t offset = seed_offset_pair.second; + + VLOG(10) << "FlashAttn fwd seed: " << seed << ", offset: " << offset; + + seed_offset->Resize({2}); + int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); + seed_offset_data[0] = static_cast(seed); + seed_offset_data[1] = static_cast(offset); + + // Allocate memory for softmax_lse and softmax. + int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16; + + softmax_lse->Resize({batch_size, num_heads, seqlen_q}); + ctx.template Alloc(softmax_lse); + + if (return_softmax) { + // may allocate more space than *max_seqlen_k* + int64_t blocksize_c = head_size > 64 ? 128 : 256; + int64_t seqlen_k = + ((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c; + if (max_seqlen_k <= 128) { + seqlen_k = 128; + } else if (max_seqlen_k <= 256) { + seqlen_k = 256; + } + softmax->Resize({batch_size, num_heads, seqlen_q, seqlen_k}); + ctx.template Alloc(softmax); + } + + // Compute scale Q + int64_t q_size = total_q * num_heads * head_size; + DenseTensor scaled_q = Empty(ctx, {total_q, num_heads, head_size}); + ComputeScaleQ(ctx, q_size, scale, q.data(), scaled_q.data()); + + const DenseTensor* attn_mask_tensor = attn_mask.get_ptr(); + std::vector mask_dims = GetAttnMaskDims(attn_mask_tensor); + + int fa_num_splits = 0; + bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16; + float fa_with_mask_scale = 1.0f; + bool fa_zero_tensors = false; + + uint64_t workspace_size = 0; + bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + static_cast(scaled_q.data()), + static_cast(k.data()), + static_cast(v.data()), + nullptr, // for calculation workspace size + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + fa_with_mask_scale, + fa_zero_tensors, + fa_is_bf16, + fa_num_splits, + softmax_lse->data(), + nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, + nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, + nullptr); + CheckFlashAttnStatus(succ); + + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); + } + succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + static_cast(scaled_q.data()), + k.data(), + v.data(), + out->data(), // set out to nullptr to calculate workspace size + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + fa_with_mask_scale, + fa_zero_tensors, + fa_is_bf16, + fa_num_splits, + softmax_lse->data(), + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, + nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, + nullptr); + CheckFlashAttnStatus(succ); +#else + RaiseNotSupportedError(); +#endif +} +template +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); // q,k,v [total_*, num_heads, head_dim] - auto dims = q.dims(); PADDLE_ENFORCE_EQ( dims.size(), @@ -70,79 +226,97 @@ void FlashAttnUnpaddedKernel( phi::errors::InvalidArgument("flash_attn_raw receive input with dim " "[total_seq_len, num_heads, head_dim]")); - const int64_t total_q = dims[0]; - const int num_heads = dims[1]; - const int head_size = dims[2]; - - const int total_k = k.dims()[0]; - const int num_heads_k = k.dims()[1]; - const int batch_size = cu_seqlens_q.numel() - 1; - - // TODO(umiswing): add deterministic in fa2. - // int num_splits = 0; // 0 for an internal heuristic, which is optimal - // if (FLAGS_cudnn_deterministic) { - // num_splits = 1; - // } - - // TODO(umiswing): add shape check - - FlashAttnFwdParamsV2 params = - FlashAttnFwdParamsV2(ctx, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset.get_ptr(), - softmax, - softmax_lse, - seed_offset); - - VLOG(4) << "FlashAttn fwd seed: " << params.seed - << ", offset: " << params.offset; - - const bool succ = phi::dynload::flash_attn_varlen_fwd( - q.data(), - k.data(), - v.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - params.rng_state.data(), - out->data(), - params.return_softmax ? softmax->data() : nullptr, - softmax_lse->data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.return_softmax, - params.is_bf16, - stream, - params.seed, - params.offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + if (attn_mask.get_ptr()) { + FlashAttnWithMaskUnpaddedImpl(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset); + } else { + const int64_t total_q = dims[0]; + const int64_t num_heads = dims[1]; + const int64_t head_size = dims[2]; + + const int64_t total_k = k.dims()[0]; + const int64_t num_heads_k = k.dims()[1]; + const int64_t batch_size = cu_seqlens_q.numel() - 1; + + // TODO(umiswing): add deterministic in fa2. + // int num_splits = 0; // 0 for an internal heuristic, which is optimal + // if (FLAGS_cudnn_deterministic) { + // num_splits = 1; + // } + + // TODO(umiswing): add shape check + + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + softmax, + softmax_lse, + seed_offset); + + VLOG(10) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; + + bool succ = phi::dynload::flash_attn_varlen_fwd( + q.data(), + k.data(), + v.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? softmax->data() : nullptr, + softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset); + CheckFlashAttnStatus(succ); } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } @@ -152,6 +326,7 @@ void FlashAttnKernel(const Context& ctx, const DenseTensor& k, const DenseTensor& v, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, float dropout, bool causal, bool return_softmax, @@ -163,89 +338,117 @@ void FlashAttnKernel(const Context& ctx, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - - auto dims = q.dims(); + const auto& dims = q.dims(); PADDLE_ENFORCE_EQ(dims.size(), 4, phi::errors::InvalidArgument( "flash_attn receive input with dim " "[batch_size, seq_len, num_heads, head_dim]")); - const int batch_size = dims[0]; - const int seqlen_q = dims[1]; - const int num_heads = dims[2]; - const int head_size = dims[3]; - const int seqlen_k = k.dims()[1]; - const int num_heads_k = k.dims()[2]; + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + const int64_t total_q = batch_size * seqlen_q; + const int64_t total_k = batch_size * seqlen_k; // TODO(umiswing): Add check shape const float scale = 1.0f / std::sqrt(head_size); - FlashAttnFwdParamsV2 params = - FlashAttnFwdParamsV2(ctx, - batch_size, - seqlen_q, - seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset.get_ptr(), - softmax, - softmax_lse, - seed_offset); - - VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; - - ctx.template Alloc(out); - - cudaStream_t stream = ctx.stream(); - - VLOG(4) << "FlashAttn fwd seed: " << params.seed - << ", offset: " << params.offset; - - bool succ = phi::dynload::flash_attn_fwd( - q.data(), - k.data(), - v.data(), - params.rng_state.data(), - out->data(), - params.return_softmax ? params.softmax->data() : nullptr, - params.softmax_lse->data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.return_softmax, - params.is_bf16, - stream, - params.seed, - params.offset); - - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention-2, detail information is", - phi::dynload::flash_attn_error())); + if (attn_mask.get_ptr()) { + DenseTensor q_t_s, k_t_s, v_t_s; + q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); + k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); + v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); + + DenseTensor cu_seqlens_q; + DenseTensor cu_seqlens_k; + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q); + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k); + + FlashAttnUnpaddedKernel(ctx, + q_t_s, + k_t_s, + v_t_s, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + seqlen_q, + seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset); + } else { + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + softmax, + softmax_lse, + seed_offset); + + VLOG(10) << "FlashAttn fwd dims: q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + VLOG(10) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; + + ctx.template Alloc(out); + + cudaStream_t stream = ctx.stream(); + bool succ = phi::dynload::flash_attn_fwd( + q.data(), + k.data(), + v.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? params.softmax->data() : nullptr, + params.softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset); + CheckFlashAttnStatus(succ); + } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index 62d0f4ec95b37eaf1265b62ae82802fa9caa4297..00ba036df09ba582d2415ddab7f07c8db6ebed56 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -14,8 +14,43 @@ #pragma once +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/enforce.h" + +#ifdef PADDLE_WITH_FLASHATTN +#include "paddle/phi/backends/dynload/flashattn.h" +#endif + namespace phi { +#ifdef PADDLE_WITH_FLASHATTN +static std::pair GenerateRNGState( + const GPUContext& ctx, + const paddle::optional& fixed_seed_offset, + const std::string& rng_name, + const int64_t batch_size, + const int64_t num_heads) { + if (fixed_seed_offset.get_ptr()) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset.get_ptr()->data(); + uint64_t seed = static_cast(fixed_seed_offset_data[0]); + uint64_t offset = static_cast(fixed_seed_offset_data[1]); + return std::make_pair(seed, offset); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + if (rng_name != "") { + auto gen = phi::GetRandomSeedGenerator(rng_name); + seed_offset_pair = gen->IncrementOffset(inc); + } else { + auto* gen = ctx.GetGenerator(); + seed_offset_pair = gen->IncrementOffset(inc); + } + return seed_offset_pair; + } +} + template struct FlashAttnFwdParamsV2 { int batch_size; @@ -55,7 +90,7 @@ struct FlashAttnFwdParamsV2 { const DataType q_dtype, const bool is_test, const std::string& rng_name, - const DenseTensor* const fixed_seed_offset_ptr, + const paddle::optional& fixed_seed_offset, DenseTensor* _softmax, DenseTensor* _softmax_lse, DenseTensor* _seed_offset) @@ -78,24 +113,11 @@ struct FlashAttnFwdParamsV2 { // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t // with the same size. rng_state = Empty(ctx, {2}); - if (fixed_seed_offset_ptr) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset_ptr->data(); - seed = static_cast(fixed_seed_offset_data[0]); - offset = static_cast(fixed_seed_offset_data[1]); - } else { - uint64_t inc = batch_size * num_heads * 32; - std::pair seed_offset_pair; - if (rng_name != "") { - auto gen = phi::GetRandomSeedGenerator(rng_name); - seed_offset_pair = gen->IncrementOffset(inc); - } else { - auto* gen = ctx.GetGenerator(); - seed_offset_pair = gen->IncrementOffset(inc); - } - seed = seed_offset_pair.first; - offset = seed_offset_pair.second; - } + + auto seed_offset_pair = GenerateRNGState( + ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; seed_offset->Resize({2}); int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); @@ -178,4 +200,66 @@ struct FlashAttnBwdParamsV2 { ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded}); } }; + +static void CheckFlashAttnStatus(const bool status) { + PADDLE_ENFORCE_EQ(status, + true, + phi::errors::External( + "Error in Flash-Attention, detail information is: %s", + phi::dynload::flash_attn_error())); +} + +template +__global__ void SimleScaleKernel(const T* input, + int64_t numel, + float scale, + T* ouput) { + CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { + ouput[i] = static_cast(scale * static_cast(input[i])); + } +} + +template +void ComputeScaleQ( + const Context& ctx, int64_t numel, float scale, const T* input, T* output) { + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); + SimleScaleKernel<<>>(input, numel, scale, output); +} + +static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { + std::vector mask_dim_4d; + if (attn_mask) { + const auto& origin_dims = attn_mask->dims(); + auto rank = origin_dims.size(); + PADDLE_ENFORCE_GE( + rank, + 4, + phi::errors::InvalidArgument( + "Teh number of dimenstions of attn_mask is expected to be greater " + "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", + rank, + origin_dims)); + + int64_t first_dim = 1; + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + mask_dim_4d = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + } + return mask_dim_4d; +} +#endif + +static void RaiseNotSupportedError() { + PADDLE_THROW( + phi::errors::Unimplemented("FlashAttention is unsupported, please check " + "the GPU compability and CUDA Version.")); +} + } // namespace phi diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b36bd5d74ec7b6d2a9edd7f806c4cd07cf3503cd..b9077896da6c59f0be8018cc68c29abef249ac66 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -202,6 +202,7 @@ def flash_attention( key, value, fixed_seed_offset, + None, dropout, causal, return_softmax, @@ -358,6 +359,7 @@ def flash_attn_unpadded( cu_seqlens_q, cu_seqlens_k, fixed_seed_offset, + None, max_seqlen_q, max_seqlen_k, scale, @@ -408,7 +410,13 @@ def flash_attn_unpadded( def scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + training=True, ): r""" The equation is: @@ -442,6 +450,7 @@ def scaled_dot_product_attention( not supported yet. dropout_p(float): The dropout ratio. is_causal(bool): Whether enable causal mode. + training(bool): Whether it is in the training phase Returns: out(Tensor): The attention tensor. @@ -458,6 +467,22 @@ def scaled_dot_product_attention( >>> print(output) >>> # xdoctest: -SKIP """ - assert attn_mask is None, "attn_mask is not supported yet" - out, _ = flash_attention(query, key, value, dropout_p, is_causal) + if attn_mask is None: + out, _ = flash_attention(query, key, value, dropout_p, is_causal) + else: + fixed_seed_offset = (None,) + return_softmax = False + rng_name = "" + out, _ = _C_ops.flash_attn( + query, + key, + value, + fixed_seed_offset, + attn_mask, + dropout_p, + is_causal, + return_softmax, + not training, + rng_name, + ) return out diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index cc23331eadf56e8d8e875a0a7aec60f4752befb5..6490ce428381bbdb051b9092449c24a35be56238 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -57,6 +57,18 @@ def attention_naive(q, k, v, causal=False): return paddle.transpose(o, [0, 2, 1, 3]) +def attention_naive_with_mask(q, k, v, attn_bias): + qt = paddle.transpose(q, [0, 2, 1, 3]) + kt = paddle.transpose(k, [0, 2, 1, 3]) + vt = paddle.transpose(v, [0, 2, 1, 3]) + scale = 1.0 / np.sqrt(q.shape[-1]) + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) + s = paddle.scale(s, scale) + p = F.softmax(s + attn_bias) + o = paddle.matmul(p, vt) + return paddle.transpose(o, [0, 2, 1, 3]) + + is_sm8x = ( core.is_compiled_with_cuda() and paddle.device.cuda.get_device_capability()[0] == 8 @@ -296,6 +308,64 @@ class TestFlashAttentionAPI(unittest.TestCase): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "and device's compute capability must be 7.5 or 8.x", +) +class TestFlashAttentionWithMaskAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 32) + self.dtype = 'float16' + self.dropout = 0.0 + self.causal = False + + def test_dot_scale_product(self): + # test dynamic + paddle.disable_static() + + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1]) + mask = np.random.random(mask_shape) + m = paddle.to_tensor( + mask, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out = scaled_dot_product_attention( + q, k, v, m, self.dropout, self.causal + ) + out_ = attention_naive_with_mask(q_, k_, v_, m) + out.backward() + out_.backward() + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + + class TestFlashAttentionAPITest1(TestFlashAttentionAPI): def setUp(self): self.place = paddle.CUDAPlace(0) @@ -370,5 +440,14 @@ class TestSDPAttentionAPITest(TestFlashAttentionAPI): self.enable_mem_efficient = False +class TestFlashAttrnionWithMaskAPI(TestFlashAttentionWithMaskAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + + if __name__ == '__main__': unittest.main()