未验证 提交 42e0c6b8 编写于 作者: Y yin wei 提交者: GitHub

Add attn_mask supported for FlashAttnKernel. (#55969)

* add mask

* add backword

* add enforce info

* update scale

* integrate code

* update enforce

* add enforce eq

* add error type

* update enforce

* add test_flash_attention

* Polish codes and fix compiling errors.

* Set num_splits to 0 for flash-attn with tensor mask.

* Fix the compiling error for non flash-attn case.

---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 0434b828
......@@ -818,8 +818,9 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
......@@ -829,8 +830,9 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
......
......@@ -910,9 +910,9 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
optional : fixed_seed_offset, attn_mask
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......@@ -923,9 +923,9 @@
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
optional : fixed_seed_offset , attn_mask
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......
......@@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
......@@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
......
......@@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel(
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
......@@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
float dropout,
bool causal,
bool return_softmax,
......
......@@ -21,17 +21,160 @@
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
#include "paddle/phi/kernels/reshape_kernel.h"
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedGradImpl(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
const cudaStream_t stream = ctx.stream();
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
"attn_mask is not nullptr, causal can not be true"));
PADDLE_ENFORCE_EQ(
head_size == 32 || head_size == 64 || head_size == 128,
true,
phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
"64, or 128, but recieved %d.",
head_size));
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset;
int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seqlen_q});
const DenseTensor* attn_mask_tensor = attn_mask.get_ptr();
std::vector<int64_t> mask_dims = GetAttnMaskDims(attn_mask_tensor);
int fa_num_splits = 0;
bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16;
float fa_with_mask_scale = 1.0f;
bool fa_zero_tensors = false;
uint64_t workspace_size;
bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()),
static_cast<void*>(dk->data()),
static_cast<void*>(dv->data()),
nullptr, // set out to nullptr to calculate workspace size
dout.data(),
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
static_cast<const void*>(softmax_lse.data()),
static_cast<void*>(dsoftmax.data()),
nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(
ctx, {static_cast<int64_t>(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q.data()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
static_cast<void*>(dq->data()),
static_cast<void*>(dk->data()),
static_cast<void*>(dv->data()),
out.data(), // set out to nullptr to calculate workspace size
dout.data(),
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
static_cast<const void*>(softmax_lse.data()),
static_cast<void*>(dsoftmax.data()),
nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
int64_t q_size = total_q * num_heads * head_size;
ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>());
#else
RaiseNotSupportedError();
#endif
}
template <typename T, typename Context>
void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
......@@ -42,6 +185,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
......@@ -59,86 +203,102 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
const int64_t total_q = dims[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
const bool zero_tensors = false;
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
if (attn_mask.get_ptr()) {
FlashAttnUnpaddedGradImpl<T, Context>(ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
attn_mask,
dout,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
dq,
dk,
dv);
} else {
const int64_t total_q = dims[0];
const int64_t batch_size = cu_seqlens_q.numel() - 1;
const int64_t num_heads = dims[1];
const int64_t head_size_og = dout.dims()[2];
const int64_t head_size = dims[2];
const int64_t total_k = k.dims()[0];
const int64_t num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
VLOG(10) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
RaiseNotSupportedError();
#endif
}
......@@ -150,6 +310,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
......@@ -159,14 +320,17 @@ void FlashAttnGradKernel(const Context& ctx,
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
const auto& dims = q.dims();
const int64_t batch_size = dims[0];
const int64_t seqlen_q = dims[1];
const int64_t num_heads = dims[2];
const int64_t head_size_og = dout.dims()[3];
const int64_t head_size = dims[3];
const int64_t seqlen_k = k.dims()[1];
const int64_t num_heads_k = k.dims()[2];
const int64_t total_q = batch_size * seqlen_q;
const int64_t total_k = batch_size * seqlen_k;
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
......@@ -175,71 +339,98 @@ void FlashAttnGradKernel(const Context& ctx,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
const float scale = 1.0f / std::sqrt(head_size);
if (attn_mask.get_ptr()) {
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k);
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
attn_mask,
dout,
seqlen_q,
seqlen_k,
scale,
dropout,
causal,
dq,
dk,
dv);
} else {
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
cudaStream_t stream = ctx.stream();
VLOG(10) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
RaiseNotSupportedError();
#endif
}
......
......@@ -15,27 +15,21 @@
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
#include "paddle/phi/kernels/reshape_kernel.h"
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(
void FlashAttnWithMaskUnpaddedImpl(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
......@@ -43,6 +37,7 @@ void FlashAttnUnpaddedKernel(
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
......@@ -56,13 +51,174 @@ void FlashAttnUnpaddedKernel(
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
cudaStream_t stream = ctx.stream();
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
PADDLE_ENFORCE_NE(causal,
true,
phi::errors::InvalidArgument(
"attn_mask is not nullptr, causal can not be true"));
PADDLE_ENFORCE_EQ(
head_size == 32 || head_size == 64 || head_size == 128,
true,
phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
"64, or 128, but recieved %d.",
head_size));
// Generate random state for dropout and save for recompute in grad.
auto seed_offset_pair =
GenerateRNGState(ctx, fixed_seed_offset, rng_name, batch_size, num_heads);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(10) << "FlashAttn fwd seed: " << seed << ", offset: " << offset;
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
// Allocate memory for softmax_lse and softmax.
int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seqlen_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seqlen_k = 128;
} else if (max_seqlen_k <= 256) {
seqlen_k = 256;
}
softmax->Resize({batch_size, num_heads, seqlen_q, seqlen_k});
ctx.template Alloc<T>(softmax);
}
// Compute scale Q
int64_t q_size = total_q * num_heads * head_size;
DenseTensor scaled_q = Empty<T>(ctx, {total_q, num_heads, head_size});
ComputeScaleQ(ctx, q_size, scale, q.data<T>(), scaled_q.data<T>());
const DenseTensor* attn_mask_tensor = attn_mask.get_ptr();
std::vector<int64_t> mask_dims = GetAttnMaskDims(attn_mask_tensor);
int fa_num_splits = 0;
bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16;
float fa_with_mask_scale = 1.0f;
bool fa_zero_tensors = false;
uint64_t workspace_size = 0;
bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask(
static_cast<const void*>(scaled_q.data()),
static_cast<const void*>(k.data()),
static_cast<const void*>(v.data()),
nullptr, // for calculation workspace size
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
softmax_lse->data(),
nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(
ctx, {static_cast<int64_t>(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd_with_bias_and_mask(
static_cast<const void*>(scaled_q.data()),
k.data(),
v.data(),
out->data(), // set out to nullptr to calculate workspace size
static_cast<const int32_t*>(cu_seqlens_q.data()),
static_cast<const int32_t*>(cu_seqlens_k.data()),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
fa_with_mask_scale,
fa_zero_tensors,
fa_is_bf16,
fa_num_splits,
softmax_lse->data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset,
attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
nullptr,
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
#else
RaiseNotSupportedError();
#endif
}
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
PADDLE_ENFORCE_EQ(
dims.size(),
......@@ -70,79 +226,97 @@ void FlashAttnUnpaddedKernel(
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
const int64_t total_q = dims[0];
const int num_heads = dims[1];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
const int batch_size = cu_seqlens_q.numel() - 1;
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
// TODO(umiswing): add shape check
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_varlen_fwd(
q.data(),
k.data(),
v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(),
params.return_softmax ? softmax->data() : nullptr,
softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
if (attn_mask.get_ptr()) {
FlashAttnWithMaskUnpaddedImpl<T, Context>(ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
attn_mask,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
is_test,
rng_name,
out,
softmax,
softmax_lse,
seed_offset);
} else {
const int64_t total_q = dims[0];
const int64_t num_heads = dims[1];
const int64_t head_size = dims[2];
const int64_t total_k = k.dims()[0];
const int64_t num_heads_k = k.dims()[1];
const int64_t batch_size = cu_seqlens_q.numel() - 1;
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
// TODO(umiswing): add shape check
FlashAttnFwdParamsV2<T> params = FlashAttnFwdParamsV2<T>(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset,
softmax,
softmax_lse,
seed_offset);
VLOG(10) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
bool succ = phi::dynload::flash_attn_varlen_fwd(
q.data(),
k.data(),
v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(),
params.return_softmax ? softmax->data() : nullptr,
softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
RaiseNotSupportedError();
#endif
}
......@@ -152,6 +326,7 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
float dropout,
bool causal,
bool return_softmax,
......@@ -163,89 +338,117 @@ void FlashAttnKernel(const Context& ctx,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
const auto& dims = q.dims();
PADDLE_ENFORCE_EQ(dims.size(),
4,
phi::errors::InvalidArgument(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
const int64_t batch_size = dims[0];
const int64_t seqlen_q = dims[1];
const int64_t num_heads = dims[2];
const int64_t head_size = dims[3];
const int64_t seqlen_k = k.dims()[1];
const int64_t num_heads_k = k.dims()[2];
const int64_t total_q = batch_size * seqlen_q;
const int64_t total_k = batch_size * seqlen_k;
// TODO(umiswing): Add check shape
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
if (attn_mask.get_ptr()) {
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
attn_mask,
seqlen_q,
seqlen_k,
scale,
dropout,
causal,
return_softmax,
is_test,
rng_name,
out,
softmax,
softmax_lse,
seed_offset);
} else {
FlashAttnFwdParamsV2<T> params = FlashAttnFwdParamsV2<T>(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset,
softmax,
softmax_lse,
seed_offset);
VLOG(10) << "FlashAttn fwd dims: q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
VLOG(10) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
CheckFlashAttnStatus(succ);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
RaiseNotSupportedError();
#endif
}
......
......@@ -14,8 +14,43 @@
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
namespace phi {
#ifdef PADDLE_WITH_FLASHATTN
static std::pair<uint64_t, uint64_t> GenerateRNGState(
const GPUContext& ctx,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const std::string& rng_name,
const int64_t batch_size,
const int64_t num_heads) {
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
uint64_t seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
return std::make_pair(seed, offset);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
return seed_offset_pair;
}
}
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
......@@ -55,7 +90,7 @@ struct FlashAttnFwdParamsV2 {
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
const paddle::optional<DenseTensor>& fixed_seed_offset,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
......@@ -78,24 +113,11 @@ struct FlashAttnFwdParamsV2 {
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
auto seed_offset_pair = GenerateRNGState(
ctx, fixed_seed_offset, rng_name, batch_size, num_heads);
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
......@@ -178,4 +200,66 @@ struct FlashAttnBwdParamsV2 {
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
static void CheckFlashAttnStatus(const bool status) {
PADDLE_ENFORCE_EQ(status,
true,
phi::errors::External(
"Error in Flash-Attention, detail information is: %s",
phi::dynload::flash_attn_error()));
}
template <typename T>
__global__ void SimleScaleKernel(const T* input,
int64_t numel,
float scale,
T* ouput) {
CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) {
ouput[i] = static_cast<T>(scale * static_cast<float>(input[i]));
}
}
template <typename T, typename Context>
void ComputeScaleQ(
const Context& ctx, int64_t numel, float scale, const T* input, T* output) {
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1);
SimleScaleKernel<<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
ctx.stream()>>>(input, numel, scale, output);
}
static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
std::vector<int64_t> mask_dim_4d;
if (attn_mask) {
const auto& origin_dims = attn_mask->dims();
auto rank = origin_dims.size();
PADDLE_ENFORCE_GE(
rank,
4,
phi::errors::InvalidArgument(
"Teh number of dimenstions of attn_mask is expected to be greater "
"or equal to 4, but recieved %d. The shape of attn_mask is {%s}",
rank,
origin_dims));
int64_t first_dim = 1;
for (int i = 0; i < rank - 3; i++) {
first_dim *= origin_dims[i];
}
mask_dim_4d = {first_dim,
origin_dims[rank - 3],
origin_dims[rank - 2],
origin_dims[rank - 1]};
}
return mask_dim_4d;
}
#endif
static void RaiseNotSupportedError() {
PADDLE_THROW(
phi::errors::Unimplemented("FlashAttention is unsupported, please check "
"the GPU compability and CUDA Version."));
}
} // namespace phi
......@@ -202,6 +202,7 @@ def flash_attention(
key,
value,
fixed_seed_offset,
None,
dropout,
causal,
return_softmax,
......@@ -358,6 +359,7 @@ def flash_attn_unpadded(
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
None,
max_seqlen_q,
max_seqlen_k,
scale,
......@@ -408,7 +410,13 @@ def flash_attn_unpadded(
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
training=True,
):
r"""
The equation is:
......@@ -442,6 +450,7 @@ def scaled_dot_product_attention(
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
training(bool): Whether it is in the training phase
Returns:
out(Tensor): The attention tensor.
......@@ -458,6 +467,22 @@ def scaled_dot_product_attention(
>>> print(output)
>>> # xdoctest: -SKIP
"""
assert attn_mask is None, "attn_mask is not supported yet"
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
if attn_mask is None:
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
else:
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out
......@@ -57,6 +57,18 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3])
def attention_naive_with_mask(q, k, v, attn_bias):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
p = F.softmax(s + attn_bias)
o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
......@@ -296,6 +308,64 @@ class TestFlashAttentionAPI(unittest.TestCase):
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"and device's compute capability must be 7.5 or 8.x",
)
class TestFlashAttentionWithMaskAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 32)
self.dtype = 'float16'
self.dropout = 0.0
self.causal = False
def test_dot_scale_product(self):
# test dynamic
paddle.disable_static()
query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
mask = np.random.random(mask_shape)
m = paddle.to_tensor(
mask, place=self.place, dtype=self.dtype, stop_gradient=False
)
out = scaled_dot_product_attention(
q, k, v, m, self.dropout, self.causal
)
out_ = attention_naive_with_mask(q_, k_, v_, m)
out.backward()
out_.backward()
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
......@@ -370,5 +440,14 @@ class TestSDPAttentionAPITest(TestFlashAttentionAPI):
self.enable_mem_efficient = False
class TestFlashAttrnionWithMaskAPI(TestFlashAttentionWithMaskAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册