未验证 提交 d29c1f8e 编写于 作者: L limingshu 提交者: GitHub

Add flash attention to speedup fused_gate_attention. (#52731)

* Reorganize the forward codes of flash-attention.

* Fix forward.

* Remove some noused codes.

* Simplify codes and fix backward.

* Change all LOG(INFO) to VLOG and fix the backward.

* add scale for AF2 flash_attn, much thanks to xreki and shaojie for debug these codes

* decrease the effect of debug print on performance

* Unify the initialize of flashattn arguments.

* Rewirte the reshape of temp_mask and temp_bias.

* API support use_flash_attn.

* Fix compiling error on CI.

* Try to crop the flash-attention lib.

* Correct the condition of whether can use flash-attn.

* Remove the softmax_out argument.

* Remove is_causal.

* Polish codes.

* Fix qkv_transpose_out's shape and scaling of Q * K.

* Update commit of flash-attention.

---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 4dc28b54
...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) ...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
set(FLASHATTN_INCLUDE_DIR set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include" "${FLASHATTN_INSTALL_DIR}/include"
......
...@@ -27,6 +27,7 @@ std::tuple<paddle::Tensor, ...@@ -27,6 +27,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor,
paddle::Tensor> paddle::Tensor>
fused_gate_attention_dygraph_function( fused_gate_attention_dygraph_function(
const paddle::Tensor& Query, const paddle::Tensor& Query,
......
...@@ -26,6 +26,7 @@ std::tuple<paddle::Tensor, ...@@ -26,6 +26,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor,
paddle::Tensor> paddle::Tensor>
fused_gate_attention_dygraph_function( fused_gate_attention_dygraph_function(
const paddle::Tensor& Query, const paddle::Tensor& Query,
...@@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function( ...@@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function(
{"SoftmaxOut", {"SoftmaxOut",
{std::make_shared<egr::EagerVariable>( {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}}, egr::Controller::Instance().GenerateUniqueName())}},
{"SoftmaxLse",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"FMHAOut", {"FMHAOut",
{std::make_shared<egr::EagerVariable>( {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}}, egr::Controller::Instance().GenerateUniqueName())}},
...@@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function( ...@@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function(
egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut); egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut);
paddle::Tensor SoftmaxOut; paddle::Tensor SoftmaxOut;
egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut); egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut);
paddle::Tensor SoftmaxLse;
egr::EagerUtils::GetOutput(outs["SoftmaxLse"][0], &SoftmaxLse);
paddle::Tensor FMHAOut; paddle::Tensor FMHAOut;
egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut); egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut);
paddle::Tensor GateOut; paddle::Tensor GateOut;
...@@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function( ...@@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function(
p_autograd_Out); p_autograd_Out);
// Create GradOpNode // Create GradOpNode
auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>( auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>(
new fused_gate_attentionGradNodeCompat(8, 12)); new fused_gate_attentionGradNodeCompat(9, 12));
bool merge_qkv = true; bool merge_qkv = true;
if (attrs.count("merge_qkv")) { if (attrs.count("merge_qkv")) {
...@@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function( ...@@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function(
has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating")); has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating"));
} }
bool use_flash_attn = false;
if (attrs.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attrs.at("use_flash_attn"));
}
// Set Attributes // Set Attributes
grad_node->SetAttrMap(std::move(attrs)); grad_node->SetAttrMap(std::move(attrs));
grad_node->SetDefaultAttrMap(std::move(default_attrs)); grad_node->SetDefaultAttrMap(std::move(default_attrs));
...@@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function( ...@@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function(
grad_node->SetGradOutMeta(NonbatchedBias, 6); grad_node->SetGradOutMeta(NonbatchedBias, 6);
} }
if (use_flash_attn) {
grad_node->SetTensorWrapperSoftmaxLse(SoftmaxLse);
grad_node->SetTensorWrapperSrcMask(SrcMask);
grad_node->SetGradOutMeta(SrcMask, 7);
}
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0);
grad_node->SetGradInMeta(QueryTransposeOut, 0); grad_node->SetGradInMeta(QueryTransposeOut, 0);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1); egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1);
...@@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function( ...@@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function(
ValueTransposeOut, ValueTransposeOut,
QKVTransposeOut, QKVTransposeOut,
SoftmaxOut, SoftmaxOut,
SoftmaxLse,
FMHAOut, FMHAOut,
GateOut, GateOut,
Out); Out);
......
...@@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()( ...@@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()(
has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating")); has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating"));
} }
bool use_flash_attn = false;
if (attr_map_.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attr_map_.at("use_flash_attn"));
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 = std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 =
{{"FMHAOut", {{"FMHAOut",
egr::EagerUtils::TrySyncToVars( egr::EagerUtils::TrySyncToVars(
...@@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()( ...@@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()(
egr::Controller::Instance().GenerateUniqueName())}; egr::Controller::Instance().GenerateUniqueName())};
} }
if (use_flash_attn) {
auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_);
ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask);
auto SoftmaxLse = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxLse_);
ins0["SoftmaxLse"] = egr::EagerUtils::TrySyncToVars(SoftmaxLse);
}
auto& attrs_map0 = this->attr_map_; auto& attrs_map0 = this->attr_map_;
// Pass the entire attribute map to TraceOp // Pass the entire attribute map to TraceOp
// The underlying kernel will pickup whatever attribute they need at runtime // The underlying kernel will pickup whatever attribute they need at runtime
......
...@@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
GateOut_.clear(); GateOut_.clear();
GateWeight_.clear(); GateWeight_.clear();
NonbatchedBias_.clear(); NonbatchedBias_.clear();
SrcMask_.clear();
OutLinearBias_.clear(); OutLinearBias_.clear();
OutLinearWeight_.clear(); OutLinearWeight_.clear();
QKVTransposeOut_.clear(); QKVTransposeOut_.clear();
QKVWeight_.clear(); QKVWeight_.clear();
Query_.clear(); Query_.clear();
SoftmaxOut_.clear(); SoftmaxOut_.clear();
SoftmaxLse_.clear();
Key_.clear(); Key_.clear();
QueryWeight_.clear(); QueryWeight_.clear();
KeyWeight_.clear(); KeyWeight_.clear();
...@@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) { void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) {
NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false); NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false);
} }
void SetTensorWrapperSrcMask(const paddle::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false);
}
void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) { void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) {
OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false); OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false);
} }
...@@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) { void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
} }
void SetTensorWrapperSoftmaxLse(const paddle::Tensor& SoftmaxLse) {
SoftmaxLse_ = egr::TensorWrapper(SoftmaxLse, false);
}
void SetTensorWrapperKey(const paddle::Tensor& Key) { void SetTensorWrapperKey(const paddle::Tensor& Key) {
Key_ = egr::TensorWrapper(Key, false); Key_ = egr::TensorWrapper(Key, false);
} }
...@@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
egr::TensorWrapper GateOut_; egr::TensorWrapper GateOut_;
egr::TensorWrapper GateWeight_; egr::TensorWrapper GateWeight_;
egr::TensorWrapper NonbatchedBias_; egr::TensorWrapper NonbatchedBias_;
egr::TensorWrapper SrcMask_;
egr::TensorWrapper OutLinearBias_; egr::TensorWrapper OutLinearBias_;
egr::TensorWrapper OutLinearWeight_; egr::TensorWrapper OutLinearWeight_;
egr::TensorWrapper QKVTransposeOut_; egr::TensorWrapper QKVTransposeOut_;
egr::TensorWrapper QKVWeight_; egr::TensorWrapper QKVWeight_;
egr::TensorWrapper Query_; egr::TensorWrapper Query_;
egr::TensorWrapper SoftmaxOut_; egr::TensorWrapper SoftmaxOut_;
egr::TensorWrapper SoftmaxLse_;
egr::TensorWrapper Key_; egr::TensorWrapper Key_;
egr::TensorWrapper QueryWeight_; egr::TensorWrapper QueryWeight_;
......
...@@ -14,8 +14,12 @@ limitations under the License. */ ...@@ -14,8 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/funcs/transpose_function.cu.h"
...@@ -24,6 +28,13 @@ limitations under the License. */ ...@@ -24,6 +28,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
__global__ void SimleScaleKernel(int64_t numel, float scale, T* inout) {
CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) {
inout[i] = static_cast<T>(scale * static_cast<float>(inout[i]));
}
}
inline std::string MemoryDebugString(const phi::DenseTensor& t) { inline std::string MemoryDebugString(const phi::DenseTensor& t) {
int device_id = platform::GetCurrentDeviceId(); int device_id = platform::GetCurrentDeviceId();
int64_t allocated = int64_t allocated =
...@@ -46,7 +57,45 @@ void AllocWithDebugInfo(const phi::GPUContext& dev_ctx, ...@@ -46,7 +57,45 @@ void AllocWithDebugInfo(const phi::GPUContext& dev_ctx,
const std::string& info, const std::string& info,
phi::DenseTensor* t) { phi::DenseTensor* t) {
dev_ctx.Alloc<T>(t, t->numel() * sizeof(T)); dev_ctx.Alloc<T>(t, t->numel() * sizeof(T));
VLOG(4) << info << ": " << MemoryDebugString(*t); if (VLOG_IS_ON(4)) {
VLOG(4) << info << ": " << MemoryDebugString(*t);
}
}
inline std::string TensorDebugString(const phi::DenseTensor* t,
const std::string& info) {
std::stringstream ss;
ss << info << ": ";
if (t) {
if (t->initialized()) {
ss << "shape=[" << t->dims() << "], ptr=" << t->data();
} else {
ss << "not initialized";
}
} else {
ss << "nullptr";
}
return ss.str();
}
inline void WaitWithDebugInfo(const phi::GPUContext& dev_ctx) {
if (VLOG_IS_ON(5)) {
dev_ctx.Wait();
VLOG(5) << "[Flash attn Synchronize] ";
}
}
template <typename T>
inline void TypeDebugInfo() {
if (VLOG_IS_ON(4)) {
if (std::is_same<T, phi::dtype::float16>::value) {
VLOG(4) << "[Grad]: T is phi::dtype::float16.";
} else if (std::is_same<T, phi::dtype::bfloat16>::value) {
VLOG(4) << "[Grad]: T is phi::dtype::bfloat16.";
} else if (std::is_same<T, float>::value) {
VLOG(4) << "[Grad]: T is float.";
}
}
} }
template <typename T> template <typename T>
...@@ -61,6 +110,7 @@ struct GateAttentionConfig { ...@@ -61,6 +110,7 @@ struct GateAttentionConfig {
bool merge_qkv; bool merge_qkv;
bool has_gating; bool has_gating;
bool use_flash_attn;
int64_t batch_size; int64_t batch_size;
int64_t seq_len_m; int64_t seq_len_m;
...@@ -90,8 +140,12 @@ struct GateAttentionConfig { ...@@ -90,8 +140,12 @@ struct GateAttentionConfig {
const phi::DenseTensor* query_weight, const phi::DenseTensor* query_weight,
const phi::DenseTensor* qkv_weight, const phi::DenseTensor* qkv_weight,
bool merge_qkv, bool merge_qkv,
bool has_gating) bool has_gating,
: dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) { bool use_flash_attn)
: dev_ctx(dev_ctx),
merge_qkv(merge_qkv),
has_gating(has_gating),
use_flash_attn(use_flash_attn) {
// query: shape=[batch_size, seq_len_m, seq_len_r, q_dim] // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
batch_size = query->dims()[0]; batch_size = query->dims()[0];
seq_len_m = query->dims()[1]; seq_len_m = query->dims()[1];
...@@ -146,6 +200,22 @@ struct GateAttentionConfig { ...@@ -146,6 +200,22 @@ struct GateAttentionConfig {
gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim}; gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
} }
bool CanUseFlashAttn() const {
#ifdef PADDLE_WITH_FLASHATTN
if (!std::is_same<T, phi::dtype::bfloat16>::value &&
!std::is_same<T, phi::dtype::float16>::value) {
return false;
}
if (merge_qkv && batch_size == 1) {
if (head_dim == 32 || head_dim == 64 || head_dim == 128) {
return use_flash_attn;
}
}
#endif
return false;
}
int64_t GetQuerySize() const { int64_t GetQuerySize() const {
return batch_size * seq_len_m * seq_len_r * num_heads * head_dim; return batch_size * seq_len_m * seq_len_r * num_heads * head_dim;
} }
...@@ -253,14 +323,16 @@ struct GateAttentionGradConfig : public GateAttentionConfig<T> { ...@@ -253,14 +323,16 @@ struct GateAttentionGradConfig : public GateAttentionConfig<T> {
const phi::DenseTensor* query_weight, const phi::DenseTensor* query_weight,
const phi::DenseTensor* qkv_weight, const phi::DenseTensor* qkv_weight,
bool merge_qkv, bool merge_qkv,
bool has_gating) bool has_gating,
bool use_flash_attn)
: GateAttentionConfig<T>(dev_ctx, : GateAttentionConfig<T>(dev_ctx,
query, query,
key, key,
query_weight, query_weight,
qkv_weight, qkv_weight,
merge_qkv, merge_qkv,
has_gating) {} has_gating,
use_flash_attn) {}
phi::DenseTensor* GetQKVOutGrad() { phi::DenseTensor* GetQKVOutGrad() {
if (!qkv_out_grad.IsInitialized()) { if (!qkv_out_grad.IsInitialized()) {
...@@ -336,6 +408,7 @@ class FMHAGateRef { ...@@ -336,6 +408,7 @@ class FMHAGateRef {
T* q_ptr = nullptr; T* q_ptr = nullptr;
T* k_ptr = nullptr; T* k_ptr = nullptr;
T* v_ptr = nullptr; T* v_ptr = nullptr;
if (merge_qkv_) { if (merge_qkv_) {
// qkv_transpose_out = transpose(qkv_out) // qkv_transpose_out = transpose(qkv_out)
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -381,7 +454,6 @@ class FMHAGateRef { ...@@ -381,7 +454,6 @@ class FMHAGateRef {
k_ptr = k_transpose_out->data<T>(); k_ptr = k_transpose_out->data<T>();
v_ptr = v_transpose_out->data<T>(); v_ptr = v_transpose_out->data<T>();
} }
// qk_out = BatchedGEMM(Q, K^T) // qk_out = BatchedGEMM(Q, K^T)
// [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] * // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] *
// [batch_size, seq_len_m, num_heads, m_size, head_dim] // [batch_size, seq_len_m, num_heads, m_size, head_dim]
...@@ -394,8 +466,8 @@ class FMHAGateRef { ...@@ -394,8 +466,8 @@ class FMHAGateRef {
int64_t gemm_m = config->seq_len_r; int64_t gemm_m = config->seq_len_r;
int64_t gemm_n = config->m_size; int64_t gemm_n = config->m_size;
int64_t gemm_k = config->head_dim; int64_t gemm_k = config->head_dim;
T alpha = static_cast<T>(1.0 / sqrt(config->head_dim)); T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
// attn = matmul(q, k.transpose(-1, -2))
ComputeBatchedGEMM(q_ptr, ComputeBatchedGEMM(q_ptr,
k_ptr, k_ptr,
qk_out_ptr, qk_out_ptr,
...@@ -407,6 +479,7 @@ class FMHAGateRef { ...@@ -407,6 +479,7 @@ class FMHAGateRef {
gemm_batch_size, gemm_batch_size,
alpha); alpha);
// attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
// softmax_out = softmax(qk_out + nonbatched_bias + src_mask) // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
ComputeBiasMaskSoftmaxForward( ComputeBiasMaskSoftmaxForward(
nonbatched_bias, src_mask, qk_out, softmax_out); nonbatched_bias, src_mask, qk_out, softmax_out);
...@@ -414,7 +487,7 @@ class FMHAGateRef { ...@@ -414,7 +487,7 @@ class FMHAGateRef {
// qktv_out = BatchedGEMM(softmax_out, V) // qktv_out = BatchedGEMM(softmax_out, V)
// [batch_size, seq_len_m, num_heads, seq_len_r, m_size] * // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
// [batch_size, seq_len_m, num_heads, m_size, head_dim] // [batch_size, seq_len_m, num_heads, m_size, head_dim]
// -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
phi::DenseTensor* qktv_out = config->GetQKTVOut(gate_out); phi::DenseTensor* qktv_out = config->GetQKTVOut(gate_out);
T* qktv_out_ptr = qktv_out->data<T>(); T* qktv_out_ptr = qktv_out->data<T>();
...@@ -423,6 +496,7 @@ class FMHAGateRef { ...@@ -423,6 +496,7 @@ class FMHAGateRef {
gemm_n = config->head_dim; gemm_n = config->head_dim;
gemm_k = config->m_size; gemm_k = config->m_size;
// o = matmul(attn, v)
T* softmax_out_ptr = softmax_out->data<T>(); T* softmax_out_ptr = softmax_out->data<T>();
ComputeBatchedGEMM(softmax_out_ptr, ComputeBatchedGEMM(softmax_out_ptr,
v_ptr, v_ptr,
...@@ -435,7 +509,9 @@ class FMHAGateRef { ...@@ -435,7 +509,9 @@ class FMHAGateRef {
gemm_batch_size); gemm_batch_size);
// fmha_out = transpose(qktv_out) // fmha_out = transpose(qktv_out)
// o = o.transpose(-2, -3).contiguous()
ComputeQKTVTransposeForward(*qktv_out, fmha_out); ComputeQKTVTransposeForward(*qktv_out, fmha_out);
config->ClearQKTVOut(); config->ClearQKTVOut();
if (config->has_gating) { if (config->has_gating) {
gate_out->Resize(config->gate_out_dims); gate_out->Resize(config->gate_out_dims);
...@@ -463,6 +539,7 @@ class FMHAGateRef { ...@@ -463,6 +539,7 @@ class FMHAGateRef {
phi::DenseTensor k_transpose_out_grad; phi::DenseTensor k_transpose_out_grad;
phi::DenseTensor v_transpose_out_grad; phi::DenseTensor v_transpose_out_grad;
phi::DenseTensor qkv_transpose_out_grad; phi::DenseTensor qkv_transpose_out_grad;
if (merge_qkv_) { if (merge_qkv_) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
qkv_transpose_out, qkv_transpose_out,
...@@ -752,11 +829,11 @@ class FMHAGateRef { ...@@ -752,11 +829,11 @@ class FMHAGateRef {
int64_t batch_size, int64_t batch_size,
T alpha = static_cast<T>(1.0), T alpha = static_cast<T>(1.0),
T beta = static_cast<T>(0.0)) { T beta = static_cast<T>(0.0)) {
CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
int64_t stride_a = m * k; int64_t stride_a = m * k;
int64_t stride_b = k * n; int64_t stride_b = k * n;
CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_); auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
blas.BatchedGEMM(cblas_trans_a, blas.BatchedGEMM(cblas_trans_a,
cblas_trans_b, cblas_trans_b,
...@@ -777,5 +854,416 @@ class FMHAGateRef { ...@@ -777,5 +854,416 @@ class FMHAGateRef {
bool merge_qkv_; bool merge_qkv_;
}; };
template <typename T>
class FlashAttnWithGating {
public:
FlashAttnWithGating(const phi::GPUContext& dev_ctx, bool merge_qkv)
: dev_ctx_(dev_ctx), merge_qkv_(merge_qkv) {}
void ComputeForward(const phi::DenseTensor* nonbatched_bias,
const phi::DenseTensor* src_mask,
phi::DenseTensor* qkv_transpose_out,
phi::DenseTensor* softmax_lse,
phi::DenseTensor* fmha_out,
GateAttentionConfig<T>* config) {
#ifdef PADDLE_WITH_FLASHATTN
bool is_bf16 =
qkv_transpose_out->dtype() == DataType::BFLOAT16 ? true : false;
TypeDebugInfo<T>();
PADDLE_ENFORCE_NOT_NULL(
qkv_transpose_out,
platform::errors::NotFound("The input qkv_transpose_out can not be "
"nullptr when merge_qkv is true."));
// 1. Transpose qkv_out for flash_attn.
phi::DenseTensor* qkv_out = config->GetQKVOut();
ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
config->ClearQKVOut();
// q_size == k_size
int64_t q_size = config->GetQuerySize();
T* q_ptr = qkv_transpose_out->data<T>();
T* k_ptr = q_ptr + q_size;
T* v_ptr = k_ptr + q_size;
// 2. Scale Q: q_ptr = alpha * q_ptr
ComputeScaleQ(q_size, config->head_dim, q_ptr);
// 3. flash_attn parameter setting.
phi::DenseTensor cu_seq_q;
phi::DenseTensor cu_seq_k;
InitArgumentsAndSeqTensors(config, &cu_seq_q, &cu_seq_k);
std::vector<int64_t> temp_mask_dim = GetCompressedDim(src_mask);
std::vector<int64_t> temp_bias_dim = GetCompressedDim(nonbatched_bias);
softmax_lse->Resize({fa_batch_size_, fa_num_heads_, fa_softmax_lse_dim_});
AllocWithDebugInfo<float>(dev_ctx_, "softmax_lse", softmax_lse);
if (VLOG_IS_ON(6)) {
VLOG(6) << "temp_mask_dim={" << phi::make_ddim(temp_mask_dim) << "}";
VLOG(6) << "temp_bias_dim={" << phi::make_ddim(temp_bias_dim) << "}";
VLOG(6) << TensorDebugString(&cu_seq_q, "cu_seq_q");
VLOG(6) << TensorDebugString(&cu_seq_k, "cu_seq_k");
VLOG(6) << TensorDebugString(nonbatched_bias, "nonbatched_bias");
VLOG(6) << TensorDebugString(src_mask, "src_mask");
VLOG(6) << TensorDebugString(qkv_transpose_out, "qkv_transpose_out");
VLOG(6) << TensorDebugString(softmax_lse, "softmax_lse");
VLOG(6) << TensorDebugString(fmha_out, "fmha_out");
}
// 4. Get worksapce size and run the flash-attention kernel.
uint64_t workspace_size = 0;
phi::DenseTensor workspace;
cudaStream_t stream = dev_ctx_.stream();
for (bool need_calc : {false, true}) {
// first calling, need_calc=false, set out_ptr to nullptr to calculate
// workspace size second calling, need_calc=true, run flash-attention
// kernel.
void* out_ptr =
need_calc ? static_cast<void*>(fmha_out->data()) : nullptr;
void* workspace_ptr = nullptr;
if (need_calc) {
VLOG(6) << "Step 2: Call the flash-attention kernel";
if (workspace_size > 0) {
workspace = CreateWorkspace(workspace_size);
workspace_ptr = static_cast<void*>(workspace.data());
}
} else {
VLOG(6) << "Step 1: Calculate the workspace_size";
}
bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask(
static_cast<const void*>(q_ptr),
static_cast<const void*>(k_ptr),
static_cast<const void*>(v_ptr),
out_ptr, // set out to nullptr to calculate workspace size
cu_seq_q.data<int32_t>(),
cu_seq_k.data<int32_t>(),
fa_total_q_,
fa_total_k_,
fa_batch_size_,
fa_num_heads_,
fa_head_size_,
fa_max_seqlen_q_,
fa_max_seqlen_k_,
fa_dropout_prob_,
fa_softmax_scale_,
fa_zero_tensors_,
is_bf16,
fa_num_splits_,
softmax_lse->data(),
workspace_ptr,
&workspace_size,
stream,
fa_seed_,
fa_offset_,
src_mask ? src_mask->data() : nullptr,
nonbatched_bias ? nonbatched_bias->data() : nullptr,
src_mask ? temp_mask_dim.data() : nullptr,
nonbatched_bias ? temp_bias_dim.data() : nullptr);
PADDLE_ENFORCE_EQ(
succ, true, phi::errors::External(phi::dynload::flash_attn_error()));
WaitWithDebugInfo(dev_ctx_);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
void ComputeBackward(const phi::DenseTensor* qkv_transpose_out,
const phi::DenseTensor* src_mask,
const phi::DenseTensor* nonbatched_bias,
const phi::DenseTensor* softmax_lse,
const phi::DenseTensor* fmha_out,
const phi::DenseTensor* fmha_out_grad,
phi::DenseTensor* src_mask_grad,
phi::DenseTensor* nonbatched_bias_grad,
GateAttentionGradConfig<T>* config) {
#ifdef PADDLE_WITH_FLASHATTN
bool is_bf16 =
qkv_transpose_out->dtype() == DataType::BFLOAT16 ? true : false;
TypeDebugInfo<T>();
PADDLE_ENFORCE_NOT_NULL(
qkv_transpose_out,
platform::errors::NotFound("The input qkv_transpose_out can not be"
"nullptr when merge_qkv is true."));
int64_t q_size = config->GetQuerySize();
const T* q_ptr = qkv_transpose_out->data<T>();
const T* k_ptr = q_ptr + q_size;
const T* v_ptr = k_ptr + q_size;
phi::DenseTensor qkv_transpose_out_grad;
qkv_transpose_out_grad.Resize(phi::make_ddim({3,
config->batch_size,
config->seq_len_m,
config->seq_len_r,
config->num_heads,
config->head_dim}));
AllocWithDebugInfo<T>(
dev_ctx_, "qkv_transpose_out_grad", &qkv_transpose_out_grad);
T* q_grad_ptr = qkv_transpose_out_grad.data<T>();
T* k_grad_ptr = q_grad_ptr + q_size;
T* v_grad_ptr = k_grad_ptr + q_size;
WaitWithDebugInfo(dev_ctx_);
// 1. flash_attn parameter setting.
phi::DenseTensor cu_seq_q;
phi::DenseTensor cu_seq_k;
InitArgumentsAndSeqTensors(config, &cu_seq_q, &cu_seq_k);
const int32_t* cu_seq_q_ptr = cu_seq_q.data<int32_t>();
const int32_t* cu_seq_k_ptr = cu_seq_k.data<int32_t>();
std::vector<int64_t> temp_mask_dim = GetCompressedDim(src_mask);
std::vector<int64_t> temp_bias_dim = GetCompressedDim(nonbatched_bias);
phi::DenseTensor softmax_d;
softmax_d.Resize(softmax_lse->dims());
AllocWithDebugInfo<float>(dev_ctx_, "d_softmax_lse", &softmax_d);
phi::DenseTensor bias_d;
if (nonbatched_bias) {
bias_d.Resize(
{fa_batch_size_, fa_num_heads_, fa_max_seqlen_q_, fa_max_seqlen_k_});
AllocWithDebugInfo<T>(dev_ctx_, "d_bias", &bias_d);
}
if (VLOG_IS_ON(6)) {
VLOG(6) << TensorDebugString(fmha_out, "fmha_out");
VLOG(6) << TensorDebugString(fmha_out_grad, "fmha_out_grad");
VLOG(6) << TensorDebugString(softmax_lse, "softmax_lse");
VLOG(6) << TensorDebugString(&softmax_d, "softmax_d");
VLOG(6) << TensorDebugString(nonbatched_bias, "nonbatched_bias");
VLOG(6) << TensorDebugString(&bias_d, "bias_d");
}
// 2. Get worksapce size and run the flash-attention kernel.
uint64_t workspace_size = 0;
phi::DenseTensor workspace;
cudaStream_t stream = dev_ctx_.stream();
for (bool need_calc : {false, true}) {
// first calling, need_calc=false, set out_ptr to nullptr to calculate
// workspace size second calling, need_calc=true, run flash-attention
// kernel.
const void* out_ptr =
need_calc ? static_cast<const void*>(fmha_out->data()) : nullptr;
void* workspace_ptr = nullptr;
if (need_calc) {
VLOG(6) << "Step 2: Call the flash-attention kernel";
if (workspace_size > 0) {
workspace = CreateWorkspace(workspace_size);
workspace_ptr = static_cast<void*>(workspace.data());
}
} else {
VLOG(6) << "Step 1: Calculate the workspace_size";
}
bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
static_cast<const void*>(q_ptr),
static_cast<const void*>(k_ptr),
static_cast<const void*>(v_ptr),
static_cast<void*>(q_grad_ptr),
static_cast<void*>(k_grad_ptr),
static_cast<void*>(v_grad_ptr),
out_ptr, // set out to nullptr to calculate workspace size
static_cast<const void*>(fmha_out_grad->data()),
cu_seq_q_ptr,
cu_seq_k_ptr,
fa_total_q_,
fa_total_k_,
fa_batch_size_,
fa_num_heads_,
fa_head_size_,
fa_max_seqlen_q_,
fa_max_seqlen_k_,
fa_dropout_prob_,
fa_softmax_scale_,
fa_zero_tensors_,
is_bf16,
fa_num_splits_,
softmax_lse->data(),
softmax_d.data(),
nonbatched_bias ? bias_d.data() : nullptr,
workspace_ptr,
&workspace_size,
stream,
fa_seed_,
fa_offset_,
src_mask ? src_mask->data() : nullptr,
nonbatched_bias ? nonbatched_bias->data() : nullptr,
src_mask ? temp_mask_dim.data() : nullptr,
nonbatched_bias ? temp_bias_dim.data() : nullptr);
PADDLE_ENFORCE_EQ(
succ, true, phi::errors::External(phi::dynload::flash_attn_error()));
WaitWithDebugInfo(dev_ctx_);
}
if (nonbatched_bias) {
// compare block reduce
auto dbias_first_dim = bias_d.numel() / nonbatched_bias->numel();
bias_d.Resize({dbias_first_dim,
temp_bias_dim[0],
temp_bias_dim[1],
temp_bias_dim[2],
temp_bias_dim[3]});
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
bias_d,
nonbatched_bias_grad,
kps::IdentityFunctor<T>(),
{0});
}
// 3. Scale Q's grad: q_grad_ptr = alpha * q_grad_ptr
ComputeScaleQ(q_size, config->head_dim, q_grad_ptr);
// 4. Compute the grad of qkv_out.
phi::DenseTensor* qkv_out_grad = config->GetQKVOutGrad();
ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
private:
std::vector<int64_t> GetCompressedDim(const phi::DenseTensor* tensor) {
std::vector<int64_t> compressed_dims;
if (tensor) {
int64_t first_dim = 1;
const auto& origin_dims = tensor->dims();
auto rank = origin_dims.size();
for (int i = 0; i < rank - 3; ++i) {
first_dim *= origin_dims[i];
}
compressed_dims = {first_dim,
origin_dims[rank - 3],
origin_dims[rank - 2],
origin_dims[rank - 1]};
}
return compressed_dims;
}
phi::DenseTensor CreateWorkspace(uint64_t workspace_size) {
phi::DenseTensor workspace;
if (workspace_size > 0) {
workspace = phi::Empty<float, phi::GPUContext>(
dev_ctx_, {int64_t(workspace_size / sizeof(float))});
}
VLOG(5) << "Allocate workspace: workspace_size=" << workspace_size;
return workspace;
}
void GenerateSeedAndOffset(int64_t batch_size, int64_t num_heads) {
auto gen = dev_ctx_.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
fa_seed_ = seed_offset_pair.first;
fa_offset_ = seed_offset_pair.second;
}
void InitArgumentsAndSeqTensors(GateAttentionConfig<T>* config,
phi::DenseTensor* cu_seq_q,
phi::DenseTensor* cu_seq_k) {
fa_batch_size_ = static_cast<int>(config->batch_size) *
static_cast<int>(config->seq_len_m);
fa_num_heads_ = static_cast<int>(config->num_heads); // qkv_dims[2];
fa_head_size_ = static_cast<int>(config->head_dim); // qkv_dims[3];
fa_max_seqlen_q_ = config->seq_len_r;
fa_max_seqlen_k_ = config->m_size;
fa_total_q_ = fa_batch_size_ * fa_max_seqlen_q_;
fa_total_k_ = fa_batch_size_ * fa_max_seqlen_k_;
// 0 for an internal heuristic, which is optimal
fa_num_splits_ = 0;
fa_zero_tensors_ = false;
fa_softmax_lse_dim_ = ((fa_max_seqlen_q_ + 16 - 1) / 16) * 16;
fa_softmax_scale_ = 1.0f;
fa_dropout_prob_ = 0.0f;
GenerateSeedAndOffset(fa_batch_size_, fa_num_heads_);
phi::ArangeNullaryKernel<int32_t, phi::GPUContext>(
dev_ctx_,
0,
(fa_batch_size_ + 1) * fa_max_seqlen_q_,
fa_max_seqlen_q_,
cu_seq_q);
phi::ArangeNullaryKernel<int32_t, phi::GPUContext>(
dev_ctx_,
0,
(fa_batch_size_ + 1) * fa_max_seqlen_k_,
fa_max_seqlen_k_,
cu_seq_k);
if (VLOG_IS_ON(6)) {
VLOG(6) << "fa_batch_size : " << fa_batch_size_;
VLOG(6) << "fa_total_q : " << fa_total_q_;
VLOG(6) << "fa_total_k : " << fa_total_k_;
VLOG(6) << "fa_num_heads : " << fa_num_heads_;
VLOG(6) << "fa_head_size : " << fa_head_size_;
VLOG(6) << "fa_max_seqlen_q : " << fa_max_seqlen_q_;
VLOG(6) << "fa_max_seqlen_k : " << fa_max_seqlen_k_;
VLOG(6) << "fa_num_splits : " << fa_num_splits_;
VLOG(6) << "fa_softmax_lse_dim : " << fa_softmax_lse_dim_;
VLOG(6) << "fa_softmax_scale : " << fa_softmax_scale_;
VLOG(6) << "fa_dropout_prob : " << fa_dropout_prob_;
}
}
// [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] ->
// [3, batch_size, seq_len_m, seq_len_r, num_heads, head_dim]
void ComputeQKVTransposeForward(const phi::DenseTensor& qkv_out,
phi::DenseTensor* qkv_transpose_out) {
std::vector<int> perm = {3, 0, 1, 2, 4, 5};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, qkv_out, perm, qkv_transpose_out);
}
// [3, batch_size, seq_len_m, seq_len_r, num_heads, head_dim] ->
// [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim]
void ComputeQKVTransposeBackward(
const phi::DenseTensor& qkv_transpose_out_grad,
phi::DenseTensor* qkv_out_grad) {
std::vector<int> perm = {1, 2, 3, 0, 4, 5};
phi::funcs::TransposeGPUKernelDriver<T>(
dev_ctx_, qkv_transpose_out_grad, perm, qkv_out_grad);
}
void ComputeScaleQ(int64_t numel, int64_t head_dim, T* ptr) {
float scale = static_cast<float>(1.0f / std::sqrt(head_dim));
VLOG(6) << "[ComputeScaleQ] numel=" << numel << ", scale=" << scale;
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx_, numel, 1);
SimleScaleKernel<T><<<gpu_config.block_per_grid,
gpu_config.thread_per_block,
0,
dev_ctx_.stream()>>>(numel, scale, ptr);
}
const phi::GPUContext& dev_ctx_;
bool merge_qkv_;
int fa_batch_size_;
int fa_total_q_;
int fa_total_k_;
int fa_num_heads_;
int fa_head_size_;
int fa_max_seqlen_q_;
int fa_max_seqlen_k_;
int fa_num_splits_;
int fa_softmax_lse_dim_;
float fa_softmax_scale_{1.0f};
float fa_dropout_prob_{0.0f};
uint64_t fa_seed_{0};
uint64_t fa_offset_{0};
bool fa_zero_tensors_{false};
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -157,6 +157,9 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -157,6 +157,9 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.AsIntermediate() .AsIntermediate()
.AsDispensable(); .AsDispensable();
AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
AddOutput("SoftmaxLse", "Result of the flash attention.")
.AsIntermediate()
.AsDispensable();
AddOutput("FMHAOut", "Result in fmha.").AsIntermediate(); AddOutput("FMHAOut", "Result in fmha.").AsIntermediate();
AddOutput("GateOut", "Result of the gating module.") AddOutput("GateOut", "Result of the gating module.")
.AsIntermediate() .AsIntermediate()
...@@ -170,6 +173,11 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -170,6 +173,11 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"if true, calculation with merged qkv, " "if true, calculation with merged qkv, "
"[default true].") "[default true].")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>(
"use_flash_attn",
"if true, the attention op will be computed in flash_attn branch, "
"[default false].")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Add fused attention op whose logic is as follows: Add fused attention op whose logic is as follows:
{ {
...@@ -223,15 +231,15 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -223,15 +231,15 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), OP_INOUT_CHECK(ctx->HasInput("QueryWeight"),
"Input", "Input",
"QueryWeight", "QueryWeight",
"fused_aate_attention_arad"); "fused_gate_attention_arad");
OP_INOUT_CHECK(ctx->HasInput("KeyWeight"), OP_INOUT_CHECK(ctx->HasInput("KeyWeight"),
"Input", "Input",
"KeyWeight", "KeyWeight",
"fused_aate_attention_arad"); "fused_gate_attention_arad");
OP_INOUT_CHECK(ctx->HasInput("ValueWeight"), OP_INOUT_CHECK(ctx->HasInput("ValueWeight"),
"Input", "Input",
"ValueWeight", "ValueWeight",
"fused_aate_attention_arad"); "fused_gate_attention_arad");
for (auto& name : {"QueryWeight", "KeyWeight", "ValueWeight"}) { for (auto& name : {"QueryWeight", "KeyWeight", "ValueWeight"}) {
ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name)); ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name));
...@@ -259,6 +267,27 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -259,6 +267,27 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias")); ctx->GetInputDim("OutLinearBias"));
} }
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<phi::DenseTensor>("Query");
auto input_data_type = framework::TransToProtoVarType(input->dtype());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "SoftmaxLse") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}; };
template <typename T> template <typename T>
...@@ -276,11 +305,18 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -276,11 +305,18 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
bool merge_qkv = PADDLE_GET_CONST(bool, op->GetAttr("merge_qkv")); bool merge_qkv = PADDLE_GET_CONST(bool, op->GetAttr("merge_qkv"));
bool use_flash_attn = PADDLE_GET_CONST(bool, op->GetAttr("use_flash_attn"));
if (merge_qkv) { if (merge_qkv) {
op->SetInput("QKVWeight", this->Input("QKVWeight")); op->SetInput("QKVWeight", this->Input("QKVWeight"));
op->SetOutput(framework::GradVarName("QKVWeight"), op->SetOutput(framework::GradVarName("QKVWeight"),
this->InputGrad("QKVWeight")); this->InputGrad("QKVWeight"));
op->SetInput("QKVTransposeOut", this->Output("QKVTransposeOut")); op->SetInput("QKVTransposeOut", this->Output("QKVTransposeOut"));
if (use_flash_attn) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SoftmaxLse", this->Output("SoftmaxLse"));
}
} else { } else {
op->SetInput("Key", this->Input("Key")); op->SetInput("Key", this->Input("Key"));
op->SetOutput(framework::GradVarName("Key"), this->InputGrad("Key")); op->SetOutput(framework::GradVarName("Key"), this->InputGrad("Key"));
......
...@@ -371,17 +371,16 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -371,17 +371,16 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
auto *v_transpose_out = ctx.Output<phi::DenseTensor>("ValueTransposeOut"); auto *v_transpose_out = ctx.Output<phi::DenseTensor>("ValueTransposeOut");
auto *qkv_transpose_out = ctx.Output<phi::DenseTensor>("QKVTransposeOut"); auto *qkv_transpose_out = ctx.Output<phi::DenseTensor>("QKVTransposeOut");
auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut"); auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");
auto *gate_out = ctx.Output<phi::DenseTensor>("GateOut"); auto *gate_out = ctx.Output<phi::DenseTensor>("GateOut");
auto *out = ctx.Output<phi::DenseTensor>("Out"); auto *out = ctx.Output<phi::DenseTensor>("Out");
const bool merge_qkv = ctx.Attr<bool>("merge_qkv"); const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating"); const bool has_gating = ctx.Attr<bool>("has_gating");
const bool use_flash_attn = ctx.Attr<bool>("use_flash_attn");
bool use_fused_matmul_bias = true; bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out); AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
if (has_gating) { if (has_gating) {
AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out); AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
...@@ -389,8 +388,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -389,8 +388,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
AllocWithDebugInfo<T>(dev_ctx, "out", out); AllocWithDebugInfo<T>(dev_ctx, "out", out);
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged. // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
GateAttentionConfig<T> config( GateAttentionConfig<T> config(dev_ctx,
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); query,
key,
query_weight,
qkv_weight,
merge_qkv,
has_gating,
use_flash_attn);
if (merge_qkv) { if (merge_qkv) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -406,6 +411,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -406,6 +411,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
phi::DenseTensor *qkv_out = config.GetQKVOut(); phi::DenseTensor *qkv_out = config.GetQKVOut();
ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out); ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);
if (config.CanUseFlashAttn()) {
qkv_transpose_out->Resize(phi::make_ddim({3,
config.batch_size,
config.seq_len_m,
config.seq_len_r,
config.num_heads,
config.head_dim}));
}
AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out); AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
} else { } else {
// 1. Separated QKV Matmul // 1. Separated QKV Matmul
...@@ -421,17 +434,31 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -421,17 +434,31 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
} }
// 2. FMHA // 2. FMHA
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); if (config.CanUseFlashAttn()) {
fmha_compute.ComputeForward(nonbatched_bias, auto *softmax_lse = ctx.Output<phi::DenseTensor>("SoftmaxLse");
src_mask, auto fmha_compute = FlashAttnWithGating<T>(dev_ctx, merge_qkv);
q_transpose_out, fmha_compute.ComputeForward(nonbatched_bias,
k_transpose_out, src_mask,
v_transpose_out, qkv_transpose_out,
qkv_transpose_out, softmax_lse,
softmax_out, fmha_out,
fmha_out, &config);
gate_out, } else {
&config); auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward(nonbatched_bias,
src_mask,
q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
fmha_out,
gate_out,
&config);
}
// 3. Gating Linear // 3. Gating Linear
if (has_gating) { if (has_gating) {
...@@ -465,7 +492,6 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -465,7 +492,6 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
ctx.Input<phi::DenseTensor>("ValueTransposeOut"); ctx.Input<phi::DenseTensor>("ValueTransposeOut");
const auto *qkv_transpose_out = const auto *qkv_transpose_out =
ctx.Input<phi::DenseTensor>("QKVTransposeOut"); ctx.Input<phi::DenseTensor>("QKVTransposeOut");
const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
const auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut"); const auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
const auto *gate_out = ctx.Input<phi::DenseTensor>("GateOut"); const auto *gate_out = ctx.Input<phi::DenseTensor>("GateOut");
...@@ -477,13 +503,20 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -477,13 +503,20 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
bool has_gating = ctx.Attr<bool>("has_gating"); bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv"); bool merge_qkv = ctx.Attr<bool>("merge_qkv");
bool use_flash_attn = ctx.Attr<bool>("use_flash_attn");
bool use_fused_matmul_bias = true; bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad); AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
GateAttentionGradConfig<T> config( GateAttentionGradConfig<T> config(dev_ctx,
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); query,
key,
query_weight,
qkv_weight,
merge_qkv,
has_gating,
use_flash_attn);
phi::DenseTensor fmha_out_grad; phi::DenseTensor fmha_out_grad;
fmha_out_grad.Resize(config.gate_out_dims); fmha_out_grad.Resize(config.gate_out_dims);
...@@ -518,16 +551,36 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -518,16 +551,36 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad); dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad);
} }
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); if (config.CanUseFlashAttn()) {
fmha_compute.ComputeBackward(q_transpose_out, const auto *nonbatched_bias =
k_transpose_out, ctx.Input<phi::DenseTensor>("NonbatchedBias");
v_transpose_out, const auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
qkv_transpose_out, const auto *softmax_lse = ctx.Input<phi::DenseTensor>("SoftmaxLse");
softmax_out,
&fmha_out_grad, auto fmha_compute = FlashAttnWithGating<T>(dev_ctx, merge_qkv);
nullptr, fmha_compute.ComputeBackward(qkv_transpose_out,
nonbatched_bias_grad, src_mask,
&config); nonbatched_bias,
softmax_lse,
fmha_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
} else {
const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward(q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
}
bool use_addto = has_gating ? true : false; bool use_addto = has_gating ? true : false;
if (merge_qkv) { if (merge_qkv) {
......
...@@ -43,9 +43,11 @@ extern void* flashattn_dso_handle; ...@@ -43,9 +43,11 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name) DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \ #define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \ __macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \ __macro(flash_attn_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error); __macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP); FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
......
...@@ -274,6 +274,7 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -274,6 +274,7 @@ class TestFusedGateAttentionOp(OpTest):
_, _,
_, _,
softmax_out, softmax_out,
_,
fmha_out, fmha_out,
gate_out, gate_out,
out, out,
......
...@@ -83,6 +83,7 @@ def fused_gate_attention( ...@@ -83,6 +83,7 @@ def fused_gate_attention(
attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None. attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None.
has_gating (bool, optional): Whether has the gating linear. Default True. has_gating (bool, optional): Whether has the gating linear. Default True.
merge_qkv (bool, optional): Whether has the gating linear. Default True. merge_qkv (bool, optional): Whether has the gating linear. Default True.
use_flash_attn (bool, optional): Whether use flash-attention to speedup. Default False.
Returns: Returns:
Tensor: The output Tensor, the data type and shape is same as `query`. Tensor: The output Tensor, the data type and shape is same as `query`.
...@@ -142,7 +143,7 @@ def fused_gate_attention( ...@@ -142,7 +143,7 @@ def fused_gate_attention(
""" """
if _non_static_mode(): if _non_static_mode():
_, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention( _, _, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention(
query, query,
key, key,
query_weight, query_weight,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册