未验证 提交 0473369f 编写于 作者: U umiswing 提交者: GitHub

[WIP] Integration flash attention 2 (#55758)

* Work for fa-2 padded fwd. Code to be cleaned.

* Work for fa2 unpadded fwd.

* Work for padded-bwd, dk get small diff on np.random.seed(0)

* Anyway I pass paddle's utest, except return softmax without dropout.

* Clean code.

* Modify interface.

* Clean code and add some check.

* Easy compile for dev.

* Fix ci.

* Fix ci-build.

* Add std c++17 option again.

* Limit max job when compiling fa2.

* Remove const_cast

* Add fwd params, to be cleaned.

* Clean code.

* Add bwd params.

* Clean code.

* Add enforce.

* Use v2.0.4

* Pass RNG state to fa2 capi

* Fix review.

* Add assert

* Skip compile for sm less than 80.
上级 785684ad
......@@ -17,7 +17,7 @@ include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
......@@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
......@@ -92,6 +92,8 @@ ExternalProject_Add(
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
......@@ -548,10 +548,15 @@ if(WITH_GPU
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
break()
endif()
endforeach()
endif()
endif()
......
......@@ -45,7 +45,9 @@ extern void* flashattn_dso_handle;
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error);
......
......@@ -25,6 +25,7 @@
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
DECLARE_bool(cudnn_deterministic);
......@@ -55,115 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
const cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
const int64_t total_q = dims[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
const bool zero_tensors = false;
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
q.dtype(),
seed_offset.data<int64_t>());
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd(
const bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
seed,
offset);
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -185,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
dq,
dk,
dv);
q.dtype(),
seed_offset.data<int64_t>());
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......
......@@ -27,6 +27,7 @@
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
DECLARE_bool(cudnn_deterministic);
......@@ -55,12 +56,10 @@ void FlashAttnUnpaddedKernel(
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
......@@ -71,141 +70,79 @@ void FlashAttnUnpaddedKernel(
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
const int64_t total_q = dims[0];
const int num_heads = dims[1];
const int head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
const int batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
uint64_t seed;
uint64_t offset;
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
uint64_t workspace_size;
// TODO(umiswing): add shape check
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ =
phi::dynload::flash_attn_fwd(q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
succ = phi::dynload::flash_attn_fwd(
const bool succ = phi::dynload::flash_attn_varlen_fwd(
q.data(),
k.data(),
v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
params.return_softmax ? softmax->data() : nullptr,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
seed,
offset);
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -234,53 +171,81 @@ void FlashAttnKernel(const Context& ctx,
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
// TODO(umiswing): Add check shape
float scale = 1.0f / std::sqrt(head_size);
const float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
seq_len_q,
seq_len_k,
scale,
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
out,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
// for padded kernel, max_seqlen_q and seqlen_q is the same.
int64_t max_seqlen_q;
// for padded kernel, max_seqlen_k and seqlen_k is the same.
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool return_softmax;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor rng_state;
DenseTensor* softmax;
DenseTensor* softmax_lse;
DenseTensor* seed_offset;
FlashAttnFwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const bool _return_softmax,
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
head_size(_head_size),
scale(_scale),
dropout(_dropout),
causal(_causal),
return_softmax(_return_softmax),
softmax(_softmax),
softmax_lse(_softmax_lse),
seed_offset(_seed_offset) {
dropout = is_test ? 0.0f : _dropout;
is_bf16 = q_dtype == DataType::BFLOAT16;
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_lse->Resize({batch_size, num_heads, max_seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
softmax->Resize(
{batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded});
ctx.template Alloc<T>(softmax);
}
}
};
struct FlashAttnBwdParamsV2 {
int batch_size;
int64_t max_seqlen_q;
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor softmax_d;
DenseTensor dq_accum;
DenseTensor rng_state;
FlashAttnBwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const DataType q_dtype,
const int64_t* seed_offset_data)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
dropout(_dropout),
scale(_scale),
causal(_causal) {
is_bf16 = q_dtype == DataType::BFLOAT16;
seed = static_cast<uint64_t>(seed_offset_data[0]);
offset = static_cast<uint64_t>(seed_offset_data[1]);
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_d = Empty<float>(ctx, {batch_size, num_heads, seqlen_q_rounded});
dq_accum = Empty<float>(
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
} // namespace phi
......@@ -57,25 +57,27 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3])
is_sm75 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 7
and paddle.device.cuda.get_device_capability()[1] == 5
)
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0
)
is_sm_supported = is_sm75 or is_sm8x
is_sm90 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 9
and paddle.device.cuda.get_device_capability()[1] == 0
)
is_sm_supported = is_sm8x or is_sm90
@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11030
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"and device's compute capability must be 7.5 or 8.x",
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 8.x or 90",
)
class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
......
Subproject commit 18106c1ba0ccee81b97ca947397c08a141815a47
Subproject commit b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册