[cherry-pick] Integration flash attention 2 (#56015)
* [FlashAttn] add flash randomness control (#52902)
* add flash randomness control
* fix VLOG undefied
* [WIP] Integration flash attention 2 (#55758)
* Work for fa-2 padded fwd. Code to be cleaned.
* Work for fa2 unpadded fwd.
* Work for padded-bwd, dk get small diff on np.random.seed(0)
* Anyway I pass paddle's utest, except return softmax without dropout.
* Clean code.
* Modify interface.
* Clean code and add some check.
* Easy compile for dev.
* Fix ci.
* Fix ci-build.
* Add std c++17 option again.
* Limit max job when compiling fa2.
* Remove const_cast
* Add fwd params, to be cleaned.
* Clean code.
* Add bwd params.
* Clean code.
* Add enforce.
* Use v2.0.4
* Pass RNG state to fa2 capi
* Fix review.
* Add assert
* Skip compile for sm less than 80.
---------
Co-authored-by: NChitsing KUI <kuizhiqing@msn.com>
Showing
想要评论请 注册 或 登录