未验证 提交 cc9a7688 编写于 作者: U umiswing 提交者: GitHub

[cherry-pick] Integration flash attention 2 (#56015)

* [FlashAttn] add flash randomness control (#52902)

* add flash randomness control

* fix VLOG undefied

* [WIP] Integration flash attention 2 (#55758)

* Work for fa-2 padded fwd. Code to be cleaned.

* Work for fa2 unpadded fwd.

* Work for padded-bwd, dk get small diff on np.random.seed(0)

* Anyway I pass paddle's utest, except return softmax without dropout.

* Clean code.

* Modify interface.

* Clean code and add some check.

* Easy compile for dev.

* Fix ci.

* Fix ci-build.

* Add std c++17 option again.

* Limit max job when compiling fa2.

* Remove const_cast

* Add fwd params, to be cleaned.

* Clean code.

* Add bwd params.

* Clean code.

* Add enforce.

* Use v2.0.4

* Pass RNG state to fa2 capi

* Fix review.

* Add assert

* Skip compile for sm less than 80.

---------
Co-authored-by: NChitsing KUI <kuizhiqing@msn.com>
上级 8d3a9882
......@@ -17,10 +17,10 @@ include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_TAG b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
......@@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
......@@ -93,6 +93,8 @@ ExternalProject_Add(
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
......@@ -512,10 +512,15 @@ if(WITH_GPU
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
break()
endif()
endforeach()
endif()
endif()
......
......@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -628,7 +628,7 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......
......@@ -678,8 +678,9 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......@@ -690,8 +691,9 @@
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......
......@@ -43,9 +43,13 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
......
......@@ -20,33 +20,38 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......
......@@ -25,6 +25,7 @@
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
DECLARE_bool(cudnn_deterministic);
......@@ -55,116 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
const cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
const int64_t total_q = dims[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
const bool zero_tensors = false;
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -186,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......
......@@ -21,13 +21,13 @@
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
DECLARE_bool(cudnn_deterministic);
......@@ -35,30 +35,31 @@ DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
......@@ -69,126 +70,79 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ =
phi::dynload::flash_attn_fwd(q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd(
const int64_t total_q = dims[0];
const int num_heads = dims[1];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
const int batch_size = cu_seqlens_q.numel() - 1;
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
// TODO(umiswing): add shape check
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_varlen_fwd(
q.data(),
k.data(),
v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
params.return_softmax ? softmax->data() : nullptr,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
seed,
offset);
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -197,10 +151,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......@@ -215,51 +171,81 @@ void FlashAttnKernel(const Context& ctx,
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
// TODO(umiswing): Add check shape
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
int64_t seq_len_k = k.dims()[1];
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
ctx.template Alloc<T>(out);
float scale = 1.0f / std::sqrt(head_size);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
is_test,
out,
softmax,
softmax_lse,
seed_offset);
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -270,11 +256,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn,
GPU,
ALL_LAYOUT,
phi::FlashAttnKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
// for padded kernel, max_seqlen_q and seqlen_q is the same.
int64_t max_seqlen_q;
// for padded kernel, max_seqlen_k and seqlen_k is the same.
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool return_softmax;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor rng_state;
DenseTensor* softmax;
DenseTensor* softmax_lse;
DenseTensor* seed_offset;
FlashAttnFwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const bool _return_softmax,
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
head_size(_head_size),
scale(_scale),
dropout(_dropout),
causal(_causal),
return_softmax(_return_softmax),
softmax(_softmax),
softmax_lse(_softmax_lse),
seed_offset(_seed_offset) {
dropout = is_test ? 0.0f : _dropout;
is_bf16 = q_dtype == DataType::BFLOAT16;
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_lse->Resize({batch_size, num_heads, max_seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
softmax->Resize(
{batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded});
ctx.template Alloc<T>(softmax);
}
}
};
struct FlashAttnBwdParamsV2 {
int batch_size;
int64_t max_seqlen_q;
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor softmax_d;
DenseTensor dq_accum;
DenseTensor rng_state;
FlashAttnBwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const DataType q_dtype,
const int64_t* seed_offset_data)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
dropout(_dropout),
scale(_scale),
causal(_causal) {
is_bf16 = q_dtype == DataType::BFLOAT16;
seed = static_cast<uint64_t>(seed_offset_data[0]);
offset = static_cast<uint64_t>(seed_offset_data[1]);
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_d = Empty<float>(ctx, {batch_size, num_heads, seqlen_q_rounded});
dq_accum = Empty<float>(
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
} // namespace phi
......@@ -37,3 +37,4 @@ from . import dist_shape
from . import dist_assign
from . import dist_scale
from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
......@@ -56,9 +56,27 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3])
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0
)
is_sm90 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 9
and paddle.device.cuda.get_device_capability()[1] == 0
)
is_sm_supported = is_sm8x or is_sm90
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
not core.is_compiled_with_cuda()
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 8.x or 90",
)
class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
......
......@@ -24,6 +24,9 @@ def flash_attention(
dropout=0.0,
causal=False,
return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -57,7 +60,9 @@ def flash_attention(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -84,10 +89,12 @@ def flash_attention(
query,
key,
value,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
......@@ -101,6 +108,7 @@ def flash_attention(
'q': query,
'k': key,
'v': value,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -117,6 +125,7 @@ def flash_attention(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
......@@ -134,6 +143,8 @@ def flash_attn_unpadded(
dropout=0.0,
causal=False,
return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -174,6 +185,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
......@@ -203,6 +216,7 @@ def flash_attn_unpadded(
value,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q,
max_seqlen_k,
scale,
......@@ -210,6 +224,7 @@ def flash_attn_unpadded(
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
......@@ -225,6 +240,7 @@ def flash_attn_unpadded(
'v': value,
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -244,6 +260,7 @@ def flash_attn_unpadded(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册