From d29c1f8e07d1330ed0128ac4dce7fab67654a0b5 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Fri, 19 May 2023 15:30:44 +0800 Subject: [PATCH] Add flash attention to speedup fused_gate_attention. (#52731) * Reorganize the forward codes of flash-attention. * Fix forward. * Remove some noused codes. * Simplify codes and fix backward. * Change all LOG(INFO) to VLOG and fix the backward. * add scale for AF2 flash_attn, much thanks to xreki and shaojie for debug these codes * decrease the effect of debug print on performance * Unify the initialize of flashattn arguments. * Rewirte the reshape of temp_mask and temp_bias. * API support use_flash_attn. * Fix compiling error on CI. * Try to crop the flash-attention lib. * Correct the condition of whether can use flash-attn. * Remove the softmax_out argument. * Remove is_causal. * Polish codes. * Fix qkv_transpose_out's shape and scaling of Q * K. * Update commit of flash-attention. --------- Co-authored-by: Liu Yiqun --- cmake/external/flashattn.cmake | 2 +- .../manual/fluid_manual/dygraph_forward_api.h | 1 + .../forwards/fused_gate_attention_fwd_func.cc | 20 +- .../nodes/fused_gate_attention_node.cc | 12 + .../api/manual/fluid_manual/nodes/nodes.h | 10 + .../operators/fused/fused_gate_attention.h | 510 +++++++++++++++++- .../fused/fused_gate_attention_op.cc | 42 +- .../fused/fused_gate_attention_op.cu | 109 +++- paddle/phi/backends/dynload/flashattn.h | 8 +- .../unittests/test_fused_gate_attention_op.py | 1 + .../nn/functional/fused_gate_attention.py | 3 +- 11 files changed, 670 insertions(+), 48 deletions(-) diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index 95893ad27a6..68957406756 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) -set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) +set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47) set(FLASHATTN_INCLUDE_DIR "${FLASHATTN_INSTALL_DIR}/include" diff --git a/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h b/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h index 7fda5aa69b7..d8a4fee0caf 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h @@ -27,6 +27,7 @@ std::tuple fused_gate_attention_dygraph_function( const paddle::Tensor& Query, diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc index fd3d32401d9..546b60438fe 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc @@ -26,6 +26,7 @@ std::tuple fused_gate_attention_dygraph_function( const paddle::Tensor& Query, @@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function( {"SoftmaxOut", {std::make_shared( egr::Controller::Instance().GenerateUniqueName())}}, + {"SoftmaxLse", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, {"FMHAOut", {std::make_shared( egr::Controller::Instance().GenerateUniqueName())}}, @@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function( egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut); paddle::Tensor SoftmaxOut; egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut); + paddle::Tensor SoftmaxLse; + egr::EagerUtils::GetOutput(outs["SoftmaxLse"][0], &SoftmaxLse); paddle::Tensor FMHAOut; egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut); paddle::Tensor GateOut; @@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function( p_autograd_Out); // Create GradOpNode auto grad_node = std::shared_ptr( - new fused_gate_attentionGradNodeCompat(8, 12)); + new fused_gate_attentionGradNodeCompat(9, 12)); bool merge_qkv = true; if (attrs.count("merge_qkv")) { @@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function( has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating")); } + bool use_flash_attn = false; + if (attrs.count("use_flash_attn")) { + use_flash_attn = PADDLE_GET_CONST(bool, attrs.at("use_flash_attn")); + } + // Set Attributes grad_node->SetAttrMap(std::move(attrs)); grad_node->SetDefaultAttrMap(std::move(default_attrs)); @@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function( grad_node->SetGradOutMeta(NonbatchedBias, 6); } + if (use_flash_attn) { + grad_node->SetTensorWrapperSoftmaxLse(SoftmaxLse); + grad_node->SetTensorWrapperSrcMask(SrcMask); + grad_node->SetGradOutMeta(SrcMask, 7); + } + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0); grad_node->SetGradInMeta(QueryTransposeOut, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1); @@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function( ValueTransposeOut, QKVTransposeOut, SoftmaxOut, + SoftmaxLse, FMHAOut, GateOut, Out); diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc b/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc index 8c427eba8cd..3692a20faed 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc @@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()( has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating")); } + bool use_flash_attn = false; + if (attr_map_.count("use_flash_attn")) { + use_flash_attn = PADDLE_GET_CONST(bool, attr_map_.at("use_flash_attn")); + } + std::map>> ins0 = {{"FMHAOut", egr::EagerUtils::TrySyncToVars( @@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()( egr::Controller::Instance().GenerateUniqueName())}; } + if (use_flash_attn) { + auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_); + ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask); + auto SoftmaxLse = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxLse_); + ins0["SoftmaxLse"] = egr::EagerUtils::TrySyncToVars(SoftmaxLse); + } + auto& attrs_map0 = this->attr_map_; // Pass the entire attribute map to TraceOp // The underlying kernel will pickup whatever attribute they need at runtime diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h index b0576672ae1..212f9d9f1da 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h @@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { GateOut_.clear(); GateWeight_.clear(); NonbatchedBias_.clear(); + SrcMask_.clear(); OutLinearBias_.clear(); OutLinearWeight_.clear(); QKVTransposeOut_.clear(); QKVWeight_.clear(); Query_.clear(); SoftmaxOut_.clear(); + SoftmaxLse_.clear(); Key_.clear(); QueryWeight_.clear(); KeyWeight_.clear(); @@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) { NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false); } + void SetTensorWrapperSrcMask(const paddle::Tensor& SrcMask) { + SrcMask_ = egr::TensorWrapper(SrcMask, false); + } void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) { OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false); } @@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) { SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); } + void SetTensorWrapperSoftmaxLse(const paddle::Tensor& SoftmaxLse) { + SoftmaxLse_ = egr::TensorWrapper(SoftmaxLse, false); + } void SetTensorWrapperKey(const paddle::Tensor& Key) { Key_ = egr::TensorWrapper(Key, false); } @@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { egr::TensorWrapper GateOut_; egr::TensorWrapper GateWeight_; egr::TensorWrapper NonbatchedBias_; + egr::TensorWrapper SrcMask_; egr::TensorWrapper OutLinearBias_; egr::TensorWrapper OutLinearWeight_; egr::TensorWrapper QKVTransposeOut_; egr::TensorWrapper QKVWeight_; egr::TensorWrapper Query_; egr::TensorWrapper SoftmaxOut_; + egr::TensorWrapper SoftmaxLse_; egr::TensorWrapper Key_; egr::TensorWrapper QueryWeight_; diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index 105647baf1c..5cbc4788a0c 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -14,8 +14,12 @@ limitations under the License. */ #pragma once +#ifdef PADDLE_WITH_FLASHATTN +#include "paddle/phi/backends/dynload/flashattn.h" +#endif + +#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h" @@ -24,6 +28,13 @@ limitations under the License. */ namespace paddle { namespace operators { +template +__global__ void SimleScaleKernel(int64_t numel, float scale, T* inout) { + CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { + inout[i] = static_cast(scale * static_cast(inout[i])); + } +} + inline std::string MemoryDebugString(const phi::DenseTensor& t) { int device_id = platform::GetCurrentDeviceId(); int64_t allocated = @@ -46,7 +57,45 @@ void AllocWithDebugInfo(const phi::GPUContext& dev_ctx, const std::string& info, phi::DenseTensor* t) { dev_ctx.Alloc(t, t->numel() * sizeof(T)); - VLOG(4) << info << ": " << MemoryDebugString(*t); + if (VLOG_IS_ON(4)) { + VLOG(4) << info << ": " << MemoryDebugString(*t); + } +} + +inline std::string TensorDebugString(const phi::DenseTensor* t, + const std::string& info) { + std::stringstream ss; + ss << info << ": "; + if (t) { + if (t->initialized()) { + ss << "shape=[" << t->dims() << "], ptr=" << t->data(); + } else { + ss << "not initialized"; + } + } else { + ss << "nullptr"; + } + return ss.str(); +} + +inline void WaitWithDebugInfo(const phi::GPUContext& dev_ctx) { + if (VLOG_IS_ON(5)) { + dev_ctx.Wait(); + VLOG(5) << "[Flash attn Synchronize] "; + } +} + +template +inline void TypeDebugInfo() { + if (VLOG_IS_ON(4)) { + if (std::is_same::value) { + VLOG(4) << "[Grad]: T is phi::dtype::float16."; + } else if (std::is_same::value) { + VLOG(4) << "[Grad]: T is phi::dtype::bfloat16."; + } else if (std::is_same::value) { + VLOG(4) << "[Grad]: T is float."; + } + } } template @@ -61,6 +110,7 @@ struct GateAttentionConfig { bool merge_qkv; bool has_gating; + bool use_flash_attn; int64_t batch_size; int64_t seq_len_m; @@ -90,8 +140,12 @@ struct GateAttentionConfig { const phi::DenseTensor* query_weight, const phi::DenseTensor* qkv_weight, bool merge_qkv, - bool has_gating) - : dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) { + bool has_gating, + bool use_flash_attn) + : dev_ctx(dev_ctx), + merge_qkv(merge_qkv), + has_gating(has_gating), + use_flash_attn(use_flash_attn) { // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] batch_size = query->dims()[0]; seq_len_m = query->dims()[1]; @@ -146,6 +200,22 @@ struct GateAttentionConfig { gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim}; } + bool CanUseFlashAttn() const { +#ifdef PADDLE_WITH_FLASHATTN + if (!std::is_same::value && + !std::is_same::value) { + return false; + } + + if (merge_qkv && batch_size == 1) { + if (head_dim == 32 || head_dim == 64 || head_dim == 128) { + return use_flash_attn; + } + } +#endif + return false; + } + int64_t GetQuerySize() const { return batch_size * seq_len_m * seq_len_r * num_heads * head_dim; } @@ -253,14 +323,16 @@ struct GateAttentionGradConfig : public GateAttentionConfig { const phi::DenseTensor* query_weight, const phi::DenseTensor* qkv_weight, bool merge_qkv, - bool has_gating) + bool has_gating, + bool use_flash_attn) : GateAttentionConfig(dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, - has_gating) {} + has_gating, + use_flash_attn) {} phi::DenseTensor* GetQKVOutGrad() { if (!qkv_out_grad.IsInitialized()) { @@ -336,6 +408,7 @@ class FMHAGateRef { T* q_ptr = nullptr; T* k_ptr = nullptr; T* v_ptr = nullptr; + if (merge_qkv_) { // qkv_transpose_out = transpose(qkv_out) PADDLE_ENFORCE_NOT_NULL( @@ -381,7 +454,6 @@ class FMHAGateRef { k_ptr = k_transpose_out->data(); v_ptr = v_transpose_out->data(); } - // qk_out = BatchedGEMM(Q, K^T) // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] * // [batch_size, seq_len_m, num_heads, m_size, head_dim] @@ -394,8 +466,8 @@ class FMHAGateRef { int64_t gemm_m = config->seq_len_r; int64_t gemm_n = config->m_size; int64_t gemm_k = config->head_dim; - T alpha = static_cast(1.0 / sqrt(config->head_dim)); + // attn = matmul(q, k.transpose(-1, -2)) ComputeBatchedGEMM(q_ptr, k_ptr, qk_out_ptr, @@ -407,6 +479,7 @@ class FMHAGateRef { gemm_batch_size, alpha); + // attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) // softmax_out = softmax(qk_out + nonbatched_bias + src_mask) ComputeBiasMaskSoftmaxForward( nonbatched_bias, src_mask, qk_out, softmax_out); @@ -414,7 +487,7 @@ class FMHAGateRef { // qktv_out = BatchedGEMM(softmax_out, V) // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * - // [batch_size, seq_len_m, num_heads, m_size, head_dim] + // [batch_size, seq_len_m, num_heads, m_size, head_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] phi::DenseTensor* qktv_out = config->GetQKTVOut(gate_out); T* qktv_out_ptr = qktv_out->data(); @@ -423,6 +496,7 @@ class FMHAGateRef { gemm_n = config->head_dim; gemm_k = config->m_size; + // o = matmul(attn, v) T* softmax_out_ptr = softmax_out->data(); ComputeBatchedGEMM(softmax_out_ptr, v_ptr, @@ -435,7 +509,9 @@ class FMHAGateRef { gemm_batch_size); // fmha_out = transpose(qktv_out) + // o = o.transpose(-2, -3).contiguous() ComputeQKTVTransposeForward(*qktv_out, fmha_out); + config->ClearQKTVOut(); if (config->has_gating) { gate_out->Resize(config->gate_out_dims); @@ -463,6 +539,7 @@ class FMHAGateRef { phi::DenseTensor k_transpose_out_grad; phi::DenseTensor v_transpose_out_grad; phi::DenseTensor qkv_transpose_out_grad; + if (merge_qkv_) { PADDLE_ENFORCE_NOT_NULL( qkv_transpose_out, @@ -752,11 +829,11 @@ class FMHAGateRef { int64_t batch_size, T alpha = static_cast(1.0), T beta = static_cast(0.0)) { - CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans; int64_t stride_a = m * k; int64_t stride_b = k * n; + CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans; auto blas = phi::funcs::GetBlas(dev_ctx_); blas.BatchedGEMM(cblas_trans_a, cblas_trans_b, @@ -777,5 +854,416 @@ class FMHAGateRef { bool merge_qkv_; }; +template +class FlashAttnWithGating { + public: + FlashAttnWithGating(const phi::GPUContext& dev_ctx, bool merge_qkv) + : dev_ctx_(dev_ctx), merge_qkv_(merge_qkv) {} + + void ComputeForward(const phi::DenseTensor* nonbatched_bias, + const phi::DenseTensor* src_mask, + phi::DenseTensor* qkv_transpose_out, + phi::DenseTensor* softmax_lse, + phi::DenseTensor* fmha_out, + GateAttentionConfig* config) { +#ifdef PADDLE_WITH_FLASHATTN + bool is_bf16 = + qkv_transpose_out->dtype() == DataType::BFLOAT16 ? true : false; + TypeDebugInfo(); + + PADDLE_ENFORCE_NOT_NULL( + qkv_transpose_out, + platform::errors::NotFound("The input qkv_transpose_out can not be " + "nullptr when merge_qkv is true.")); + + // 1. Transpose qkv_out for flash_attn. + phi::DenseTensor* qkv_out = config->GetQKVOut(); + ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out); + config->ClearQKVOut(); + + // q_size == k_size + int64_t q_size = config->GetQuerySize(); + T* q_ptr = qkv_transpose_out->data(); + T* k_ptr = q_ptr + q_size; + T* v_ptr = k_ptr + q_size; + + // 2. Scale Q: q_ptr = alpha * q_ptr + ComputeScaleQ(q_size, config->head_dim, q_ptr); + + // 3. flash_attn parameter setting. + phi::DenseTensor cu_seq_q; + phi::DenseTensor cu_seq_k; + InitArgumentsAndSeqTensors(config, &cu_seq_q, &cu_seq_k); + + std::vector temp_mask_dim = GetCompressedDim(src_mask); + std::vector temp_bias_dim = GetCompressedDim(nonbatched_bias); + + softmax_lse->Resize({fa_batch_size_, fa_num_heads_, fa_softmax_lse_dim_}); + AllocWithDebugInfo(dev_ctx_, "softmax_lse", softmax_lse); + + if (VLOG_IS_ON(6)) { + VLOG(6) << "temp_mask_dim={" << phi::make_ddim(temp_mask_dim) << "}"; + VLOG(6) << "temp_bias_dim={" << phi::make_ddim(temp_bias_dim) << "}"; + VLOG(6) << TensorDebugString(&cu_seq_q, "cu_seq_q"); + VLOG(6) << TensorDebugString(&cu_seq_k, "cu_seq_k"); + VLOG(6) << TensorDebugString(nonbatched_bias, "nonbatched_bias"); + VLOG(6) << TensorDebugString(src_mask, "src_mask"); + VLOG(6) << TensorDebugString(qkv_transpose_out, "qkv_transpose_out"); + VLOG(6) << TensorDebugString(softmax_lse, "softmax_lse"); + VLOG(6) << TensorDebugString(fmha_out, "fmha_out"); + } + + // 4. Get worksapce size and run the flash-attention kernel. + uint64_t workspace_size = 0; + phi::DenseTensor workspace; + cudaStream_t stream = dev_ctx_.stream(); + for (bool need_calc : {false, true}) { + // first calling, need_calc=false, set out_ptr to nullptr to calculate + // workspace size second calling, need_calc=true, run flash-attention + // kernel. + void* out_ptr = + need_calc ? static_cast(fmha_out->data()) : nullptr; + void* workspace_ptr = nullptr; + if (need_calc) { + VLOG(6) << "Step 2: Call the flash-attention kernel"; + if (workspace_size > 0) { + workspace = CreateWorkspace(workspace_size); + workspace_ptr = static_cast(workspace.data()); + } + } else { + VLOG(6) << "Step 1: Calculate the workspace_size"; + } + bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + static_cast(q_ptr), + static_cast(k_ptr), + static_cast(v_ptr), + out_ptr, // set out to nullptr to calculate workspace size + cu_seq_q.data(), + cu_seq_k.data(), + fa_total_q_, + fa_total_k_, + fa_batch_size_, + fa_num_heads_, + fa_head_size_, + fa_max_seqlen_q_, + fa_max_seqlen_k_, + fa_dropout_prob_, + fa_softmax_scale_, + fa_zero_tensors_, + is_bf16, + fa_num_splits_, + softmax_lse->data(), + workspace_ptr, + &workspace_size, + stream, + fa_seed_, + fa_offset_, + src_mask ? src_mask->data() : nullptr, + nonbatched_bias ? nonbatched_bias->data() : nullptr, + src_mask ? temp_mask_dim.data() : nullptr, + nonbatched_bias ? temp_bias_dim.data() : nullptr); + PADDLE_ENFORCE_EQ( + succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + WaitWithDebugInfo(dev_ctx_); + } +#else + PADDLE_THROW(phi::errors::Unimplemented( + "FlashAttention is unsupported, please set use_flash_attn to false.")); +#endif + } + + void ComputeBackward(const phi::DenseTensor* qkv_transpose_out, + const phi::DenseTensor* src_mask, + const phi::DenseTensor* nonbatched_bias, + const phi::DenseTensor* softmax_lse, + const phi::DenseTensor* fmha_out, + const phi::DenseTensor* fmha_out_grad, + phi::DenseTensor* src_mask_grad, + phi::DenseTensor* nonbatched_bias_grad, + GateAttentionGradConfig* config) { +#ifdef PADDLE_WITH_FLASHATTN + bool is_bf16 = + qkv_transpose_out->dtype() == DataType::BFLOAT16 ? true : false; + TypeDebugInfo(); + + PADDLE_ENFORCE_NOT_NULL( + qkv_transpose_out, + platform::errors::NotFound("The input qkv_transpose_out can not be" + "nullptr when merge_qkv is true.")); + + int64_t q_size = config->GetQuerySize(); + const T* q_ptr = qkv_transpose_out->data(); + const T* k_ptr = q_ptr + q_size; + const T* v_ptr = k_ptr + q_size; + + phi::DenseTensor qkv_transpose_out_grad; + qkv_transpose_out_grad.Resize(phi::make_ddim({3, + config->batch_size, + config->seq_len_m, + config->seq_len_r, + config->num_heads, + config->head_dim})); + AllocWithDebugInfo( + dev_ctx_, "qkv_transpose_out_grad", &qkv_transpose_out_grad); + + T* q_grad_ptr = qkv_transpose_out_grad.data(); + T* k_grad_ptr = q_grad_ptr + q_size; + T* v_grad_ptr = k_grad_ptr + q_size; + WaitWithDebugInfo(dev_ctx_); + + // 1. flash_attn parameter setting. + phi::DenseTensor cu_seq_q; + phi::DenseTensor cu_seq_k; + InitArgumentsAndSeqTensors(config, &cu_seq_q, &cu_seq_k); + const int32_t* cu_seq_q_ptr = cu_seq_q.data(); + const int32_t* cu_seq_k_ptr = cu_seq_k.data(); + + std::vector temp_mask_dim = GetCompressedDim(src_mask); + std::vector temp_bias_dim = GetCompressedDim(nonbatched_bias); + + phi::DenseTensor softmax_d; + softmax_d.Resize(softmax_lse->dims()); + AllocWithDebugInfo(dev_ctx_, "d_softmax_lse", &softmax_d); + + phi::DenseTensor bias_d; + if (nonbatched_bias) { + bias_d.Resize( + {fa_batch_size_, fa_num_heads_, fa_max_seqlen_q_, fa_max_seqlen_k_}); + AllocWithDebugInfo(dev_ctx_, "d_bias", &bias_d); + } + + if (VLOG_IS_ON(6)) { + VLOG(6) << TensorDebugString(fmha_out, "fmha_out"); + VLOG(6) << TensorDebugString(fmha_out_grad, "fmha_out_grad"); + VLOG(6) << TensorDebugString(softmax_lse, "softmax_lse"); + VLOG(6) << TensorDebugString(&softmax_d, "softmax_d"); + VLOG(6) << TensorDebugString(nonbatched_bias, "nonbatched_bias"); + VLOG(6) << TensorDebugString(&bias_d, "bias_d"); + } + + // 2. Get worksapce size and run the flash-attention kernel. + uint64_t workspace_size = 0; + phi::DenseTensor workspace; + cudaStream_t stream = dev_ctx_.stream(); + for (bool need_calc : {false, true}) { + // first calling, need_calc=false, set out_ptr to nullptr to calculate + // workspace size second calling, need_calc=true, run flash-attention + // kernel. + const void* out_ptr = + need_calc ? static_cast(fmha_out->data()) : nullptr; + void* workspace_ptr = nullptr; + if (need_calc) { + VLOG(6) << "Step 2: Call the flash-attention kernel"; + if (workspace_size > 0) { + workspace = CreateWorkspace(workspace_size); + workspace_ptr = static_cast(workspace.data()); + } + } else { + VLOG(6) << "Step 1: Calculate the workspace_size"; + } + + bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( + static_cast(q_ptr), + static_cast(k_ptr), + static_cast(v_ptr), + static_cast(q_grad_ptr), + static_cast(k_grad_ptr), + static_cast(v_grad_ptr), + out_ptr, // set out to nullptr to calculate workspace size + static_cast(fmha_out_grad->data()), + cu_seq_q_ptr, + cu_seq_k_ptr, + fa_total_q_, + fa_total_k_, + fa_batch_size_, + fa_num_heads_, + fa_head_size_, + fa_max_seqlen_q_, + fa_max_seqlen_k_, + fa_dropout_prob_, + fa_softmax_scale_, + fa_zero_tensors_, + is_bf16, + fa_num_splits_, + softmax_lse->data(), + softmax_d.data(), + nonbatched_bias ? bias_d.data() : nullptr, + workspace_ptr, + &workspace_size, + stream, + fa_seed_, + fa_offset_, + src_mask ? src_mask->data() : nullptr, + nonbatched_bias ? nonbatched_bias->data() : nullptr, + src_mask ? temp_mask_dim.data() : nullptr, + nonbatched_bias ? temp_bias_dim.data() : nullptr); + PADDLE_ENFORCE_EQ( + succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + WaitWithDebugInfo(dev_ctx_); + } + + if (nonbatched_bias) { + // compare block reduce + auto dbias_first_dim = bias_d.numel() / nonbatched_bias->numel(); + bias_d.Resize({dbias_first_dim, + temp_bias_dim[0], + temp_bias_dim[1], + temp_bias_dim[2], + temp_bias_dim[3]}); + phi::funcs::ReduceKernel>( + dev_ctx_, + bias_d, + nonbatched_bias_grad, + kps::IdentityFunctor(), + {0}); + } + + // 3. Scale Q's grad: q_grad_ptr = alpha * q_grad_ptr + ComputeScaleQ(q_size, config->head_dim, q_grad_ptr); + + // 4. Compute the grad of qkv_out. + phi::DenseTensor* qkv_out_grad = config->GetQKVOutGrad(); + ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "FlashAttention is unsupported, please set use_flash_attn to false.")); +#endif + } + + private: + std::vector GetCompressedDim(const phi::DenseTensor* tensor) { + std::vector compressed_dims; + if (tensor) { + int64_t first_dim = 1; + const auto& origin_dims = tensor->dims(); + auto rank = origin_dims.size(); + for (int i = 0; i < rank - 3; ++i) { + first_dim *= origin_dims[i]; + } + compressed_dims = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + } + return compressed_dims; + } + + phi::DenseTensor CreateWorkspace(uint64_t workspace_size) { + phi::DenseTensor workspace; + if (workspace_size > 0) { + workspace = phi::Empty( + dev_ctx_, {int64_t(workspace_size / sizeof(float))}); + } + VLOG(5) << "Allocate workspace: workspace_size=" << workspace_size; + return workspace; + } + + void GenerateSeedAndOffset(int64_t batch_size, int64_t num_heads) { + auto gen = dev_ctx_.GetGenerator(); + uint64_t inc = batch_size * num_heads * 32; + auto seed_offset_pair = gen->IncrementOffset(inc); + fa_seed_ = seed_offset_pair.first; + fa_offset_ = seed_offset_pair.second; + } + + void InitArgumentsAndSeqTensors(GateAttentionConfig* config, + phi::DenseTensor* cu_seq_q, + phi::DenseTensor* cu_seq_k) { + fa_batch_size_ = static_cast(config->batch_size) * + static_cast(config->seq_len_m); + fa_num_heads_ = static_cast(config->num_heads); // qkv_dims[2]; + fa_head_size_ = static_cast(config->head_dim); // qkv_dims[3]; + fa_max_seqlen_q_ = config->seq_len_r; + fa_max_seqlen_k_ = config->m_size; + fa_total_q_ = fa_batch_size_ * fa_max_seqlen_q_; + fa_total_k_ = fa_batch_size_ * fa_max_seqlen_k_; + + // 0 for an internal heuristic, which is optimal + fa_num_splits_ = 0; + fa_zero_tensors_ = false; + + fa_softmax_lse_dim_ = ((fa_max_seqlen_q_ + 16 - 1) / 16) * 16; + fa_softmax_scale_ = 1.0f; + fa_dropout_prob_ = 0.0f; + GenerateSeedAndOffset(fa_batch_size_, fa_num_heads_); + + phi::ArangeNullaryKernel( + dev_ctx_, + 0, + (fa_batch_size_ + 1) * fa_max_seqlen_q_, + fa_max_seqlen_q_, + cu_seq_q); + phi::ArangeNullaryKernel( + dev_ctx_, + 0, + (fa_batch_size_ + 1) * fa_max_seqlen_k_, + fa_max_seqlen_k_, + cu_seq_k); + + if (VLOG_IS_ON(6)) { + VLOG(6) << "fa_batch_size : " << fa_batch_size_; + VLOG(6) << "fa_total_q : " << fa_total_q_; + VLOG(6) << "fa_total_k : " << fa_total_k_; + VLOG(6) << "fa_num_heads : " << fa_num_heads_; + VLOG(6) << "fa_head_size : " << fa_head_size_; + VLOG(6) << "fa_max_seqlen_q : " << fa_max_seqlen_q_; + VLOG(6) << "fa_max_seqlen_k : " << fa_max_seqlen_k_; + VLOG(6) << "fa_num_splits : " << fa_num_splits_; + VLOG(6) << "fa_softmax_lse_dim : " << fa_softmax_lse_dim_; + VLOG(6) << "fa_softmax_scale : " << fa_softmax_scale_; + VLOG(6) << "fa_dropout_prob : " << fa_dropout_prob_; + } + } + + // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] -> + // [3, batch_size, seq_len_m, seq_len_r, num_heads, head_dim] + void ComputeQKVTransposeForward(const phi::DenseTensor& qkv_out, + phi::DenseTensor* qkv_transpose_out) { + std::vector perm = {3, 0, 1, 2, 4, 5}; + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, qkv_out, perm, qkv_transpose_out); + } + + // [3, batch_size, seq_len_m, seq_len_r, num_heads, head_dim] -> + // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] + void ComputeQKVTransposeBackward( + const phi::DenseTensor& qkv_transpose_out_grad, + phi::DenseTensor* qkv_out_grad) { + std::vector perm = {1, 2, 3, 0, 4, 5}; + phi::funcs::TransposeGPUKernelDriver( + dev_ctx_, qkv_transpose_out_grad, perm, qkv_out_grad); + } + + void ComputeScaleQ(int64_t numel, int64_t head_dim, T* ptr) { + float scale = static_cast(1.0f / std::sqrt(head_dim)); + VLOG(6) << "[ComputeScaleQ] numel=" << numel << ", scale=" << scale; + + auto gpu_config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx_, numel, 1); + SimleScaleKernel<<>>(numel, scale, ptr); + } + + const phi::GPUContext& dev_ctx_; + bool merge_qkv_; + + int fa_batch_size_; + int fa_total_q_; + int fa_total_k_; + int fa_num_heads_; + int fa_head_size_; + int fa_max_seqlen_q_; + int fa_max_seqlen_k_; + int fa_num_splits_; + int fa_softmax_lse_dim_; + float fa_softmax_scale_{1.0f}; + float fa_dropout_prob_{0.0f}; + uint64_t fa_seed_{0}; + uint64_t fa_offset_{0}; + bool fa_zero_tensors_{false}; +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cc b/paddle/fluid/operators/fused/fused_gate_attention_op.cc index c91bca47cf4..7175a20787b 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cc @@ -157,6 +157,9 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker { .AsIntermediate() .AsDispensable(); AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); + AddOutput("SoftmaxLse", "Result of the flash attention.") + .AsIntermediate() + .AsDispensable(); AddOutput("FMHAOut", "Result in fmha.").AsIntermediate(); AddOutput("GateOut", "Result of the gating module.") .AsIntermediate() @@ -170,6 +173,11 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "if true, calculation with merged qkv, " "[default true].") .SetDefault(true); + AddAttr( + "use_flash_attn", + "if true, the attention op will be computed in flash_attn branch, " + "[default false].") + .SetDefault(false); AddComment(R"DOC( Add fused attention op whose logic is as follows: { @@ -223,15 +231,15 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), "Input", "QueryWeight", - "fused_aate_attention_arad"); + "fused_gate_attention_arad"); OP_INOUT_CHECK(ctx->HasInput("KeyWeight"), "Input", "KeyWeight", - "fused_aate_attention_arad"); + "fused_gate_attention_arad"); OP_INOUT_CHECK(ctx->HasInput("ValueWeight"), "Input", "ValueWeight", - "fused_aate_attention_arad"); + "fused_gate_attention_arad"); for (auto& name : {"QueryWeight", "KeyWeight", "ValueWeight"}) { ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name)); @@ -259,6 +267,27 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->GetInputDim("OutLinearBias")); } + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("Query"); + auto input_data_type = framework::TransToProtoVarType(input->dtype()); + return phi::KernelKey(input_data_type, ctx.GetPlace()); + } + + phi::KernelKey GetKernelTypeForVar( + const std::string& var_name, + const phi::DenseTensor& tensor, + const phi::KernelKey& expected_kernel_type) const override { + if (var_name == "SoftmaxLse") { + return phi::KernelKey(phi::Backend::ALL_BACKEND, + expected_kernel_type.layout(), + expected_kernel_type.dtype()); + } + return phi::KernelKey( + tensor.place(), tensor.layout(), expected_kernel_type.dtype()); + } }; template @@ -276,11 +305,18 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetAttrMap(this->Attrs()); bool merge_qkv = PADDLE_GET_CONST(bool, op->GetAttr("merge_qkv")); + bool use_flash_attn = PADDLE_GET_CONST(bool, op->GetAttr("use_flash_attn")); + if (merge_qkv) { op->SetInput("QKVWeight", this->Input("QKVWeight")); op->SetOutput(framework::GradVarName("QKVWeight"), this->InputGrad("QKVWeight")); op->SetInput("QKVTransposeOut", this->Output("QKVTransposeOut")); + + if (use_flash_attn) { + op->SetInput("SrcMask", this->Input("SrcMask")); + op->SetInput("SoftmaxLse", this->Output("SoftmaxLse")); + } } else { op->SetInput("Key", this->Input("Key")); op->SetOutput(framework::GradVarName("Key"), this->InputGrad("Key")); diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cu b/paddle/fluid/operators/fused/fused_gate_attention_op.cu index b5e3e7e4401..e2cdb513fea 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cu @@ -371,17 +371,16 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { auto *v_transpose_out = ctx.Output("ValueTransposeOut"); auto *qkv_transpose_out = ctx.Output("QKVTransposeOut"); - auto *softmax_out = ctx.Output("SoftmaxOut"); auto *fmha_out = ctx.Output("FMHAOut"); auto *gate_out = ctx.Output("GateOut"); auto *out = ctx.Output("Out"); const bool merge_qkv = ctx.Attr("merge_qkv"); const bool has_gating = ctx.Attr("has_gating"); + const bool use_flash_attn = ctx.Attr("use_flash_attn"); bool use_fused_matmul_bias = true; auto &dev_ctx = ctx.template device_context(); - AllocWithDebugInfo(dev_ctx, "softmax_out", softmax_out); AllocWithDebugInfo(dev_ctx, "fmha_out", fmha_out); if (has_gating) { AllocWithDebugInfo(dev_ctx, "gate_out", gate_out); @@ -389,8 +388,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { AllocWithDebugInfo(dev_ctx, "out", out); // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged. - GateAttentionConfig config( - dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); + GateAttentionConfig config(dev_ctx, + query, + key, + query_weight, + qkv_weight, + merge_qkv, + has_gating, + use_flash_attn); if (merge_qkv) { PADDLE_ENFORCE_EQ( @@ -406,6 +411,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { phi::DenseTensor *qkv_out = config.GetQKVOut(); ComputeMergedQKVMatmulForward(ctx, config, query, qkv_out); + if (config.CanUseFlashAttn()) { + qkv_transpose_out->Resize(phi::make_ddim({3, + config.batch_size, + config.seq_len_m, + config.seq_len_r, + config.num_heads, + config.head_dim})); + } AllocWithDebugInfo(dev_ctx, "qkv_transpose_out", qkv_transpose_out); } else { // 1. Separated QKV Matmul @@ -421,17 +434,31 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { } // 2. FMHA - auto fmha_compute = FMHAGateRef(dev_ctx, merge_qkv); - fmha_compute.ComputeForward(nonbatched_bias, - src_mask, - q_transpose_out, - k_transpose_out, - v_transpose_out, - qkv_transpose_out, - softmax_out, - fmha_out, - gate_out, - &config); + if (config.CanUseFlashAttn()) { + auto *softmax_lse = ctx.Output("SoftmaxLse"); + auto fmha_compute = FlashAttnWithGating(dev_ctx, merge_qkv); + fmha_compute.ComputeForward(nonbatched_bias, + src_mask, + qkv_transpose_out, + softmax_lse, + fmha_out, + &config); + } else { + auto *softmax_out = ctx.Output("SoftmaxOut"); + AllocWithDebugInfo(dev_ctx, "softmax_out", softmax_out); + + auto fmha_compute = FMHAGateRef(dev_ctx, merge_qkv); + fmha_compute.ComputeForward(nonbatched_bias, + src_mask, + q_transpose_out, + k_transpose_out, + v_transpose_out, + qkv_transpose_out, + softmax_out, + fmha_out, + gate_out, + &config); + } // 3. Gating Linear if (has_gating) { @@ -465,7 +492,6 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { ctx.Input("ValueTransposeOut"); const auto *qkv_transpose_out = ctx.Input("QKVTransposeOut"); - const auto *softmax_out = ctx.Input("SoftmaxOut"); const auto *fmha_out = ctx.Input("FMHAOut"); const auto *gate_out = ctx.Input("GateOut"); @@ -477,13 +503,20 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { bool has_gating = ctx.Attr("has_gating"); bool merge_qkv = ctx.Attr("merge_qkv"); + bool use_flash_attn = ctx.Attr("use_flash_attn"); bool use_fused_matmul_bias = true; auto &dev_ctx = ctx.template device_context(); AllocWithDebugInfo(dev_ctx, "query_grad", query_grad); - GateAttentionGradConfig config( - dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); + GateAttentionGradConfig config(dev_ctx, + query, + key, + query_weight, + qkv_weight, + merge_qkv, + has_gating, + use_flash_attn); phi::DenseTensor fmha_out_grad; fmha_out_grad.Resize(config.gate_out_dims); @@ -518,16 +551,36 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad); } - auto fmha_compute = FMHAGateRef(dev_ctx, merge_qkv); - fmha_compute.ComputeBackward(q_transpose_out, - k_transpose_out, - v_transpose_out, - qkv_transpose_out, - softmax_out, - &fmha_out_grad, - nullptr, - nonbatched_bias_grad, - &config); + if (config.CanUseFlashAttn()) { + const auto *nonbatched_bias = + ctx.Input("NonbatchedBias"); + const auto *src_mask = ctx.Input("SrcMask"); + const auto *softmax_lse = ctx.Input("SoftmaxLse"); + + auto fmha_compute = FlashAttnWithGating(dev_ctx, merge_qkv); + fmha_compute.ComputeBackward(qkv_transpose_out, + src_mask, + nonbatched_bias, + softmax_lse, + fmha_out, + &fmha_out_grad, + nullptr, + nonbatched_bias_grad, + &config); + } else { + const auto *softmax_out = ctx.Input("SoftmaxOut"); + + auto fmha_compute = FMHAGateRef(dev_ctx, merge_qkv); + fmha_compute.ComputeBackward(q_transpose_out, + k_transpose_out, + v_transpose_out, + qkv_transpose_out, + softmax_out, + &fmha_out_grad, + nullptr, + nonbatched_bias_grad, + &config); + } bool use_addto = has_gating ? true : false; if (merge_qkv) { diff --git a/paddle/phi/backends/dynload/flashattn.h b/paddle/phi/backends/dynload/flashattn.h index ec443fd9f8e..8948ec6a469 100644 --- a/paddle/phi/backends/dynload/flashattn.h +++ b/paddle/phi/backends/dynload/flashattn.h @@ -43,9 +43,11 @@ extern void* flashattn_dso_handle; #define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \ DYNAMIC_LOAD_FLASHATTN_WRAP(__name) -#define FLASHATTN_ROUTINE_EACH(__macro) \ - __macro(flash_attn_fwd); \ - __macro(flash_attn_bwd); \ +#define FLASHATTN_ROUTINE_EACH(__macro) \ + __macro(flash_attn_fwd); \ + __macro(flash_attn_bwd); \ + __macro(flash_attn_fwd_with_bias_and_mask); \ + __macro(flash_attn_bwd_with_bias_and_mask); \ __macro(flash_attn_error); FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP); diff --git a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py index d9bc88d4b2c..e88b923f287 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py @@ -274,6 +274,7 @@ class TestFusedGateAttentionOp(OpTest): _, _, softmax_out, + _, fmha_out, gate_out, out, diff --git a/python/paddle/incubate/nn/functional/fused_gate_attention.py b/python/paddle/incubate/nn/functional/fused_gate_attention.py index 13833683449..78f4abac823 100644 --- a/python/paddle/incubate/nn/functional/fused_gate_attention.py +++ b/python/paddle/incubate/nn/functional/fused_gate_attention.py @@ -83,6 +83,7 @@ def fused_gate_attention( attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None. has_gating (bool, optional): Whether has the gating linear. Default True. merge_qkv (bool, optional): Whether has the gating linear. Default True. + use_flash_attn (bool, optional): Whether use flash-attention to speedup. Default False. Returns: Tensor: The output Tensor, the data type and shape is same as `query`. @@ -142,7 +143,7 @@ def fused_gate_attention( """ if _non_static_mode(): - _, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention( + _, _, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention( query, key, query_weight, -- GitLab