Add attn_mask supported for FlashAttnKernel. (#55969)
* add mask
* add backword
* add enforce info
* update scale
* integrate code
* update enforce
* add enforce eq
* add error type
* update enforce
* add test_flash_attention
* Polish codes and fix compiling errors.
* Set num_splits to 0 for flash-attn with tensor mask.
* Fix the compiling error for non flash-attn case.
---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
Showing
想要评论请 注册 或 登录