未验证 提交 d29c1f8e 编写于 作者: L limingshu 提交者: GitHub

Add flash attention to speedup fused_gate_attention. (#52731)

* Reorganize the forward codes of flash-attention.

* Fix forward.

* Remove some noused codes.

* Simplify codes and fix backward.

* Change all LOG(INFO) to VLOG and fix the backward.

* add scale for AF2 flash_attn, much thanks to xreki and shaojie for debug these codes

* decrease the effect of debug print on performance

* Unify the initialize of flashattn arguments.

* Rewirte the reshape of temp_mask and temp_bias.

* API support use_flash_attn.

* Fix compiling error on CI.

* Try to crop the flash-attention lib.

* Correct the condition of whether can use flash-attn.

* Remove the softmax_out argument.

* Remove is_causal.

* Polish codes.

* Fix qkv_transpose_out's shape and scaling of Q * K.

* Update commit of flash-attention.

---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 4dc28b54
...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) ...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
set(FLASHATTN_INCLUDE_DIR set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include" "${FLASHATTN_INSTALL_DIR}/include"
......
...@@ -27,6 +27,7 @@ std::tuple<paddle::Tensor, ...@@ -27,6 +27,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor,
paddle::Tensor> paddle::Tensor>
fused_gate_attention_dygraph_function( fused_gate_attention_dygraph_function(
const paddle::Tensor& Query, const paddle::Tensor& Query,
......
...@@ -26,6 +26,7 @@ std::tuple<paddle::Tensor, ...@@ -26,6 +26,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor, paddle::Tensor,
paddle::Tensor,
paddle::Tensor> paddle::Tensor>
fused_gate_attention_dygraph_function( fused_gate_attention_dygraph_function(
const paddle::Tensor& Query, const paddle::Tensor& Query,
...@@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function( ...@@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function(
{"SoftmaxOut", {"SoftmaxOut",
{std::make_shared<egr::EagerVariable>( {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}}, egr::Controller::Instance().GenerateUniqueName())}},
{"SoftmaxLse",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"FMHAOut", {"FMHAOut",
{std::make_shared<egr::EagerVariable>( {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}}, egr::Controller::Instance().GenerateUniqueName())}},
...@@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function( ...@@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function(
egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut); egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut);
paddle::Tensor SoftmaxOut; paddle::Tensor SoftmaxOut;
egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut); egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut);
paddle::Tensor SoftmaxLse;
egr::EagerUtils::GetOutput(outs["SoftmaxLse"][0], &SoftmaxLse);
paddle::Tensor FMHAOut; paddle::Tensor FMHAOut;
egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut); egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut);
paddle::Tensor GateOut; paddle::Tensor GateOut;
...@@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function( ...@@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function(
p_autograd_Out); p_autograd_Out);
// Create GradOpNode // Create GradOpNode
auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>( auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>(
new fused_gate_attentionGradNodeCompat(8, 12)); new fused_gate_attentionGradNodeCompat(9, 12));
bool merge_qkv = true; bool merge_qkv = true;
if (attrs.count("merge_qkv")) { if (attrs.count("merge_qkv")) {
...@@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function( ...@@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function(
has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating")); has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating"));
} }
bool use_flash_attn = false;
if (attrs.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attrs.at("use_flash_attn"));
}
// Set Attributes // Set Attributes
grad_node->SetAttrMap(std::move(attrs)); grad_node->SetAttrMap(std::move(attrs));
grad_node->SetDefaultAttrMap(std::move(default_attrs)); grad_node->SetDefaultAttrMap(std::move(default_attrs));
...@@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function( ...@@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function(
grad_node->SetGradOutMeta(NonbatchedBias, 6); grad_node->SetGradOutMeta(NonbatchedBias, 6);
} }
if (use_flash_attn) {
grad_node->SetTensorWrapperSoftmaxLse(SoftmaxLse);
grad_node->SetTensorWrapperSrcMask(SrcMask);
grad_node->SetGradOutMeta(SrcMask, 7);
}
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0); egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0);
grad_node->SetGradInMeta(QueryTransposeOut, 0); grad_node->SetGradInMeta(QueryTransposeOut, 0);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1); egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1);
...@@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function( ...@@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function(
ValueTransposeOut, ValueTransposeOut,
QKVTransposeOut, QKVTransposeOut,
SoftmaxOut, SoftmaxOut,
SoftmaxLse,
FMHAOut, FMHAOut,
GateOut, GateOut,
Out); Out);
......
...@@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()( ...@@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()(
has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating")); has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating"));
} }
bool use_flash_attn = false;
if (attr_map_.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attr_map_.at("use_flash_attn"));
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 = std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 =
{{"FMHAOut", {{"FMHAOut",
egr::EagerUtils::TrySyncToVars( egr::EagerUtils::TrySyncToVars(
...@@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()( ...@@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()(
egr::Controller::Instance().GenerateUniqueName())}; egr::Controller::Instance().GenerateUniqueName())};
} }
if (use_flash_attn) {
auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_);
ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask);
auto SoftmaxLse = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxLse_);
ins0["SoftmaxLse"] = egr::EagerUtils::TrySyncToVars(SoftmaxLse);
}
auto& attrs_map0 = this->attr_map_; auto& attrs_map0 = this->attr_map_;
// Pass the entire attribute map to TraceOp // Pass the entire attribute map to TraceOp
// The underlying kernel will pickup whatever attribute they need at runtime // The underlying kernel will pickup whatever attribute they need at runtime
......
...@@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
GateOut_.clear(); GateOut_.clear();
GateWeight_.clear(); GateWeight_.clear();
NonbatchedBias_.clear(); NonbatchedBias_.clear();
SrcMask_.clear();
OutLinearBias_.clear(); OutLinearBias_.clear();
OutLinearWeight_.clear(); OutLinearWeight_.clear();
QKVTransposeOut_.clear(); QKVTransposeOut_.clear();
QKVWeight_.clear(); QKVWeight_.clear();
Query_.clear(); Query_.clear();
SoftmaxOut_.clear(); SoftmaxOut_.clear();
SoftmaxLse_.clear();
Key_.clear(); Key_.clear();
QueryWeight_.clear(); QueryWeight_.clear();
KeyWeight_.clear(); KeyWeight_.clear();
...@@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) { void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) {
NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false); NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false);
} }
void SetTensorWrapperSrcMask(const paddle::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false);
}
void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) { void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) {
OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false); OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false);
} }
...@@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) { void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
} }
void SetTensorWrapperSoftmaxLse(const paddle::Tensor& SoftmaxLse) {
SoftmaxLse_ = egr::TensorWrapper(SoftmaxLse, false);
}
void SetTensorWrapperKey(const paddle::Tensor& Key) { void SetTensorWrapperKey(const paddle::Tensor& Key) {
Key_ = egr::TensorWrapper(Key, false); Key_ = egr::TensorWrapper(Key, false);
} }
...@@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
egr::TensorWrapper GateOut_; egr::TensorWrapper GateOut_;
egr::TensorWrapper GateWeight_; egr::TensorWrapper GateWeight_;
egr::TensorWrapper NonbatchedBias_; egr::TensorWrapper NonbatchedBias_;
egr::TensorWrapper SrcMask_;
egr::TensorWrapper OutLinearBias_; egr::TensorWrapper OutLinearBias_;
egr::TensorWrapper OutLinearWeight_; egr::TensorWrapper OutLinearWeight_;
egr::TensorWrapper QKVTransposeOut_; egr::TensorWrapper QKVTransposeOut_;
egr::TensorWrapper QKVWeight_; egr::TensorWrapper QKVWeight_;
egr::TensorWrapper Query_; egr::TensorWrapper Query_;
egr::TensorWrapper SoftmaxOut_; egr::TensorWrapper SoftmaxOut_;
egr::TensorWrapper SoftmaxLse_;
egr::TensorWrapper Key_; egr::TensorWrapper Key_;
egr::TensorWrapper QueryWeight_; egr::TensorWrapper QueryWeight_;
......
...@@ -157,6 +157,9 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -157,6 +157,9 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.AsIntermediate() .AsIntermediate()
.AsDispensable(); .AsDispensable();
AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
AddOutput("SoftmaxLse", "Result of the flash attention.")
.AsIntermediate()
.AsDispensable();
AddOutput("FMHAOut", "Result in fmha.").AsIntermediate(); AddOutput("FMHAOut", "Result in fmha.").AsIntermediate();
AddOutput("GateOut", "Result of the gating module.") AddOutput("GateOut", "Result of the gating module.")
.AsIntermediate() .AsIntermediate()
...@@ -170,6 +173,11 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -170,6 +173,11 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"if true, calculation with merged qkv, " "if true, calculation with merged qkv, "
"[default true].") "[default true].")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>(
"use_flash_attn",
"if true, the attention op will be computed in flash_attn branch, "
"[default false].")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Add fused attention op whose logic is as follows: Add fused attention op whose logic is as follows:
{ {
...@@ -223,15 +231,15 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -223,15 +231,15 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("QueryWeight"), OP_INOUT_CHECK(ctx->HasInput("QueryWeight"),
"Input", "Input",
"QueryWeight", "QueryWeight",
"fused_aate_attention_arad"); "fused_gate_attention_arad");
OP_INOUT_CHECK(ctx->HasInput("KeyWeight"), OP_INOUT_CHECK(ctx->HasInput("KeyWeight"),
"Input", "Input",
"KeyWeight", "KeyWeight",
"fused_aate_attention_arad"); "fused_gate_attention_arad");
OP_INOUT_CHECK(ctx->HasInput("ValueWeight"), OP_INOUT_CHECK(ctx->HasInput("ValueWeight"),
"Input", "Input",
"ValueWeight", "ValueWeight",
"fused_aate_attention_arad"); "fused_gate_attention_arad");
for (auto& name : {"QueryWeight", "KeyWeight", "ValueWeight"}) { for (auto& name : {"QueryWeight", "KeyWeight", "ValueWeight"}) {
ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name)); ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name));
...@@ -259,6 +267,27 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel { ...@@ -259,6 +267,27 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias")); ctx->GetInputDim("OutLinearBias"));
} }
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<phi::DenseTensor>("Query");
auto input_data_type = framework::TransToProtoVarType(input->dtype());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "SoftmaxLse") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}; };
template <typename T> template <typename T>
...@@ -276,11 +305,18 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -276,11 +305,18 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
bool merge_qkv = PADDLE_GET_CONST(bool, op->GetAttr("merge_qkv")); bool merge_qkv = PADDLE_GET_CONST(bool, op->GetAttr("merge_qkv"));
bool use_flash_attn = PADDLE_GET_CONST(bool, op->GetAttr("use_flash_attn"));
if (merge_qkv) { if (merge_qkv) {
op->SetInput("QKVWeight", this->Input("QKVWeight")); op->SetInput("QKVWeight", this->Input("QKVWeight"));
op->SetOutput(framework::GradVarName("QKVWeight"), op->SetOutput(framework::GradVarName("QKVWeight"),
this->InputGrad("QKVWeight")); this->InputGrad("QKVWeight"));
op->SetInput("QKVTransposeOut", this->Output("QKVTransposeOut")); op->SetInput("QKVTransposeOut", this->Output("QKVTransposeOut"));
if (use_flash_attn) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SoftmaxLse", this->Output("SoftmaxLse"));
}
} else { } else {
op->SetInput("Key", this->Input("Key")); op->SetInput("Key", this->Input("Key"));
op->SetOutput(framework::GradVarName("Key"), this->InputGrad("Key")); op->SetOutput(framework::GradVarName("Key"), this->InputGrad("Key"));
......
...@@ -371,17 +371,16 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -371,17 +371,16 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
auto *v_transpose_out = ctx.Output<phi::DenseTensor>("ValueTransposeOut"); auto *v_transpose_out = ctx.Output<phi::DenseTensor>("ValueTransposeOut");
auto *qkv_transpose_out = ctx.Output<phi::DenseTensor>("QKVTransposeOut"); auto *qkv_transpose_out = ctx.Output<phi::DenseTensor>("QKVTransposeOut");
auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut"); auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");
auto *gate_out = ctx.Output<phi::DenseTensor>("GateOut"); auto *gate_out = ctx.Output<phi::DenseTensor>("GateOut");
auto *out = ctx.Output<phi::DenseTensor>("Out"); auto *out = ctx.Output<phi::DenseTensor>("Out");
const bool merge_qkv = ctx.Attr<bool>("merge_qkv"); const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating"); const bool has_gating = ctx.Attr<bool>("has_gating");
const bool use_flash_attn = ctx.Attr<bool>("use_flash_attn");
bool use_fused_matmul_bias = true; bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out); AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
if (has_gating) { if (has_gating) {
AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out); AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
...@@ -389,8 +388,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -389,8 +388,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
AllocWithDebugInfo<T>(dev_ctx, "out", out); AllocWithDebugInfo<T>(dev_ctx, "out", out);
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged. // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
GateAttentionConfig<T> config( GateAttentionConfig<T> config(dev_ctx,
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); query,
key,
query_weight,
qkv_weight,
merge_qkv,
has_gating,
use_flash_attn);
if (merge_qkv) { if (merge_qkv) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -406,6 +411,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -406,6 +411,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
phi::DenseTensor *qkv_out = config.GetQKVOut(); phi::DenseTensor *qkv_out = config.GetQKVOut();
ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out); ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);
if (config.CanUseFlashAttn()) {
qkv_transpose_out->Resize(phi::make_ddim({3,
config.batch_size,
config.seq_len_m,
config.seq_len_r,
config.num_heads,
config.head_dim}));
}
AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out); AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
} else { } else {
// 1. Separated QKV Matmul // 1. Separated QKV Matmul
...@@ -421,17 +434,31 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> { ...@@ -421,17 +434,31 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
} }
// 2. FMHA // 2. FMHA
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); if (config.CanUseFlashAttn()) {
fmha_compute.ComputeForward(nonbatched_bias, auto *softmax_lse = ctx.Output<phi::DenseTensor>("SoftmaxLse");
src_mask, auto fmha_compute = FlashAttnWithGating<T>(dev_ctx, merge_qkv);
q_transpose_out, fmha_compute.ComputeForward(nonbatched_bias,
k_transpose_out, src_mask,
v_transpose_out, qkv_transpose_out,
qkv_transpose_out, softmax_lse,
softmax_out, fmha_out,
fmha_out, &config);
gate_out, } else {
&config); auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward(nonbatched_bias,
src_mask,
q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
fmha_out,
gate_out,
&config);
}
// 3. Gating Linear // 3. Gating Linear
if (has_gating) { if (has_gating) {
...@@ -465,7 +492,6 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -465,7 +492,6 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
ctx.Input<phi::DenseTensor>("ValueTransposeOut"); ctx.Input<phi::DenseTensor>("ValueTransposeOut");
const auto *qkv_transpose_out = const auto *qkv_transpose_out =
ctx.Input<phi::DenseTensor>("QKVTransposeOut"); ctx.Input<phi::DenseTensor>("QKVTransposeOut");
const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
const auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut"); const auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
const auto *gate_out = ctx.Input<phi::DenseTensor>("GateOut"); const auto *gate_out = ctx.Input<phi::DenseTensor>("GateOut");
...@@ -477,13 +503,20 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -477,13 +503,20 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
bool has_gating = ctx.Attr<bool>("has_gating"); bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv"); bool merge_qkv = ctx.Attr<bool>("merge_qkv");
bool use_flash_attn = ctx.Attr<bool>("use_flash_attn");
bool use_fused_matmul_bias = true; bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>(); auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad); AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
GateAttentionGradConfig<T> config( GateAttentionGradConfig<T> config(dev_ctx,
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); query,
key,
query_weight,
qkv_weight,
merge_qkv,
has_gating,
use_flash_attn);
phi::DenseTensor fmha_out_grad; phi::DenseTensor fmha_out_grad;
fmha_out_grad.Resize(config.gate_out_dims); fmha_out_grad.Resize(config.gate_out_dims);
...@@ -518,16 +551,36 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> { ...@@ -518,16 +551,36 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad); dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad);
} }
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv); if (config.CanUseFlashAttn()) {
fmha_compute.ComputeBackward(q_transpose_out, const auto *nonbatched_bias =
k_transpose_out, ctx.Input<phi::DenseTensor>("NonbatchedBias");
v_transpose_out, const auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
qkv_transpose_out, const auto *softmax_lse = ctx.Input<phi::DenseTensor>("SoftmaxLse");
softmax_out,
&fmha_out_grad, auto fmha_compute = FlashAttnWithGating<T>(dev_ctx, merge_qkv);
nullptr, fmha_compute.ComputeBackward(qkv_transpose_out,
nonbatched_bias_grad, src_mask,
&config); nonbatched_bias,
softmax_lse,
fmha_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
} else {
const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward(q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
}
bool use_addto = has_gating ? true : false; bool use_addto = has_gating ? true : false;
if (merge_qkv) { if (merge_qkv) {
......
...@@ -43,9 +43,11 @@ extern void* flashattn_dso_handle; ...@@ -43,9 +43,11 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name) DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \ #define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \ __macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \ __macro(flash_attn_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error); __macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP); FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
......
...@@ -274,6 +274,7 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -274,6 +274,7 @@ class TestFusedGateAttentionOp(OpTest):
_, _,
_, _,
softmax_out, softmax_out,
_,
fmha_out, fmha_out,
gate_out, gate_out,
out, out,
......
...@@ -83,6 +83,7 @@ def fused_gate_attention( ...@@ -83,6 +83,7 @@ def fused_gate_attention(
attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None. attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None.
has_gating (bool, optional): Whether has the gating linear. Default True. has_gating (bool, optional): Whether has the gating linear. Default True.
merge_qkv (bool, optional): Whether has the gating linear. Default True. merge_qkv (bool, optional): Whether has the gating linear. Default True.
use_flash_attn (bool, optional): Whether use flash-attention to speedup. Default False.
Returns: Returns:
Tensor: The output Tensor, the data type and shape is same as `query`. Tensor: The output Tensor, the data type and shape is same as `query`.
...@@ -142,7 +143,7 @@ def fused_gate_attention( ...@@ -142,7 +143,7 @@ def fused_gate_attention(
""" """
if _non_static_mode(): if _non_static_mode():
_, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention( _, _, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention(
query, query,
key, key,
query_weight, query_weight,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册