Add flash attention to speedup fused_gate_attention. (#52731)
* Reorganize the forward codes of flash-attention.
* Fix forward.
* Remove some noused codes.
* Simplify codes and fix backward.
* Change all LOG(INFO) to VLOG and fix the backward.
* add scale for AF2 flash_attn, much thanks to xreki and shaojie for debug these codes
* decrease the effect of debug print on performance
* Unify the initialize of flashattn arguments.
* Rewirte the reshape of temp_mask and temp_bias.
* API support use_flash_attn.
* Fix compiling error on CI.
* Try to crop the flash-attention lib.
* Correct the condition of whether can use flash-attn.
* Remove the softmax_out argument.
* Remove is_causal.
* Polish codes.
* Fix qkv_transpose_out's shape and scaling of Q * K.
* Update commit of flash-attention.
---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
Showing
想要评论请 注册 或 登录