diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index eae35d90f50f097f9f63075bc919d11cb093c861..95893ad27a6a1b30c465ec1676ce23c205c461ec 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) -set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc) +set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) set(FLASHATTN_INCLUDE_DIR "${FLASHATTN_INSTALL_DIR}/include" diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 8e75ecc473f2cb758afec60e183233367284d41f..4bd6aa99fe5c3e83fcdbc30ee99a80c2a09efdcc 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/flash_attn_grad_kernel.h" +#include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" @@ -25,6 +27,8 @@ #include "paddle/phi/backends/dynload/flashattn.h" #endif +DECLARE_bool(cudnn_deterministic); + namespace phi { template @@ -67,10 +71,17 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, int num_splits = 0; // 0 for an internal heuristic, which is optimal bool zero_tensors = false; + if (FLAGS_cudnn_deterministic) { + num_splits = 1; + } + const int64_t* seed_offset_data = seed_offset.data(); uint64_t seed = static_cast(seed_offset_data[0]); uint64_t offset = static_cast(seed_offset_data[1]); + VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset + << ", num_splits:" << num_splits; + int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); @@ -187,6 +198,9 @@ void FlashAttnGradKernel(const Context& ctx, float scale = 1.0f / std::sqrt(head_size); + VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + DenseTensor q_t_s, k_t_s, v_t_s; q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 7c2cd423dd03283b3431565151aa4ea331941f5f..63bd4d10cbdd526f1244203a38a4a2ddb08d86f5 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -14,9 +14,11 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" +#include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -28,6 +30,8 @@ #include "paddle/phi/backends/dynload/flashattn.h" #endif +DECLARE_bool(cudnn_deterministic); + namespace phi { template @@ -73,6 +77,9 @@ void FlashAttnUnpaddedKernel(const Context& ctx, int64_t batch_size = cu_seqlens_q.numel() - 1; int num_splits = 0; // 0 for an internal heuristic, which is optimal + if (FLAGS_cudnn_deterministic) { + num_splits = 1; + } bool zero_tensors = false; auto gen = ctx.GetGenerator(); @@ -82,6 +89,9 @@ void FlashAttnUnpaddedKernel(const Context& ctx, uint64_t seed = seed_offset_pair.first; uint64_t offset = seed_offset_pair.second; + VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset + << ", num_splits:" << num_splits; + seed_offset->Resize({2}); auto* seed_offset_data = ctx.template HostAlloc(seed_offset); seed_offset_data[0] = static_cast(seed); @@ -217,6 +227,9 @@ void FlashAttnKernel(const Context& ctx, float scale = 1.0f / std::sqrt(head_size); + VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + DenseTensor q_t_s, k_t_s, v_t_s; q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});