未验证 提交 d29c1f8e 编写于 作者: L limingshu 提交者: GitHub

Add flash attention to speedup fused_gate_attention. (#52731)

* Reorganize the forward codes of flash-attention.

* Fix forward.

* Remove some noused codes.

* Simplify codes and fix backward.

* Change all LOG(INFO) to VLOG and fix the backward.

* add scale for AF2 flash_attn, much thanks to xreki and shaojie for debug these codes

* decrease the effect of debug print on performance

* Unify the initialize of flashattn arguments.

* Rewirte the reshape of temp_mask and temp_bias.

* API support use_flash_attn.

* Fix compiling error on CI.

* Try to crop the flash-attention lib.

* Correct the condition of whether can use flash-attn.

* Remove the softmax_out argument.

* Remove is_causal.

* Polish codes.

* Fix qkv_transpose_out's shape and scaling of Q * K.

* Update commit of flash-attention.

---------
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 4dc28b54
......@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_TAG 18106c1ba0ccee81b97ca947397c08a141815a47)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
......
......@@ -27,6 +27,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
fused_gate_attention_dygraph_function(
const paddle::Tensor& Query,
......
......@@ -26,6 +26,7 @@ std::tuple<paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor>
fused_gate_attention_dygraph_function(
const paddle::Tensor& Query,
......@@ -181,6 +182,9 @@ fused_gate_attention_dygraph_function(
{"SoftmaxOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"SoftmaxLse",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"FMHAOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
......@@ -256,6 +260,8 @@ fused_gate_attention_dygraph_function(
egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut);
paddle::Tensor SoftmaxOut;
egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut);
paddle::Tensor SoftmaxLse;
egr::EagerUtils::GetOutput(outs["SoftmaxLse"][0], &SoftmaxLse);
paddle::Tensor FMHAOut;
egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut);
paddle::Tensor GateOut;
......@@ -296,7 +302,7 @@ fused_gate_attention_dygraph_function(
p_autograd_Out);
// Create GradOpNode
auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>(
new fused_gate_attentionGradNodeCompat(8, 12));
new fused_gate_attentionGradNodeCompat(9, 12));
bool merge_qkv = true;
if (attrs.count("merge_qkv")) {
......@@ -308,6 +314,11 @@ fused_gate_attention_dygraph_function(
has_gating = PADDLE_GET_CONST(bool, attrs.at("has_gating"));
}
bool use_flash_attn = false;
if (attrs.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attrs.at("use_flash_attn"));
}
// Set Attributes
grad_node->SetAttrMap(std::move(attrs));
grad_node->SetDefaultAttrMap(std::move(default_attrs));
......@@ -354,6 +365,12 @@ fused_gate_attention_dygraph_function(
grad_node->SetGradOutMeta(NonbatchedBias, 6);
}
if (use_flash_attn) {
grad_node->SetTensorWrapperSoftmaxLse(SoftmaxLse);
grad_node->SetTensorWrapperSrcMask(SrcMask);
grad_node->SetGradOutMeta(SrcMask, 7);
}
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0);
grad_node->SetGradInMeta(QueryTransposeOut, 0);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1);
......@@ -379,6 +396,7 @@ fused_gate_attention_dygraph_function(
ValueTransposeOut,
QKVTransposeOut,
SoftmaxOut,
SoftmaxLse,
FMHAOut,
GateOut,
Out);
......
......@@ -45,6 +45,11 @@ fused_gate_attentionGradNodeCompat::operator()(
has_gating = PADDLE_GET_CONST(bool, attr_map_.at("has_gating"));
}
bool use_flash_attn = false;
if (attr_map_.count("use_flash_attn")) {
use_flash_attn = PADDLE_GET_CONST(bool, attr_map_.at("use_flash_attn"));
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 =
{{"FMHAOut",
egr::EagerUtils::TrySyncToVars(
......@@ -168,6 +173,13 @@ fused_gate_attentionGradNodeCompat::operator()(
egr::Controller::Instance().GenerateUniqueName())};
}
if (use_flash_attn) {
auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_);
ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask);
auto SoftmaxLse = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxLse_);
ins0["SoftmaxLse"] = egr::EagerUtils::TrySyncToVars(SoftmaxLse);
}
auto& attrs_map0 = this->attr_map_;
// Pass the entire attribute map to TraceOp
// The underlying kernel will pickup whatever attribute they need at runtime
......
......@@ -61,12 +61,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
GateOut_.clear();
GateWeight_.clear();
NonbatchedBias_.clear();
SrcMask_.clear();
OutLinearBias_.clear();
OutLinearWeight_.clear();
QKVTransposeOut_.clear();
QKVWeight_.clear();
Query_.clear();
SoftmaxOut_.clear();
SoftmaxLse_.clear();
Key_.clear();
QueryWeight_.clear();
KeyWeight_.clear();
......@@ -103,6 +105,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperNonbatchedBias(const paddle::Tensor& NonbatchedBias) {
NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false);
}
void SetTensorWrapperSrcMask(const paddle::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false);
}
void SetTensorWrapperOutLinearBias(const paddle::Tensor& OutLinearBias) {
OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false);
}
......@@ -121,6 +126,9 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
void SetTensorWrapperSoftmaxOut(const paddle::Tensor& SoftmaxOut) {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
}
void SetTensorWrapperSoftmaxLse(const paddle::Tensor& SoftmaxLse) {
SoftmaxLse_ = egr::TensorWrapper(SoftmaxLse, false);
}
void SetTensorWrapperKey(const paddle::Tensor& Key) {
Key_ = egr::TensorWrapper(Key, false);
}
......@@ -160,12 +168,14 @@ class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
egr::TensorWrapper GateOut_;
egr::TensorWrapper GateWeight_;
egr::TensorWrapper NonbatchedBias_;
egr::TensorWrapper SrcMask_;
egr::TensorWrapper OutLinearBias_;
egr::TensorWrapper OutLinearWeight_;
egr::TensorWrapper QKVTransposeOut_;
egr::TensorWrapper QKVWeight_;
egr::TensorWrapper Query_;
egr::TensorWrapper SoftmaxOut_;
egr::TensorWrapper SoftmaxLse_;
egr::TensorWrapper Key_;
egr::TensorWrapper QueryWeight_;
......
......@@ -157,6 +157,9 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.AsIntermediate()
.AsDispensable();
AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
AddOutput("SoftmaxLse", "Result of the flash attention.")
.AsIntermediate()
.AsDispensable();
AddOutput("FMHAOut", "Result in fmha.").AsIntermediate();
AddOutput("GateOut", "Result of the gating module.")
.AsIntermediate()
......@@ -170,6 +173,11 @@ class FusedGateAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"if true, calculation with merged qkv, "
"[default true].")
.SetDefault(true);
AddAttr<bool>(
"use_flash_attn",
"if true, the attention op will be computed in flash_attn branch, "
"[default false].")
.SetDefault(false);
AddComment(R"DOC(
Add fused attention op whose logic is as follows:
{
......@@ -223,15 +231,15 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("QueryWeight"),
"Input",
"QueryWeight",
"fused_aate_attention_arad");
"fused_gate_attention_arad");
OP_INOUT_CHECK(ctx->HasInput("KeyWeight"),
"Input",
"KeyWeight",
"fused_aate_attention_arad");
"fused_gate_attention_arad");
OP_INOUT_CHECK(ctx->HasInput("ValueWeight"),
"Input",
"ValueWeight",
"fused_aate_attention_arad");
"fused_gate_attention_arad");
for (auto& name : {"QueryWeight", "KeyWeight", "ValueWeight"}) {
ctx->SetOutputDim(framework::GradVarName(name), ctx->GetInputDim(name));
......@@ -259,6 +267,27 @@ class FusedGateAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<phi::DenseTensor>("Query");
auto input_data_type = framework::TransToProtoVarType(input->dtype());
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (var_name == "SoftmaxLse") {
return phi::KernelKey(phi::Backend::ALL_BACKEND,
expected_kernel_type.layout(),
expected_kernel_type.dtype());
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
template <typename T>
......@@ -276,11 +305,18 @@ class FusedGateAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
bool merge_qkv = PADDLE_GET_CONST(bool, op->GetAttr("merge_qkv"));
bool use_flash_attn = PADDLE_GET_CONST(bool, op->GetAttr("use_flash_attn"));
if (merge_qkv) {
op->SetInput("QKVWeight", this->Input("QKVWeight"));
op->SetOutput(framework::GradVarName("QKVWeight"),
this->InputGrad("QKVWeight"));
op->SetInput("QKVTransposeOut", this->Output("QKVTransposeOut"));
if (use_flash_attn) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SoftmaxLse", this->Output("SoftmaxLse"));
}
} else {
op->SetInput("Key", this->Input("Key"));
op->SetOutput(framework::GradVarName("Key"), this->InputGrad("Key"));
......
......@@ -371,17 +371,16 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
auto *v_transpose_out = ctx.Output<phi::DenseTensor>("ValueTransposeOut");
auto *qkv_transpose_out = ctx.Output<phi::DenseTensor>("QKVTransposeOut");
auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");
auto *gate_out = ctx.Output<phi::DenseTensor>("GateOut");
auto *out = ctx.Output<phi::DenseTensor>("Out");
const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
const bool has_gating = ctx.Attr<bool>("has_gating");
const bool use_flash_attn = ctx.Attr<bool>("use_flash_attn");
bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
if (has_gating) {
AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
......@@ -389,8 +388,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
AllocWithDebugInfo<T>(dev_ctx, "out", out);
// When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
GateAttentionConfig<T> config(
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
GateAttentionConfig<T> config(dev_ctx,
query,
key,
query_weight,
qkv_weight,
merge_qkv,
has_gating,
use_flash_attn);
if (merge_qkv) {
PADDLE_ENFORCE_EQ(
......@@ -406,6 +411,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
phi::DenseTensor *qkv_out = config.GetQKVOut();
ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);
if (config.CanUseFlashAttn()) {
qkv_transpose_out->Resize(phi::make_ddim({3,
config.batch_size,
config.seq_len_m,
config.seq_len_r,
config.num_heads,
config.head_dim}));
}
AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
} else {
// 1. Separated QKV Matmul
......@@ -421,17 +434,31 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
}
// 2. FMHA
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward(nonbatched_bias,
src_mask,
q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
fmha_out,
gate_out,
&config);
if (config.CanUseFlashAttn()) {
auto *softmax_lse = ctx.Output<phi::DenseTensor>("SoftmaxLse");
auto fmha_compute = FlashAttnWithGating<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward(nonbatched_bias,
src_mask,
qkv_transpose_out,
softmax_lse,
fmha_out,
&config);
} else {
auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeForward(nonbatched_bias,
src_mask,
q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
fmha_out,
gate_out,
&config);
}
// 3. Gating Linear
if (has_gating) {
......@@ -465,7 +492,6 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
ctx.Input<phi::DenseTensor>("ValueTransposeOut");
const auto *qkv_transpose_out =
ctx.Input<phi::DenseTensor>("QKVTransposeOut");
const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
const auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
const auto *gate_out = ctx.Input<phi::DenseTensor>("GateOut");
......@@ -477,13 +503,20 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
bool has_gating = ctx.Attr<bool>("has_gating");
bool merge_qkv = ctx.Attr<bool>("merge_qkv");
bool use_flash_attn = ctx.Attr<bool>("use_flash_attn");
bool use_fused_matmul_bias = true;
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
GateAttentionGradConfig<T> config(
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
GateAttentionGradConfig<T> config(dev_ctx,
query,
key,
query_weight,
qkv_weight,
merge_qkv,
has_gating,
use_flash_attn);
phi::DenseTensor fmha_out_grad;
fmha_out_grad.Resize(config.gate_out_dims);
......@@ -518,16 +551,36 @@ class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad);
}
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward(q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
if (config.CanUseFlashAttn()) {
const auto *nonbatched_bias =
ctx.Input<phi::DenseTensor>("NonbatchedBias");
const auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
const auto *softmax_lse = ctx.Input<phi::DenseTensor>("SoftmaxLse");
auto fmha_compute = FlashAttnWithGating<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward(qkv_transpose_out,
src_mask,
nonbatched_bias,
softmax_lse,
fmha_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
} else {
const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
fmha_compute.ComputeBackward(q_transpose_out,
k_transpose_out,
v_transpose_out,
qkv_transpose_out,
softmax_out,
&fmha_out_grad,
nullptr,
nonbatched_bias_grad,
&config);
}
bool use_addto = has_gating ? true : false;
if (merge_qkv) {
......
......@@ -43,9 +43,11 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
......
......@@ -274,6 +274,7 @@ class TestFusedGateAttentionOp(OpTest):
_,
_,
softmax_out,
_,
fmha_out,
gate_out,
out,
......
......@@ -83,6 +83,7 @@ def fused_gate_attention(
attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None.
has_gating (bool, optional): Whether has the gating linear. Default True.
merge_qkv (bool, optional): Whether has the gating linear. Default True.
use_flash_attn (bool, optional): Whether use flash-attention to speedup. Default False.
Returns:
Tensor: The output Tensor, the data type and shape is same as `query`.
......@@ -142,7 +143,7 @@ def fused_gate_attention(
"""
if _non_static_mode():
_, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention(
_, _, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention(
query,
key,
query_weight,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册