From caa0f3774cdd93b766f048bfcda16ddde6b06b67 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 9 Aug 2023 09:55:15 +0800 Subject: [PATCH] fix codestyle (#56066) --- paddle/phi/api/yaml/ops.yaml | 2 -- python/paddle/nn/functional/flash_attention.py | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 9d7f6d6c0dd..9cba9266fe0 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -687,7 +687,6 @@ kernel : func : flash_attn data_type : q - intermediate : softmax_lse, seed_offset backward : flash_attn_grad - op : flash_attn_unpadded @@ -712,7 +711,6 @@ kernel : func : flash_attn_v1 data_type : q - intermediate : softmax_lse, seed_offset backward : flash_attn_v1_grad - op : flash_attn_v1_unpadded diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 238fc504e94..2250c573ffa 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -92,7 +92,7 @@ def flash_attention( """ if in_dynamic_mode(): if g_use_flash_attn_v1: - (result_attention, result_softmax,) = _C_ops.flash_attn_v1( + (result_attention, result_softmax, _, _) = _C_ops.flash_attn_v1( query, key, value, @@ -101,8 +101,9 @@ def flash_attention( return_softmax, not training, ) + else: - (result_attention, result_softmax,) = _C_ops.flash_attn( + (result_attention, result_softmax, _, _) = _C_ops.flash_attn( query, key, value, -- GitLab