# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py import torch import torch.nn as nn import torch.nn.functional as F import flash_attn_cuda def _flash_attn_forward( q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax, num_splits=0, generator=None, ): """ num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. Don't change it unless you know what you're doing. """ softmax_lse, *rest = flash_attn_cuda.fwd( q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, return_softmax, num_splits, generator, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() S_dmask = rest[0] if return_softmax else None return out, softmax_lse, S_dmask def _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0, generator=None, ): """ num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel as num_splits=3), so effectively the choices are 0, 1, and 2. This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. """ _, _, _, softmax_d = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax, ): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) out, softmax_lse, S_dmask = _flash_attn_forward( qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax, ) ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) ctx.dropout_p = dropout_p ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.causal = causal return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors if rng_state is not None: cur_rng_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(rng_state) dqkv = torch.empty_like(qkv) _flash_attn_backward( dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) return dqkv, None, None, None, None, None, None def flash_attn_unpadded_qkvpacked_func( qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, ): return FlashAttnQKVPackedFunc.apply( qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs )