• L
    Add flash attention to speedup fused_gate_attention. (#52731) · d29c1f8e
    limingshu 提交于
    * Reorganize the forward codes of flash-attention.
    
    * Fix forward.
    
    * Remove some noused codes.
    
    * Simplify codes and fix backward.
    
    * Change all LOG(INFO) to VLOG and fix the backward.
    
    * add scale for AF2 flash_attn, much thanks to xreki and shaojie for debug these codes
    
    * decrease the effect of debug print on performance
    
    * Unify the initialize of flashattn arguments.
    
    * Rewirte the reshape of temp_mask and temp_bias.
    
    * API support use_flash_attn.
    
    * Fix compiling error on CI.
    
    * Try to crop the flash-attention lib.
    
    * Correct the condition of whether can use flash-attn.
    
    * Remove the softmax_out argument.
    
    * Remove is_causal.
    
    * Polish codes.
    
    * Fix qkv_transpose_out's shape and scaling of Q * K.
    
    * Update commit of flash-attention.
    
    ---------
    Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
    d29c1f8e
fused_gate_attention_op.cc 14.1 KB