未验证 提交 cc9a7688 编写于 作者: U umiswing 提交者: GitHub

[cherry-pick] Integration flash attention 2 (#56015)

* [FlashAttn] add flash randomness control (#52902)

* add flash randomness control

* fix VLOG undefied

* [WIP] Integration flash attention 2 (#55758)

* Work for fa-2 padded fwd. Code to be cleaned.

* Work for fa2 unpadded fwd.

* Work for padded-bwd, dk get small diff on np.random.seed(0)

* Anyway I pass paddle's utest, except return softmax without dropout.

* Clean code.

* Modify interface.

* Clean code and add some check.

* Easy compile for dev.

* Fix ci.

* Fix ci-build.

* Add std c++17 option again.

* Limit max job when compiling fa2.

* Remove const_cast

* Add fwd params, to be cleaned.

* Clean code.

* Add bwd params.

* Clean code.

* Add enforce.

* Use v2.0.4

* Pass RNG state to fa2 capi

* Fix review.

* Add assert

* Skip compile for sm less than 80.

---------
Co-authored-by: NChitsing KUI <kuizhiqing@msn.com>
上级 8d3a9882
...@@ -17,10 +17,10 @@ include(ExternalProject) ...@@ -17,10 +17,10 @@ include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN) add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) set(FLASHATTN_TAG b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a)
set(FLASHATTN_INCLUDE_DIR set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include" "${FLASHATTN_INSTALL_DIR}/include"
...@@ -62,7 +62,7 @@ else() ...@@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS}) set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif() endif()
...@@ -93,6 +93,8 @@ ExternalProject_Add( ...@@ -93,6 +93,8 @@ ExternalProject_Add(
-DBUILD_SHARED=ON -DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
...@@ -512,10 +512,15 @@ if(WITH_GPU ...@@ -512,10 +512,15 @@ if(WITH_GPU
list(APPEND third_party_deps extern_cutlass) list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON) set(WITH_CUTLASS ON)
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2) if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
include(external/flashattn) foreach(arch ${NVCC_ARCH_BIN})
list(APPEND third_party_deps extern_flashattn) if(${arch} GREATER_EQUAL 80)
set(WITH_FLASHATTN ON) include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
break()
endif()
endforeach()
endif() endif()
endif() endif()
......
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad - backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
...@@ -628,7 +628,7 @@ ...@@ -628,7 +628,7 @@
data_type: q data_type: q
- backward_op : flash_attn_unpadded_grad - backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
......
...@@ -678,8 +678,9 @@ ...@@ -678,8 +678,9 @@
backward : fill_diagonal_tensor_grad backward : fill_diagonal_tensor_grad
- op : flash_attn - op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
...@@ -690,8 +691,9 @@ ...@@ -690,8 +691,9 @@
backward : flash_attn_grad backward : flash_attn_grad
- op : flash_attn_unpadded - op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
......
...@@ -43,9 +43,13 @@ extern void* flashattn_dso_handle; ...@@ -43,9 +43,13 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name) DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \ #define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \ __macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \ __macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error); __macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP); FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
......
...@@ -20,33 +20,38 @@ ...@@ -20,33 +20,38 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx, void FlashAttnUnpaddedKernel(
const DenseTensor& q, const Context& ctx,
const DenseTensor& k, const DenseTensor& q,
const DenseTensor& v, const DenseTensor& k,
const DenseTensor& cu_seqlens_q, const DenseTensor& v,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_q,
int64_t max_seqlen_q, const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_k, const paddle::optional<DenseTensor>& fixed_seed_offset,
float scale, int64_t max_seqlen_q,
float dropout, int64_t max_seqlen_k,
bool causal, float scale,
bool return_softmax, float dropout,
bool is_test, bool causal,
DenseTensor* out, bool return_softmax,
DenseTensor* softmax, bool is_test,
DenseTensor* softmax_lse, const std::string& rng_name,
DenseTensor* seed_offset); DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx, void FlashAttnKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif #endif
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
...@@ -55,116 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -55,116 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
ctx.template Alloc<T>(dk); ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv); ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream(); const cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim] // q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims(); auto dims = q.dims();
int64_t total_q = dims[0]; const int64_t total_q = dims[0];
int64_t num_heads = dims[1]; const int batch_size = cu_seqlens_q.numel() - 1;
int64_t head_size = dims[2]; const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
int64_t total_k = k.dims()[0]; const int head_size = dims[2];
int64_t batch_size = cu_seqlens_q.numel() - 1; const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false; // TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) { // if (FLAGS_cudnn_deterministic) {
num_splits = 1; // num_splits = 1;
} // }
const int64_t* seed_offset_data = seed_offset.data<int64_t>(); const bool zero_tensors = false;
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]); // TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset head_size_og,
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size, head_size,
max_seqlen_q, phi::errors::InvalidArgument(
max_seqlen_k, "flash_attn_bwd receive input with head_size_og == head_size"));
dropout,
scale, FlashAttnBwdParamsV2 params =
zero_tensors, FlashAttnBwdParamsV2(ctx,
causal, batch_size,
is_bf16, max_seqlen_q,
num_splits, max_seqlen_k,
const_cast<float*>(softmax_lse.data<float>()), num_heads,
dsoftmax.data(), num_heads_k,
nullptr, head_size,
&workspace_size, dropout,
stream, scale,
seed, causal,
offset); q.dtype(),
seed_offset.data<int64_t>());
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
if (!succ) { if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
} }
#else
DenseTensor workspace; PADDLE_THROW(phi::errors::Unimplemented(
if (workspace_size > 0) { "FlashAttention is unsupported, please set use_flash_attn to false."));
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#endif #endif
} }
...@@ -186,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -186,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
// q,k,v [batch_size, seq_len, num_heads, head_dim] // q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims(); auto dims = q.dims();
int64_t batch_size = dims[0]; const int batch_size = dims[0];
int64_t seq_len_q = dims[1]; const int seqlen_q = dims[1];
int64_t num_heads = dims[2]; const int num_heads = dims[2];
int64_t head_size = dims[3]; const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
int64_t seq_len_k = k.dims()[1]; const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k; // TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
float scale = 1.0f / std::sqrt(head_size); head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]"; << "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s; const float scale = 1.0f / std::sqrt(head_size);
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); FlashAttnBwdParamsV2 params =
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
DenseTensor cu_seqlens_q; ctx.template Alloc<T>(dq);
DenseTensor cu_seqlens_k; ctx.template Alloc<T>(dk);
ArangeNullaryKernel<int32_t, Context>( ctx.template Alloc<T>(dv);
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedGradKernel<T, Context>(ctx, cudaStream_t stream = ctx.stream();
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
......
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif #endif
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
...@@ -35,30 +35,31 @@ DECLARE_bool(cudnn_deterministic); ...@@ -35,30 +35,31 @@ DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx, void FlashAttnUnpaddedKernel(
const DenseTensor& q, const Context& ctx,
const DenseTensor& k, const DenseTensor& q,
const DenseTensor& v, const DenseTensor& k,
const DenseTensor& cu_seqlens_q, const DenseTensor& v,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_q,
int64_t max_seqlen_q, const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_k, const paddle::optional<DenseTensor>& fixed_seed_offset,
float scale, int64_t max_seqlen_q,
float dropout, int64_t max_seqlen_k,
bool causal, float scale,
bool return_softmax, float dropout,
bool is_test, bool causal,
DenseTensor* out, bool return_softmax,
DenseTensor* softmax, bool is_test,
DenseTensor* softmax_lse, const std::string& rng_name,
DenseTensor* seed_offset) { DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim] // q,k,v [total_*, num_heads, head_dim]
...@@ -69,126 +70,79 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -69,126 +70,79 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
phi::errors::InvalidArgument("flash_attn_raw receive input with dim " phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]")); "[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0]; const int64_t total_q = dims[0];
int64_t num_heads = dims[1]; const int num_heads = dims[1];
int64_t head_size = dims[2]; const int head_size = dims[2];
int64_t total_k = k.dims()[0]; const int total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1; const int num_heads_k = k.dims()[1];
const int batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) { // TODO(umiswing): add deterministic in fa2.
num_splits = 1; // int num_splits = 0; // 0 for an internal heuristic, which is optimal
} // if (FLAGS_cudnn_deterministic) {
bool zero_tensors = false; // num_splits = 1;
// }
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32; // TODO(umiswing): add shape check
auto seed_offset_pair = gen->IncrementOffset(inc);
FlashAttnFwdParamsV2<T> params =
uint64_t seed = seed_offset_pair.first; FlashAttnFwdParamsV2<T>(ctx,
uint64_t offset = seed_offset_pair.second; batch_size,
max_seqlen_q,
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset max_seqlen_k,
<< ", num_splits:" << num_splits; num_heads,
num_heads_k,
seed_offset->Resize({2}); head_size,
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset); dropout,
seed_offset_data[0] = static_cast<int64_t>(seed); scale,
seed_offset_data[1] = static_cast<int64_t>(offset); causal,
return_softmax,
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; q.dtype(),
is_test,
softmax_lse->Resize({batch_size, num_heads, seq_len_q}); rng_name,
ctx.template Alloc<float>(softmax_lse); fixed_seed_offset.get_ptr(),
softmax,
if (return_softmax) { softmax_lse,
// may allocate more space than *max_seqlen_k* seed_offset);
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k = VLOG(4) << "FlashAttn fwd seed: " << params.seed
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c; << ", offset: " << params.offset;
if (max_seqlen_k <= 128) {
seq_len_k = 128; const bool succ = phi::dynload::flash_attn_varlen_fwd(
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ =
phi::dynload::flash_attn_fwd(q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd(
q.data(), q.data(),
k.data(), k.data(),
v.data(), v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(), out->data(),
cu_seqlens_q.data(), params.return_softmax ? softmax->data() : nullptr,
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(), softmax_lse->data(),
return_softmax ? softmax->data() : nullptr, params.batch_size,
workspace_size > 0 ? workspace.data() : nullptr, params.max_seqlen_q,
&workspace_size, params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream, stream,
seed, params.seed,
offset); params.offset);
if (!succ) { if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
} }
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
...@@ -197,10 +151,12 @@ void FlashAttnKernel(const Context& ctx, ...@@ -197,10 +151,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -215,51 +171,81 @@ void FlashAttnKernel(const Context& ctx, ...@@ -215,51 +171,81 @@ void FlashAttnKernel(const Context& ctx,
"flash_attn receive input with dim " "flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]")); "[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0]; const int batch_size = dims[0];
int64_t seq_len_q = dims[1]; const int seqlen_q = dims[1];
int64_t num_heads = dims[2]; const int num_heads = dims[2];
int64_t head_size = dims[3]; const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
// TODO(umiswing): Add check shape
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
int64_t seq_len_k = k.dims()[1]; VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
int64_t total_q = batch_size * seq_len_q; ctx.template Alloc<T>(out);
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size); cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< "], v[" << v.dims() << "]"; << ", offset: " << params.offset;
DenseTensor q_t_s, k_t_s, v_t_s; bool succ = phi::dynload::flash_attn_fwd(
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); q.data(),
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); k.data(),
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); v.data(),
params.rng_state.data(),
DenseTensor cu_seqlens_q; out->data(),
DenseTensor cu_seqlens_k; params.return_softmax ? params.softmax->data() : nullptr,
ArangeNullaryKernel<int32_t, Context>( params.softmax_lse->data(),
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); params.batch_size,
ArangeNullaryKernel<int32_t, Context>( params.max_seqlen_q,
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); params.max_seqlen_k,
params.seqlen_q_rounded,
FlashAttnUnpaddedKernel<T, Context>(ctx, params.seqlen_k_rounded,
q_t_s, params.num_heads,
k_t_s, params.num_heads_k,
v_t_s, params.head_size,
cu_seqlens_q, params.head_size_rounded,
cu_seqlens_k, params.dropout,
seq_len_q, params.scale,
seq_len_k, params.causal,
scale, params.return_softmax,
dropout, params.is_bf16,
causal, stream,
return_softmax, params.seed,
is_test, params.offset);
out,
softmax,
softmax_lse,
seed_offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
...@@ -270,11 +256,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, ...@@ -270,11 +256,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel, phi::FlashAttnUnpaddedKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn, PD_REGISTER_KERNEL(flash_attn,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlashAttnKernel, phi::FlashAttnKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
// for padded kernel, max_seqlen_q and seqlen_q is the same.
int64_t max_seqlen_q;
// for padded kernel, max_seqlen_k and seqlen_k is the same.
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool return_softmax;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor rng_state;
DenseTensor* softmax;
DenseTensor* softmax_lse;
DenseTensor* seed_offset;
FlashAttnFwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const bool _return_softmax,
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
head_size(_head_size),
scale(_scale),
dropout(_dropout),
causal(_causal),
return_softmax(_return_softmax),
softmax(_softmax),
softmax_lse(_softmax_lse),
seed_offset(_seed_offset) {
dropout = is_test ? 0.0f : _dropout;
is_bf16 = q_dtype == DataType::BFLOAT16;
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_lse->Resize({batch_size, num_heads, max_seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
softmax->Resize(
{batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded});
ctx.template Alloc<T>(softmax);
}
}
};
struct FlashAttnBwdParamsV2 {
int batch_size;
int64_t max_seqlen_q;
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor softmax_d;
DenseTensor dq_accum;
DenseTensor rng_state;
FlashAttnBwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const DataType q_dtype,
const int64_t* seed_offset_data)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
dropout(_dropout),
scale(_scale),
causal(_causal) {
is_bf16 = q_dtype == DataType::BFLOAT16;
seed = static_cast<uint64_t>(seed_offset_data[0]);
offset = static_cast<uint64_t>(seed_offset_data[1]);
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_d = Empty<float>(ctx, {batch_size, num_heads, seqlen_q_rounded});
dq_accum = Empty<float>(
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
} // namespace phi
...@@ -37,3 +37,4 @@ from . import dist_shape ...@@ -37,3 +37,4 @@ from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale from . import dist_scale
from . import dist_dropout from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
...@@ -56,9 +56,27 @@ def attention_naive(q, k, v, causal=False): ...@@ -56,9 +56,27 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3]) return paddle.transpose(o, [0, 2, 1, 3])
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0
)
is_sm90 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 9
and paddle.device.cuda.get_device_capability()[1] == 0
)
is_sm_supported = is_sm8x or is_sm90
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030, not core.is_compiled_with_cuda()
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3", or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 8.x or 90",
) )
class TestFlashAttentionAPI(unittest.TestCase): class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -24,6 +24,9 @@ def flash_attention( ...@@ -24,6 +24,9 @@ def flash_attention(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True, training=True,
name=None, name=None,
): ):
...@@ -57,7 +60,9 @@ def flash_attention( ...@@ -57,7 +60,9 @@ def flash_attention(
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax. return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase. training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
...@@ -84,10 +89,12 @@ def flash_attention( ...@@ -84,10 +89,12 @@ def flash_attention(
query, query,
key, key,
value, value,
fixed_seed_offset,
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
not training, not training,
rng_name,
) )
return result_attention, result_softmax if return_softmax else None return result_attention, result_softmax if return_softmax else None
...@@ -101,6 +108,7 @@ def flash_attention( ...@@ -101,6 +108,7 @@ def flash_attention(
'q': query, 'q': query,
'k': key, 'k': key,
'v': value, 'v': value,
'fixed_seed_offset': fixed_seed_offset,
} }
outputs = { outputs = {
'out': out, 'out': out,
...@@ -117,6 +125,7 @@ def flash_attention( ...@@ -117,6 +125,7 @@ def flash_attention(
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training, 'is_test': not training,
'rng_name': rng_name,
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
...@@ -134,6 +143,8 @@ def flash_attn_unpadded( ...@@ -134,6 +143,8 @@ def flash_attn_unpadded(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True, training=True,
name=None, name=None,
): ):
...@@ -174,6 +185,8 @@ def flash_attn_unpadded( ...@@ -174,6 +185,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax. return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase. training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
...@@ -203,6 +216,7 @@ def flash_attn_unpadded( ...@@ -203,6 +216,7 @@ def flash_attn_unpadded(
value, value,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
scale, scale,
...@@ -210,6 +224,7 @@ def flash_attn_unpadded( ...@@ -210,6 +224,7 @@ def flash_attn_unpadded(
causal, causal,
return_softmax, return_softmax,
not training, not training,
rng_name,
) )
return result_attention, result_softmax if return_softmax else None return result_attention, result_softmax if return_softmax else None
...@@ -225,6 +240,7 @@ def flash_attn_unpadded( ...@@ -225,6 +240,7 @@ def flash_attn_unpadded(
'v': value, 'v': value,
'cu_seqlens_q': cu_seqlens_q, 'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k, 'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
} }
outputs = { outputs = {
'out': out, 'out': out,
...@@ -244,6 +260,7 @@ def flash_attn_unpadded( ...@@ -244,6 +260,7 @@ def flash_attn_unpadded(
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training, 'is_test': not training,
'rng_name': rng_name,
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册