未验证 提交 42e0c6b8 编写于 作者: Y yin wei 提交者: GitHub

Add attn_mask supported for FlashAttnKernel. (#55969)

* add mask

* add backword

* add enforce info

* update scale

* integrate code

* update enforce

* add enforce eq

* add error type

* update enforce

* add test_flash_attention

* Polish codes and fix compiling errors.

* Set num_splits to 0 for flash-attn with tensor mask.

* Fix the compiling error for non flash-attn case.

---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 0434b828
......@@ -818,8 +818,9 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
......@@ -829,8 +830,9 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
......
......@@ -910,9 +910,9 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
optional : fixed_seed_offset, attn_mask
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......@@ -923,9 +923,9 @@
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
optional : fixed_seed_offset , attn_mask
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......
......@@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
......@@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
......
......@@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel(
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
......@@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
float dropout,
bool causal,
bool return_softmax,
......
......@@ -14,8 +14,43 @@
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
namespace phi {
#ifdef PADDLE_WITH_FLASHATTN
static std::pair<uint64_t, uint64_t> GenerateRNGState(
const GPUContext& ctx,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const std::string& rng_name,
const int64_t batch_size,
const int64_t num_heads) {
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
uint64_t seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
return std::make_pair(seed, offset);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
return seed_offset_pair;
}
}
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
......@@ -55,7 +90,7 @@ struct FlashAttnFwdParamsV2 {
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
const paddle::optional<DenseTensor>& fixed_seed_offset,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
......@@ -78,24 +113,11 @@ struct FlashAttnFwdParamsV2 {
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
auto seed_offset_pair = GenerateRNGState(
ctx, fixed_seed_offset, rng_name, batch_size, num_heads);
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
......@@ -178,4 +200,66 @@ struct FlashAttnBwdParamsV2 {
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
static void CheckFlashAttnStatus(const bool status) {
PADDLE_ENFORCE_EQ(status,
true,
phi::errors::External(
"Error in Flash-Attention, detail information is: %s",
phi::dynload::flash_attn_error()));
}
template <typename T>
__global__ void SimleScaleKernel(const T* input,
int64_t numel,
float scale,
T* ouput) {
CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) {
ouput[i] = static_cast<T>(scale * static_cast<float>(input[i]));
}
}
template <typename T, typename Context>
void ComputeScaleQ(
const Context& ctx, int64_t numel, float scale, const T* input, T* output) {
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1);
SimleScaleKernel<<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
ctx.stream()>>>(input, numel, scale, output);
}
static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
std::vector<int64_t> mask_dim_4d;
if (attn_mask) {
const auto& origin_dims = attn_mask->dims();
auto rank = origin_dims.size();
PADDLE_ENFORCE_GE(
rank,
4,
phi::errors::InvalidArgument(
"Teh number of dimenstions of attn_mask is expected to be greater "
"or equal to 4, but recieved %d. The shape of attn_mask is {%s}",
rank,
origin_dims));
int64_t first_dim = 1;
for (int i = 0; i < rank - 3; i++) {
first_dim *= origin_dims[i];
}
mask_dim_4d = {first_dim,
origin_dims[rank - 3],
origin_dims[rank - 2],
origin_dims[rank - 1]};
}
return mask_dim_4d;
}
#endif
static void RaiseNotSupportedError() {
PADDLE_THROW(
phi::errors::Unimplemented("FlashAttention is unsupported, please check "
"the GPU compability and CUDA Version."));
}
} // namespace phi
......@@ -202,6 +202,7 @@ def flash_attention(
key,
value,
fixed_seed_offset,
None,
dropout,
causal,
return_softmax,
......@@ -358,6 +359,7 @@ def flash_attn_unpadded(
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
None,
max_seqlen_q,
max_seqlen_k,
scale,
......@@ -408,7 +410,13 @@ def flash_attn_unpadded(
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
training=True,
):
r"""
The equation is:
......@@ -442,6 +450,7 @@ def scaled_dot_product_attention(
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
training(bool): Whether it is in the training phase
Returns:
out(Tensor): The attention tensor.
......@@ -458,6 +467,22 @@ def scaled_dot_product_attention(
>>> print(output)
>>> # xdoctest: -SKIP
"""
assert attn_mask is None, "attn_mask is not supported yet"
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
if attn_mask is None:
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
else:
fixed_seed_offset = (None,)
return_softmax = False
rng_name = ""
out, _ = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
attn_mask,
dropout_p,
is_causal,
return_softmax,
not training,
rng_name,
)
return out
......@@ -57,6 +57,18 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3])
def attention_naive_with_mask(q, k, v, attn_bias):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
p = F.softmax(s + attn_bias)
o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
......@@ -296,6 +308,64 @@ class TestFlashAttentionAPI(unittest.TestCase):
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"and device's compute capability must be 7.5 or 8.x",
)
class TestFlashAttentionWithMaskAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 32)
self.dtype = 'float16'
self.dropout = 0.0
self.causal = False
def test_dot_scale_product(self):
# test dynamic
paddle.disable_static()
query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
mask = np.random.random(mask_shape)
m = paddle.to_tensor(
mask, place=self.place, dtype=self.dtype, stop_gradient=False
)
out = scaled_dot_product_attention(
q, k, v, m, self.dropout, self.causal
)
out_ = attention_naive_with_mask(q_, k_, v_, m)
out.backward()
out_.backward()
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
......@@ -370,5 +440,14 @@ class TestSDPAttentionAPITest(TestFlashAttentionAPI):
self.enable_mem_efficient = False
class TestFlashAttrnionWithMaskAPI(TestFlashAttentionWithMaskAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册