未验证 提交 1a8786cf 编写于 作者: L Li Min 提交者: GitHub

Add support bias is none for fused_attention op. (#37411)

Add support for bias is none for fused_attention op.
上级 4812eda5
...@@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
...@@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head] // qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp"); "FusedAttentionOp");
if (ctx->HasInput("QKVBias")) {
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp"); "FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut",
...@@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]", "input qkv_weight = [%s]",
x_dim, y_dim)); x_dim, y_dim));
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
...@@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [batch_size, seq_len, 3, num_head, head_size] // [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut", ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut", ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
}
// [3, batch_size, num_head, seq_len, head_size] // [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2", ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
...@@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"H. Here, H represents the last dimension of its input tensor.") "H. Here, H represents the last dimension of its input tensor.")
.AsDispensable(); .AsDispensable();
AddInput("QKVW", "The qkv weight tensor."); AddInput("QKVW", "The qkv weight tensor.");
AddInput("QKVBias", "The qkv bias tensor."); AddInput("QKVBias", "The qkv bias tensor.").AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable(); .AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor."); AddInput("OutLinearW", "The out_linear weight tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor."); AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable();
AddInput("Ln2Scale", AddInput("Ln2Scale",
"(optional) Scale is a 1-dimensional tensor of size " "(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.") "H. Here, H represents the last dimension of its input tensor.")
...@@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
if (ctx->HasOutput(framework::GradVarName("LnScale"))) { if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
...@@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias")); ctx->GetInputDim("OutLinearBias"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearW"), ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW")); ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
if (ctx->HasOutput(framework::GradVarName("QKVBias"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBias"), ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias")); ctx->GetInputDim("QKVBias"));
}
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"), ctx->SetOutputDim(framework::GradVarName("LnOut"),
...@@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut")); ctx->GetInputDim("QKVOut"));
if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut")); ctx->GetInputDim("QKVBiasOut"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut")); ctx->GetInputDim("OutLinearOut"));
} }
...@@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
// inputs x, parameters and their grad. // inputs x, parameters and their grad.
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW")); op->SetInput("QKVW", this->Input("QKVW"));
if (this->HasInput("QKVBias")) {
op->SetInput("QKVBias", this->Input("QKVBias")); op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
}
if (this->HasInput("SrcMask")) { if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask")); op->SetInput("SrcMask", this->Input("SrcMask"));
...@@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearW", this->Input("OutLinearW"));
if (this->HasInput("OutLinearBias")) {
op->SetInput("OutLinearBias", this->Input("OutLinearBias")); op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
}
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
bool is_pre_layer_norm = bool is_pre_layer_norm =
...@@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW")); op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearW"), op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW")); this->InputGrad("OutLinearW"));
...@@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->Output("BiasDropoutResidualOut")); this->Output("BiasDropoutResidualOut"));
} }
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2")); op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut")); op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut")); op->SetInput("QKTVOut", this->Output("QKTVOut"));
...@@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKTVOut"), op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut")); this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"), op->SetOutput(framework::GradVarName("TransposeOut2"),
......
...@@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *x_data = input_x->data<T>(); auto *x_data = input_x->data<T>();
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace()); auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace()); auto *qkv_bias_out_data =
(qkv_bias == nullptr) ? nullptr
: qkv_bias_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for FMHA. // get data ptr for FMHA.
auto *transpose_out_2_data = auto *transpose_out_2_data =
...@@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// get data ptr for out_linear. // get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>(); auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>(); auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace()); auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for bias+dropout+residual+layernorm // get data ptr for bias+dropout+residual+layernorm
...@@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(), auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed); epsilon, bsz_seq, dim_embed);
bool compute_bias = true;
if (qkv_bias == nullptr) {
compute_bias = false;
}
// (transA, transB, compute_bias) = (false, true, true) // (transA, transB, compute_bias) = (false, true, true)
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, true, auto qkv_compute =
bsz_seq, output_size, input_size, true); AttnMatMul<T>(ctx.cuda_device_context(), false, true, bsz_seq,
output_size, input_size, compute_bias);
AttnDropoutParam attn_dropout_param( AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_rate, is_test_1, dropout_implementation_1, attn_dropout_rate,
...@@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out, qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_bias_out); qkv_bias_out);
} }
if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
} else {
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out, qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out, attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out); qktv_out, fmha_out);
}
// fmha_out: [batch_size, seq_len, num_head, head_dim] // fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim] // weight: [embed_dim, embed_dim]
...@@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias"); auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>()); auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>(); auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>(); auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
// fw output // fw output
auto *fmha_out = ctx.Input<Tensor>("FMHAOut"); auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
...@@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_bias_dropout_residual_out = auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut")); ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace()); auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace()); // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace()); // space can be reused.
auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
? nullptr
: d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data =
(d_qkv_bias_out == nullptr)
? nullptr
: d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
auto *d_transpose_out_2_data = auto *d_transpose_out_2_data =
d_transpose_out_2->mutable_data<T>(ctx.GetPlace()); d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
...@@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias")); auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_bias_data = (d_qkv_bias == nullptr)
? nullptr
: d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data = auto *d_out_linear_weight_data =
d_out_linear_weight->mutable_data<T>(ctx.GetPlace()); d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data = auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace()); (d_out_linear_bias == nullptr)
? nullptr
: d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
const auto input_x_dims = input_x->dims(); const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims(); const auto qkv_w_dims = qkv_weight->dims();
...@@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
bool transA = false; bool transA = false;
bool transB = true; bool transB = true;
bool compute_bias = true; bool compute_qkv_bias = true;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(), auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed); epsilon, bsz_seq, dim_embed);
auto qkv_compute = auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq, AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias); output_size, input_size, compute_qkv_bias);
AttnDropoutParam attn_dropout_param( AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_prob, is_test_1, dropout_implementation_1, attn_dropout_prob,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
...@@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
output_size = hidden_size; output_size = hidden_size;
transA = false; transA = false;
transB = false; transB = false;
compute_bias = false; bool compute_bias = false;
auto out_linear_compute = auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq, AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias); output_size, input_size, compute_bias);
...@@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out, d_fmha_out, d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr); d_out_linear_weight, nullptr);
if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward( fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out); d_transpose_out_2, nullptr, d_qkv_bias_out);
cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data, } else {
bsz_seq * 3 * num_head * dim_head * sizeof(T), fmha_ref_compute.ComputeBackward(
cudaMemcpyDeviceToDevice); *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_out);
}
if (pre_layer_norm) { if (pre_layer_norm) {
auto *ln_mean = ctx.Input<Tensor>("LnMean"); auto *ln_mean = ctx.Input<Tensor>("LnMean");
...@@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_bias_data = auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr (d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace())); : d_ln_bias->mutable_data<U>(ctx.GetPlace()));
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out, qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out,
d_ln_out, d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out,
d_qkv_weight, d_qkv_bias); d_qkv_weight, d_qkv_bias);
}
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data, ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data); d_ln_scale_data, d_ln_bias_data);
} else { } else {
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias); d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x,
d_qkv_weight, d_qkv_bias);
}
} }
// gradient accumulation // gradient accumulation
std::vector<const Tensor *> ins; std::vector<const Tensor *> ins;
......
...@@ -168,15 +168,27 @@ class TestFusedAttentionOp(OpTest): ...@@ -168,15 +168,27 @@ class TestFusedAttentionOp(OpTest):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor( q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False) self.q_proj.weight, stop_gradient=False)
q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False)
k_proj_weight = paddle.to_tensor( k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False) self.k_proj.weight, stop_gradient=False)
k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False)
v_proj_weight = paddle.to_tensor( v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False) self.v_proj.weight, stop_gradient=False)
v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False)
out_linear_weight = paddle.to_tensor( out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False) self.out_proj.weight, stop_gradient=False)
if self.bias_attr is False:
qkv_bias_tensor = None
out_linear_bias = None
else:
q_proj_bias = paddle.to_tensor(
self.q_proj.bias, stop_gradient=False)
k_proj_bias = paddle.to_tensor(
self.k_proj.bias, stop_gradient=False)
v_proj_bias = paddle.to_tensor(
self.v_proj.bias, stop_gradient=False)
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor( out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False) self.out_proj.bias, stop_gradient=False)
...@@ -193,17 +205,12 @@ class TestFusedAttentionOp(OpTest): ...@@ -193,17 +205,12 @@ class TestFusedAttentionOp(OpTest):
qkv_weight = qkv_weight.reshape( qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim)) (3, self.num_heads, self.head_dim, self.embed_dim))
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False) x = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask: if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else: else:
attn_mask = None attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05 epsilon = 1e-05
ln2_epsilon = 1e-05 ln2_epsilon = 1e-05
...@@ -227,6 +234,36 @@ class TestFusedAttentionOp(OpTest): ...@@ -227,6 +234,36 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = False
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp): class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self): def config(self):
self.x_type = np.float32 self.x_type = np.float32
......
...@@ -356,6 +356,9 @@ def fused_multi_head_attention(x, ...@@ -356,6 +356,9 @@ def fused_multi_head_attention(x,
0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." 0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert qkv_weight.shape[3] == x.shape[ assert qkv_weight.shape[3] == x.shape[
2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." 2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[
3], "embed_dim must be divisible by num_heads."
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册