• U
    [cherry-pick] Integration flash attention 2 (#56015) · cc9a7688
    umiswing 提交于
    * [FlashAttn] add flash randomness control (#52902)
    
    * add flash randomness control
    
    * fix VLOG undefied
    
    * [WIP] Integration flash attention 2 (#55758)
    
    * Work for fa-2 padded fwd. Code to be cleaned.
    
    * Work for fa2 unpadded fwd.
    
    * Work for padded-bwd, dk get small diff on np.random.seed(0)
    
    * Anyway I pass paddle's utest, except return softmax without dropout.
    
    * Clean code.
    
    * Modify interface.
    
    * Clean code and add some check.
    
    * Easy compile for dev.
    
    * Fix ci.
    
    * Fix ci-build.
    
    * Add std c++17 option again.
    
    * Limit max job when compiling fa2.
    
    * Remove const_cast
    
    * Add fwd params, to be cleaned.
    
    * Clean code.
    
    * Add bwd params.
    
    * Clean code.
    
    * Add enforce.
    
    * Use v2.0.4
    
    * Pass RNG state to fa2 capi
    
    * Fix review.
    
    * Add assert
    
    * Skip compile for sm less than 80.
    
    ---------
    Co-authored-by: NChitsing KUI <kuizhiqing@msn.com>
    cc9a7688
dist_flash_attn.py 3.5 KB