// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/flash_attn_grad_kernel.h" #include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" #ifdef PADDLE_WITH_FLASHATTN #include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" #endif DECLARE_bool(cudnn_deterministic); namespace phi { template void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& q, const DenseTensor& k, const DenseTensor& v, const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_k, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, const DenseTensor& dout, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout, bool causal, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(dq); ctx.template Alloc(dk); ctx.template Alloc(dv); const cudaStream_t stream = ctx.stream(); // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); const int64_t total_q = dims[0]; const int batch_size = cu_seqlens_q.numel() - 1; const int num_heads = dims[1]; const int head_size_og = dout.dims()[2]; const int head_size = dims[2]; const int total_k = k.dims()[0]; const int num_heads_k = k.dims()[1]; // TODO(umiswing): add deterministic in fa2. // int num_splits = 0; // 0 for an internal heuristic, which is optimal // if (FLAGS_cudnn_deterministic) { // num_splits = 1; // } const bool zero_tensors = false; // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( head_size_og, head_size, phi::errors::InvalidArgument( "flash_attn_bwd receive input with head_size_og == head_size")); FlashAttnBwdParamsV2 params = FlashAttnBwdParamsV2(ctx, batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, head_size, dropout, scale, causal, q.dtype(), seed_offset.data()); VLOG(4) << "FlashAttn bwd seed: " << params.seed << ", offset: " << params.offset; const bool succ = phi::dynload::flash_attn_varlen_bwd(dout.data(), q.data(), k.data(), v.data(), out.data(), params.softmax_d.data(), softmax_lse.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), params.rng_state.data(), dq->data(), dk->data(), dv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, params.max_seqlen_k, params.seqlen_q_rounded, params.seqlen_k_rounded, params.num_heads, params.num_heads_k, params.head_size, params.head_size_rounded, params.dropout, params.scale, params.causal, params.is_bf16, stream, params.seed, params.offset); if (!succ) { PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); } #else PADDLE_THROW(phi::errors::Unimplemented( "FlashAttention is unsupported, please set use_flash_attn to false.")); #endif } template void FlashAttnGradKernel(const Context& ctx, const DenseTensor& q, const DenseTensor& k, const DenseTensor& v, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, const DenseTensor& dout, float dropout, bool causal, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] auto dims = q.dims(); const int batch_size = dims[0]; const int seqlen_q = dims[1]; const int num_heads = dims[2]; const int head_size_og = dout.dims()[3]; const int head_size = dims[3]; const int seqlen_k = k.dims()[1]; const int num_heads_k = k.dims()[2]; // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( head_size_og, head_size, phi::errors::InvalidArgument( "flash_attn_bwd receive input with head_size_og == head_size")); VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() << "], v[" << v.dims() << "]"; const float scale = 1.0f / std::sqrt(head_size); FlashAttnBwdParamsV2 params = FlashAttnBwdParamsV2(ctx, batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, head_size, dropout, scale, causal, q.dtype(), seed_offset.data()); ctx.template Alloc(dq); ctx.template Alloc(dk); ctx.template Alloc(dv); cudaStream_t stream = ctx.stream(); VLOG(4) << "FlashAttn bwd seed: " << params.seed << ", offset: " << params.offset; const bool succ = phi::dynload::flash_attn_bwd(dout.data(), q.data(), k.data(), v.data(), out.data(), params.softmax_d.data(), softmax_lse.data(), params.rng_state.data(), dq->data(), dk->data(), dv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, params.max_seqlen_k, params.seqlen_q_rounded, params.seqlen_k_rounded, params.num_heads, params.num_heads_k, params.head_size, params.head_size_rounded, params.dropout, params.scale, params.causal, params.is_bf16, stream, params.seed, params.offset); PADDLE_ENFORCE_EQ( succ, true, phi::errors::External("Error in Flash-Attention-2, detail information is", phi::dynload::flash_attn_error())); #else PADDLE_THROW(phi::errors::Unimplemented( "FlashAttention is unsupported, please set use_flash_attn to false.")); #endif } } // namespace phi PD_REGISTER_KERNEL(flash_attn_unpadded_grad, GPU, ALL_LAYOUT, phi::FlashAttnUnpaddedGradKernel, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } PD_REGISTER_KERNEL(flash_attn_grad, GPU, ALL_LAYOUT, phi::FlashAttnGradKernel, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset }