diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index b071f25dd4d3a57d4585a87ffcbcf14189773e24..515ffc958680365e0fc55366203b2647dcbcfef9 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -17,7 +17,7 @@ include(ExternalProject) add_definitions(-DPADDLE_WITH_FLASHATTN) set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) -set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) +set(FLASHATTN_SOURCE_SUBDIR csrc) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn) set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47) @@ -62,7 +62,7 @@ else() set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS}) set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) - set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) endif() @@ -92,6 +92,8 @@ ExternalProject_Add( -DBUILD_SHARED=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_JOB_POOL_COMPILE:STRING=compile + -DCMAKE_JOB_POOLS:STRING=compile=4 ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 48dcfa97020b0f966194378e61b4237cbc30ad62..6365adfda243b33f51b2c299bdeb1364957e1a70 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -548,10 +548,15 @@ if(WITH_GPU list(APPEND third_party_deps extern_cutlass) set(WITH_CUTLASS ON) endif() - if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2) - include(external/flashattn) - list(APPEND third_party_deps extern_flashattn) - set(WITH_FLASHATTN ON) + if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4) + foreach(arch ${NVCC_ARCH_BIN}) + if(${arch} GREATER_EQUAL 80) + include(external/flashattn) + list(APPEND third_party_deps extern_flashattn) + set(WITH_FLASHATTN ON) + break() + endif() + endforeach() endif() endif() diff --git a/paddle/phi/backends/dynload/flashattn.h b/paddle/phi/backends/dynload/flashattn.h index 8948ec6a46988b980142ba2d90da3bb4109e873b..e4728cf43405e1c0d246c300e8c8554f3eb26b6f 100644 --- a/paddle/phi/backends/dynload/flashattn.h +++ b/paddle/phi/backends/dynload/flashattn.h @@ -45,7 +45,9 @@ extern void* flashattn_dso_handle; #define FLASHATTN_ROUTINE_EACH(__macro) \ __macro(flash_attn_fwd); \ + __macro(flash_attn_varlen_fwd); \ __macro(flash_attn_bwd); \ + __macro(flash_attn_varlen_bwd); \ __macro(flash_attn_fwd_with_bias_and_mask); \ __macro(flash_attn_bwd_with_bias_and_mask); \ __macro(flash_attn_error); diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index b75f4b4aea4b88cc790d64ba6389d3677a37b38b..1ae6c887aee027fe0dba0fa700ee841e64be6d12 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -25,6 +25,7 @@ #ifdef PADDLE_WITH_FLASHATTN #include "paddle/phi/backends/dynload/flashattn.h" +#include "paddle/phi/kernels/gpu/flash_attn_utils.h" #endif DECLARE_bool(cudnn_deterministic); @@ -55,115 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ctx.template Alloc(dk); ctx.template Alloc(dv); - cudaStream_t stream = ctx.stream(); - bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; + const cudaStream_t stream = ctx.stream(); // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); - int64_t total_q = dims[0]; - int64_t num_heads = dims[1]; - int64_t head_size = dims[2]; - - int64_t total_k = k.dims()[0]; - int64_t batch_size = cu_seqlens_q.numel() - 1; - - int num_splits = 0; // 0 for an internal heuristic, which is optimal - if (FLAGS_cudnn_deterministic) { - num_splits = 1; - } - bool zero_tensors = false; - - const int64_t* seed_offset_data = seed_offset.data(); - uint64_t seed = static_cast(seed_offset_data[0]); - uint64_t offset = static_cast(seed_offset_data[1]); - - VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset - << ", num_splits:" << num_splits; - - int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; - DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); - - uint64_t workspace_size; - - // calculate workspace size before execution - bool succ = phi::dynload::flash_attn_bwd( - q.data(), - k.data(), - v.data(), - dq->data(), - dk->data(), - dv->data(), - nullptr, // for calculation workspace size - dout.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - const_cast(softmax_lse.data()), - dsoftmax.data(), - nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - - DenseTensor workspace; - if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); - } - - succ = phi::dynload::flash_attn_bwd( - q.data(), - k.data(), - v.data(), - dq->data(), - dk->data(), - dv->data(), - out.data(), - dout.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, + const int64_t total_q = dims[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = dims[1]; + const int head_size_og = dout.dims()[2]; + const int head_size = dims[2]; + const int total_k = k.dims()[0]; + const int num_heads_k = k.dims()[1]; + + // TODO(umiswing): add deterministic in fa2. + // int num_splits = 0; // 0 for an internal heuristic, which is optimal + // if (FLAGS_cudnn_deterministic) { + // num_splits = 1; + // } + + const bool zero_tensors = false; + + // TODO(umiswing): add shape check + PADDLE_ENFORCE_EQ( + head_size_og, head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - const_cast(softmax_lse.data()), - dsoftmax.data(), - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset); + phi::errors::InvalidArgument( + "flash_attn_bwd receive input with head_size_og == head_size")); + + FlashAttnBwdParamsV2 params = + FlashAttnBwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + q.dtype(), + seed_offset.data()); + + VLOG(4) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; + + const bool succ = + phi::dynload::flash_attn_varlen_bwd(dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.is_bf16, + stream, + params.seed, + params.offset); if (!succ) { PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); } - +#else + PADDLE_THROW(phi::errors::Unimplemented( + "FlashAttention is unsupported, please set use_flash_attn to false.")); #endif } @@ -185,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx, // q,k,v [batch_size, seq_len, num_heads, head_dim] auto dims = q.dims(); - int64_t batch_size = dims[0]; - int64_t seq_len_q = dims[1]; - int64_t num_heads = dims[2]; - int64_t head_size = dims[3]; - - int64_t seq_len_k = k.dims()[1]; - - int64_t total_q = batch_size * seq_len_q; - int64_t total_k = batch_size * seq_len_k; - - float scale = 1.0f / std::sqrt(head_size); + const int batch_size = dims[0]; + const int seqlen_q = dims[1]; + const int num_heads = dims[2]; + const int head_size_og = dout.dims()[3]; + const int head_size = dims[3]; + const int seqlen_k = k.dims()[1]; + const int num_heads_k = k.dims()[2]; + + // TODO(umiswing): add shape check + PADDLE_ENFORCE_EQ( + head_size_og, + head_size, + phi::errors::InvalidArgument( + "flash_attn_bwd receive input with head_size_og == head_size")); VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() << "], v[" << v.dims() << "]"; - DenseTensor q_t_s, k_t_s, v_t_s; - q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); - k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); - v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); - - DenseTensor cu_seqlens_q; - DenseTensor cu_seqlens_k; - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); - - FlashAttnUnpaddedGradKernel(ctx, - q_t_s, - k_t_s, - v_t_s, - cu_seqlens_q, - cu_seqlens_k, - out, - softmax_lse, - seed_offset, - dout, - seq_len_q, - seq_len_k, - scale, - dropout, - causal, - dq, - dk, - dv); + const float scale = 1.0f / std::sqrt(head_size); + + FlashAttnBwdParamsV2 params = + FlashAttnBwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + q.dtype(), + seed_offset.data()); + + ctx.template Alloc(dq); + ctx.template Alloc(dk); + ctx.template Alloc(dv); + + cudaStream_t stream = ctx.stream(); + VLOG(4) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; + + const bool succ = phi::dynload::flash_attn_bwd(dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.is_bf16, + stream, + params.seed, + params.offset); + + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention-2, detail information is", + phi::dynload::flash_attn_error())); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "FlashAttention is unsupported, please set use_flash_attn to false.")); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 714edf4be6f3c5313c81ee0581573326de6be6c7..e943b7bbf78519cc1ff10be93e83c9d1d8302ed9 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -27,6 +27,7 @@ #ifdef PADDLE_WITH_FLASHATTN #include "paddle/phi/backends/dynload/flashattn.h" +#include "paddle/phi/kernels/gpu/flash_attn_utils.h" #endif DECLARE_bool(cudnn_deterministic); @@ -55,12 +56,10 @@ void FlashAttnUnpaddedKernel( DenseTensor* softmax_lse, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN - if (is_test) dropout = 0.0f; ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); - bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; // q,k,v [total_*, num_heads, head_dim] @@ -71,141 +70,79 @@ void FlashAttnUnpaddedKernel( phi::errors::InvalidArgument("flash_attn_raw receive input with dim " "[total_seq_len, num_heads, head_dim]")); - int64_t total_q = dims[0]; - int64_t num_heads = dims[1]; - int64_t head_size = dims[2]; - - int64_t total_k = k.dims()[0]; - int64_t batch_size = cu_seqlens_q.numel() - 1; - - int num_splits = 0; // 0 for an internal heuristic, which is optimal - if (FLAGS_cudnn_deterministic) { - num_splits = 1; - } - bool zero_tensors = false; - - uint64_t seed; - uint64_t offset; - - if (fixed_seed_offset.get_ptr()) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset.get_ptr()->data(); - seed = static_cast(fixed_seed_offset_data[0]); - offset = static_cast(fixed_seed_offset_data[1]); - } else { - uint64_t inc = batch_size * num_heads * 32; - std::pair seed_offset_pair; - if (rng_name != "") { - auto gen = phi::GetRandomSeedGenerator(rng_name); - seed_offset_pair = gen->IncrementOffset(inc); - } else { - auto* gen = ctx.GetGenerator(); - seed_offset_pair = gen->IncrementOffset(inc); - } - seed = seed_offset_pair.first; - offset = seed_offset_pair.second; - } - - VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset - << ", num_splits:" << num_splits; - - seed_offset->Resize({2}); - int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); - seed_offset_data[0] = static_cast(seed); - seed_offset_data[1] = static_cast(offset); - - int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; - - softmax_lse->Resize({batch_size, num_heads, seq_len_q}); - ctx.template Alloc(softmax_lse); - - if (return_softmax) { - // may allocate more space than *max_seqlen_k* - int64_t blocksize_c = head_size > 64 ? 128 : 256; - int64_t seq_len_k = - ((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c; - if (max_seqlen_k <= 128) { - seq_len_k = 128; - } else if (max_seqlen_k <= 256) { - seq_len_k = 256; - } - softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k}); - ctx.template Alloc(softmax); - } - - uint64_t workspace_size; - - // TODO(kuizhiqing) pass allocation/empty func in capi to decouple - // calculate workspace size before execution - bool succ = - phi::dynload::flash_attn_fwd(q.data(), - k.data(), - v.data(), - nullptr, // for calculation workspace size - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse->data(), - return_softmax ? softmax->data() : nullptr, - nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - - DenseTensor workspace; - if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); - } - - succ = phi::dynload::flash_attn_fwd( + const int64_t total_q = dims[0]; + const int num_heads = dims[1]; + const int head_size = dims[2]; + + const int total_k = k.dims()[0]; + const int num_heads_k = k.dims()[1]; + const int batch_size = cu_seqlens_q.numel() - 1; + + // TODO(umiswing): add deterministic in fa2. + // int num_splits = 0; // 0 for an internal heuristic, which is optimal + // if (FLAGS_cudnn_deterministic) { + // num_splits = 1; + // } + + // TODO(umiswing): add shape check + + FlashAttnFwdParamsV2 params = + FlashAttnFwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset.get_ptr(), + softmax, + softmax_lse, + seed_offset); + + VLOG(4) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; + + const bool succ = phi::dynload::flash_attn_varlen_fwd( q.data(), k.data(), v.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + params.rng_state.data(), out->data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, + params.return_softmax ? softmax->data() : nullptr, softmax_lse->data(), - return_softmax ? softmax->data() : nullptr, - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.return_softmax, + params.is_bf16, stream, - seed, - offset); + params.seed, + params.offset); if (!succ) { PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); } - +#else + PADDLE_THROW(phi::errors::Unimplemented( + "FlashAttention is unsupported, please set use_flash_attn to false.")); #endif } @@ -234,53 +171,81 @@ void FlashAttnKernel(const Context& ctx, "flash_attn receive input with dim " "[batch_size, seq_len, num_heads, head_dim]")); - int64_t batch_size = dims[0]; - int64_t seq_len_q = dims[1]; - int64_t num_heads = dims[2]; - int64_t head_size = dims[3]; + const int batch_size = dims[0]; + const int seqlen_q = dims[1]; + const int num_heads = dims[2]; + const int head_size = dims[3]; + const int seqlen_k = k.dims()[1]; + const int num_heads_k = k.dims()[2]; + + // TODO(umiswing): Add check shape + + const float scale = 1.0f / std::sqrt(head_size); + + FlashAttnFwdParamsV2 params = + FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset.get_ptr(), + softmax, + softmax_lse, + seed_offset); - int64_t seq_len_k = k.dims()[1]; + VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; - int64_t total_q = batch_size * seq_len_q; - int64_t total_k = batch_size * seq_len_k; + ctx.template Alloc(out); - float scale = 1.0f / std::sqrt(head_size); + cudaStream_t stream = ctx.stream(); - VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; + VLOG(4) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; - DenseTensor q_t_s, k_t_s, v_t_s; - q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); - k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); - v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size}); - - DenseTensor cu_seqlens_q; - DenseTensor cu_seqlens_k; - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); - ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); - - FlashAttnUnpaddedKernel(ctx, - q_t_s, - k_t_s, - v_t_s, - cu_seqlens_q, - cu_seqlens_k, - fixed_seed_offset, - seq_len_q, - seq_len_k, - scale, - dropout, - causal, - return_softmax, - is_test, - rng_name, - out, - softmax, - softmax_lse, - seed_offset); + bool succ = phi::dynload::flash_attn_fwd( + q.data(), + k.data(), + v.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? params.softmax->data() : nullptr, + params.softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention-2, detail information is", + phi::dynload::flash_attn_error())); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "FlashAttention is unsupported, please set use_flash_attn to false.")); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..62d0f4ec95b37eaf1265b62ae82802fa9caa4297 --- /dev/null +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -0,0 +1,181 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { + +template +struct FlashAttnFwdParamsV2 { + int batch_size; + // for padded kernel, max_seqlen_q and seqlen_q is the same. + int64_t max_seqlen_q; + // for padded kernel, max_seqlen_k and seqlen_k is the same. + int64_t max_seqlen_k; + int seqlen_q_rounded; + int seqlen_k_rounded; + int num_heads; + int num_heads_k; + int head_size; + int head_size_rounded; + float dropout; + float scale; + bool causal; + bool return_softmax; + bool is_bf16; + uint64_t seed; + uint64_t offset; + DenseTensor rng_state; + DenseTensor* softmax; + DenseTensor* softmax_lse; + DenseTensor* seed_offset; + + FlashAttnFwdParamsV2(const GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const bool _return_softmax, + const DataType q_dtype, + const bool is_test, + const std::string& rng_name, + const DenseTensor* const fixed_seed_offset_ptr, + DenseTensor* _softmax, + DenseTensor* _softmax_lse, + DenseTensor* _seed_offset) + : batch_size(_batch_size), + max_seqlen_q(_max_seqlen_q), + max_seqlen_k(_max_seqlen_k), + num_heads(_num_heads), + num_heads_k(_num_heads), + head_size(_head_size), + scale(_scale), + dropout(_dropout), + causal(_causal), + return_softmax(_return_softmax), + softmax(_softmax), + softmax_lse(_softmax_lse), + seed_offset(_seed_offset) { + dropout = is_test ? 0.0f : _dropout; + is_bf16 = q_dtype == DataType::BFLOAT16; + + // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t + // with the same size. + rng_state = Empty(ctx, {2}); + if (fixed_seed_offset_ptr) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset_ptr->data(); + seed = static_cast(fixed_seed_offset_data[0]); + offset = static_cast(fixed_seed_offset_data[1]); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + if (rng_name != "") { + auto gen = phi::GetRandomSeedGenerator(rng_name); + seed_offset_pair = gen->IncrementOffset(inc); + } else { + auto* gen = ctx.GetGenerator(); + seed_offset_pair = gen->IncrementOffset(inc); + } + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; + } + + seed_offset->Resize({2}); + int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); + seed_offset_data[0] = static_cast(seed); + seed_offset_data[1] = static_cast(offset); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + head_size_rounded = round_multiple(head_size, 32); + seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + softmax_lse->Resize({batch_size, num_heads, max_seqlen_q}); + ctx.template Alloc(softmax_lse); + + if (return_softmax) { + softmax->Resize( + {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}); + ctx.template Alloc(softmax); + } + } +}; + +struct FlashAttnBwdParamsV2 { + int batch_size; + int64_t max_seqlen_q; + int64_t max_seqlen_k; + int seqlen_q_rounded; + int seqlen_k_rounded; + int num_heads; + int num_heads_k; + int head_size; + int head_size_rounded; + float dropout; + float scale; + bool causal; + bool is_bf16; + uint64_t seed; + uint64_t offset; + DenseTensor softmax_d; + DenseTensor dq_accum; + DenseTensor rng_state; + + FlashAttnBwdParamsV2(const GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const DataType q_dtype, + const int64_t* seed_offset_data) + : batch_size(_batch_size), + max_seqlen_q(_max_seqlen_q), + max_seqlen_k(_max_seqlen_k), + num_heads(_num_heads), + num_heads_k(_num_heads_k), + head_size(_head_size), + dropout(_dropout), + scale(_scale), + causal(_causal) { + is_bf16 = q_dtype == DataType::BFLOAT16; + seed = static_cast(seed_offset_data[0]); + offset = static_cast(seed_offset_data[1]); + + // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t + // with the same size. + rng_state = Empty(ctx, {2}); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + + head_size_rounded = round_multiple(head_size, 32); + seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + softmax_d = Empty(ctx, {batch_size, num_heads, seqlen_q_rounded}); + dq_accum = Empty( + ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded}); + } +}; +} // namespace phi diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 6bde691bd2f952ee2213efbdbdcbb474db29b81f..cc23331eadf56e8d8e875a0a7aec60f4752befb5 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -57,25 +57,27 @@ def attention_naive(q, k, v, causal=False): return paddle.transpose(o, [0, 2, 1, 3]) -is_sm75 = ( - core.is_compiled_with_cuda() - and paddle.device.cuda.get_device_capability()[0] == 7 - and paddle.device.cuda.get_device_capability()[1] == 5 -) is_sm8x = ( core.is_compiled_with_cuda() and paddle.device.cuda.get_device_capability()[0] == 8 and paddle.device.cuda.get_device_capability()[1] >= 0 ) -is_sm_supported = is_sm75 or is_sm8x + +is_sm90 = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 9 + and paddle.device.cuda.get_device_capability()[1] == 0 +) + +is_sm_supported = is_sm8x or is_sm90 @unittest.skipIf( not core.is_compiled_with_cuda() - or get_cuda_version() < 11030 + or get_cuda_version() < 11040 or not is_sm_supported, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" - "and device's compute capability must be 7.5 or 8.x", + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", ) class TestFlashAttentionAPI(unittest.TestCase): def setUp(self): diff --git a/third_party/flashattn b/third_party/flashattn index 18106c1ba0ccee81b97ca947397c08a141815a47..b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit 18106c1ba0ccee81b97ca947397c08a141815a47 +Subproject commit b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a