未验证 提交 0473369f 编写于 作者: U umiswing 提交者: GitHub

[WIP] Integration flash attention 2 (#55758)

* Work for fa-2 padded fwd. Code to be cleaned.

* Work for fa2 unpadded fwd.

* Work for padded-bwd, dk get small diff on np.random.seed(0)

* Anyway I pass paddle's utest, except return softmax without dropout.

* Clean code.

* Modify interface.

* Clean code and add some check.

* Easy compile for dev.

* Fix ci.

* Fix ci-build.

* Add std c++17 option again.

* Limit max job when compiling fa2.

* Remove const_cast

* Add fwd params, to be cleaned.

* Clean code.

* Add bwd params.

* Clean code.

* Add enforce.

* Use v2.0.4

* Pass RNG state to fa2 capi

* Fix review.

* Add assert

* Skip compile for sm less than 80.
上级 785684ad
...@@ -17,7 +17,7 @@ include(ExternalProject) ...@@ -17,7 +17,7 @@ include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN) add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn) set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47) set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
...@@ -62,7 +62,7 @@ else() ...@@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS}) set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif() endif()
...@@ -92,6 +92,8 @@ ExternalProject_Add( ...@@ -92,6 +92,8 @@ ExternalProject_Add(
-DBUILD_SHARED=ON -DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
...@@ -548,10 +548,15 @@ if(WITH_GPU ...@@ -548,10 +548,15 @@ if(WITH_GPU
list(APPEND third_party_deps extern_cutlass) list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON) set(WITH_CUTLASS ON)
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2) if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn) include(external/flashattn)
list(APPEND third_party_deps extern_flashattn) list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON) set(WITH_FLASHATTN ON)
break()
endif()
endforeach()
endif() endif()
endif() endif()
......
...@@ -45,7 +45,9 @@ extern void* flashattn_dso_handle; ...@@ -45,7 +45,9 @@ extern void* flashattn_dso_handle;
#define FLASHATTN_ROUTINE_EACH(__macro) \ #define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \ __macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \ __macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \ __macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \ __macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error); __macro(flash_attn_error);
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif #endif
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
...@@ -55,115 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -55,115 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
ctx.template Alloc<T>(dk); ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv); ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream(); const cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim] // q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims(); auto dims = q.dims();
int64_t total_q = dims[0]; const int64_t total_q = dims[0];
int64_t num_heads = dims[1]; const int batch_size = cu_seqlens_q.numel() - 1;
int64_t head_size = dims[2]; const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
int64_t total_k = k.dims()[0]; const int head_size = dims[2];
int64_t batch_size = cu_seqlens_q.numel() - 1; const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) { // TODO(umiswing): add deterministic in fa2.
num_splits = 1; // int num_splits = 0; // 0 for an internal heuristic, which is optimal
} // if (FLAGS_cudnn_deterministic) {
bool zero_tensors = false; // num_splits = 1;
// }
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]); const bool zero_tensors = false;
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
// TODO(umiswing): add shape check
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset PADDLE_ENFORCE_EQ(
<< ", num_splits:" << num_splits; head_size_og,
head_size,
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; phi::errors::InvalidArgument(
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q}); "flash_attn_bwd receive input with head_size_og == head_size"));
uint64_t workspace_size;
// calculate workspace size before execution FlashAttnBwdParamsV2 params =
bool succ = phi::dynload::flash_attn_bwd( FlashAttnBwdParamsV2(ctx,
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size, batch_size,
num_heads,
head_size,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout, dropout,
scale, scale,
zero_tensors,
causal, causal,
is_bf16, q.dtype(),
num_splits, seed_offset.data<int64_t>());
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) { VLOG(4) << "FlashAttn bwd seed: " << params.seed
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); << ", offset: " << params.offset;
}
DenseTensor workspace; const bool succ =
if (workspace_size > 0) { phi::dynload::flash_attn_varlen_bwd(dout.data(),
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd(
q.data(), q.data(),
k.data(), k.data(),
v.data(), v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(), dq->data(),
dk->data(), dk->data(),
dv->data(), dv->data(),
out.data(), params.dq_accum.data(),
dout.data(), params.batch_size,
cu_seqlens_q.data(), params.max_seqlen_q,
cu_seqlens_k.data(), params.max_seqlen_k,
total_q, params.seqlen_q_rounded,
total_k, params.seqlen_k_rounded,
batch_size, params.num_heads,
num_heads, params.num_heads_k,
head_size, params.head_size,
max_seqlen_q, params.head_size_rounded,
max_seqlen_k, params.dropout,
dropout, params.scale,
scale, params.causal,
zero_tensors, params.is_bf16,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream, stream,
seed, params.seed,
offset); params.offset);
if (!succ) { if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
} }
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
...@@ -185,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -185,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
// q,k,v [batch_size, seq_len, num_heads, head_dim] // q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims(); auto dims = q.dims();
int64_t batch_size = dims[0]; const int batch_size = dims[0];
int64_t seq_len_q = dims[1]; const int seqlen_q = dims[1];
int64_t num_heads = dims[2]; const int num_heads = dims[2];
int64_t head_size = dims[3]; const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
int64_t seq_len_k = k.dims()[1]; const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k; // TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
float scale = 1.0f / std::sqrt(head_size); head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]"; << "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s; const float scale = 1.0f / std::sqrt(head_size);
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); FlashAttnBwdParamsV2 params =
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); FlashAttnBwdParamsV2(ctx,
batch_size,
DenseTensor cu_seqlens_q; seqlen_q,
DenseTensor cu_seqlens_k; seqlen_k,
ArangeNullaryKernel<int32_t, Context>( num_heads,
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); num_heads_k,
ArangeNullaryKernel<int32_t, Context>( head_size,
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout, dropout,
scale,
causal, causal,
dq, q.dtype(),
dk, seed_offset.data<int64_t>());
dv);
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif #endif
DECLARE_bool(cudnn_deterministic); DECLARE_bool(cudnn_deterministic);
...@@ -55,12 +56,10 @@ void FlashAttnUnpaddedKernel( ...@@ -55,12 +56,10 @@ void FlashAttnUnpaddedKernel(
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
DenseTensor* seed_offset) { DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim] // q,k,v [total_*, num_heads, head_dim]
...@@ -71,141 +70,79 @@ void FlashAttnUnpaddedKernel( ...@@ -71,141 +70,79 @@ void FlashAttnUnpaddedKernel(
phi::errors::InvalidArgument("flash_attn_raw receive input with dim " phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]")); "[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0]; const int64_t total_q = dims[0];
int64_t num_heads = dims[1]; const int num_heads = dims[1];
int64_t head_size = dims[2]; const int head_size = dims[2];
int64_t total_k = k.dims()[0]; const int total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1; const int num_heads_k = k.dims()[1];
const int batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal // TODO(umiswing): add deterministic in fa2.
if (FLAGS_cudnn_deterministic) { // int num_splits = 0; // 0 for an internal heuristic, which is optimal
num_splits = 1; // if (FLAGS_cudnn_deterministic) {
} // num_splits = 1;
bool zero_tensors = false; // }
uint64_t seed;
uint64_t offset;
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size; // TODO(umiswing): add shape check
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple FlashAttnFwdParamsV2<T> params =
// calculate workspace size before execution FlashAttnFwdParamsV2<T>(ctx,
bool succ =
phi::dynload::flash_attn_fwd(q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size, batch_size,
num_heads,
head_size,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout, dropout,
scale, scale,
zero_tensors,
causal, causal,
is_bf16, return_softmax,
num_splits, q.dtype(),
softmax_lse->data(), is_test,
return_softmax ? softmax->data() : nullptr, rng_name,
nullptr, fixed_seed_offset.get_ptr(),
&workspace_size, softmax,
stream, softmax_lse,
seed, seed_offset);
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace; VLOG(4) << "FlashAttn fwd seed: " << params.seed
if (workspace_size > 0) { << ", offset: " << params.offset;
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd( const bool succ = phi::dynload::flash_attn_varlen_fwd(
q.data(), q.data(),
k.data(), k.data(),
v.data(), v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(), out->data(),
cu_seqlens_q.data(), params.return_softmax ? softmax->data() : nullptr,
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(), softmax_lse->data(),
return_softmax ? softmax->data() : nullptr, params.batch_size,
workspace_size > 0 ? workspace.data() : nullptr, params.max_seqlen_q,
&workspace_size, params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream, stream,
seed, params.seed,
offset); params.offset);
if (!succ) { if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
} }
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
...@@ -234,53 +171,81 @@ void FlashAttnKernel(const Context& ctx, ...@@ -234,53 +171,81 @@ void FlashAttnKernel(const Context& ctx,
"flash_attn receive input with dim " "flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]")); "[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0]; const int batch_size = dims[0];
int64_t seq_len_q = dims[1]; const int seqlen_q = dims[1];
int64_t num_heads = dims[2]; const int num_heads = dims[2];
int64_t head_size = dims[3]; const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
int64_t seq_len_k = k.dims()[1]; const int num_heads_k = k.dims()[2];
int64_t total_q = batch_size * seq_len_q; // TODO(umiswing): Add check shape
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size); const float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() FlashAttnFwdParamsV2<T> params =
<< "], v[" << v.dims() << "]"; FlashAttnFwdParamsV2<T>(ctx,
batch_size,
DenseTensor q_t_s, k_t_s, v_t_s; seqlen_q,
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); seqlen_k,
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); num_heads,
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); num_heads_k,
head_size,
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
seq_len_q,
seq_len_k,
scale,
dropout, dropout,
scale,
causal, causal,
return_softmax, return_softmax,
q.dtype(),
is_test, is_test,
rng_name, rng_name,
out, fixed_seed_offset.get_ptr(),
softmax, softmax,
softmax_lse, softmax_lse,
seed_offset); seed_offset);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif #endif
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
// for padded kernel, max_seqlen_q and seqlen_q is the same.
int64_t max_seqlen_q;
// for padded kernel, max_seqlen_k and seqlen_k is the same.
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool return_softmax;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor rng_state;
DenseTensor* softmax;
DenseTensor* softmax_lse;
DenseTensor* seed_offset;
FlashAttnFwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const bool _return_softmax,
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
head_size(_head_size),
scale(_scale),
dropout(_dropout),
causal(_causal),
return_softmax(_return_softmax),
softmax(_softmax),
softmax_lse(_softmax_lse),
seed_offset(_seed_offset) {
dropout = is_test ? 0.0f : _dropout;
is_bf16 = q_dtype == DataType::BFLOAT16;
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_lse->Resize({batch_size, num_heads, max_seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
softmax->Resize(
{batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded});
ctx.template Alloc<T>(softmax);
}
}
};
struct FlashAttnBwdParamsV2 {
int batch_size;
int64_t max_seqlen_q;
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor softmax_d;
DenseTensor dq_accum;
DenseTensor rng_state;
FlashAttnBwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const DataType q_dtype,
const int64_t* seed_offset_data)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
dropout(_dropout),
scale(_scale),
causal(_causal) {
is_bf16 = q_dtype == DataType::BFLOAT16;
seed = static_cast<uint64_t>(seed_offset_data[0]);
offset = static_cast<uint64_t>(seed_offset_data[1]);
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_d = Empty<float>(ctx, {batch_size, num_heads, seqlen_q_rounded});
dq_accum = Empty<float>(
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
} // namespace phi
...@@ -57,25 +57,27 @@ def attention_naive(q, k, v, causal=False): ...@@ -57,25 +57,27 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3]) return paddle.transpose(o, [0, 2, 1, 3])
is_sm75 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 7
and paddle.device.cuda.get_device_capability()[1] == 5
)
is_sm8x = ( is_sm8x = (
core.is_compiled_with_cuda() core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8 and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0 and paddle.device.cuda.get_device_capability()[1] >= 0
) )
is_sm_supported = is_sm75 or is_sm8x
is_sm90 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 9
and paddle.device.cuda.get_device_capability()[1] == 0
)
is_sm_supported = is_sm8x or is_sm90
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
or get_cuda_version() < 11030 or get_cuda_version() < 11040
or not is_sm_supported, or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 7.5 or 8.x", "and device's compute capability must be 8.x or 90",
) )
class TestFlashAttentionAPI(unittest.TestCase): class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self): def setUp(self):
......
Subproject commit 18106c1ba0ccee81b97ca947397c08a141815a47 Subproject commit b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册