未验证 提交 ad44a40c 编写于 作者: L Li Min 提交者: GitHub

Fix fused_attention_op scope. (#37065)

att, bug fix
上级 48d53cfc
...@@ -42,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -42,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp"); "FusedAttentionOp");
} else {
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
} }
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head] // qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
...@@ -70,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -70,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
...@@ -109,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -109,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
} else {
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
} }
// [batch_size, seq_len, 3, num_head, head_size] // [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut", ctx->SetOutputDim("QKVOut",
...@@ -138,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -138,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) { if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
} }
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
} }
...@@ -314,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -314,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddComment(R"DOC( AddComment(R"DOC(
Add fused attention op whose logic is as follows: Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim] // @input: [batch_size, seq_len, 3, num_head, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim] // @final_out: [batch_size, seq_len, num_heads, head_dim]
if (pre_layernorm) if (pre_layernorm)
out = layer_norm(input); out = layer_norm(input);
out = compute_qkv(out) + bias; out = compute_qkv(out) + bias;
// fmha module // fmha module
{ {
out = transpose(out, perm=[2, 0, 3, 1, 4]); out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t; out = q * k^t;
out = attn_mask + out; out = attn_mask + out;
out = softmax(out); out = softmax(out);
out = dropout(out); out = dropout(out);
out = out * v; out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]); out = transpose(out, perm=[0, 2, 1, 3]);
} }
out = out_linear(out); out = out_linear(out);
final_out = layer_norm(residual + dropout(bias + out)); if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
)DOC"); )DOC");
} }
}; };
...@@ -347,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -347,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false")); "GradOp is only callable when attn_dropout_is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
"FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", "FusedAttentionGrad");
"FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { "FusedAttentionGrad");
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->GetInputDim("Ln2Scale")); ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
} ctx->GetInputDim("Ln2Scale"));
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { }
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->GetInputDim("Ln2Bias")); ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
} ctx->GetInputDim("Ln2Bias"));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); }
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { } else {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
...@@ -368,6 +375,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -368,6 +375,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad"); "FusedAttentionGrad");
} }
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
...@@ -402,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -402,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"), ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut")); ctx->GetInputDim("LnOut"));
} else {
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
} }
ctx->SetOutputDim(framework::GradVarName("FMHAOut"), ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut")); ctx->GetInputDim("FMHAOut"));
...@@ -426,8 +438,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -426,8 +438,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("QKVBiasOut")); ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut")); ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
} }
protected: protected:
...@@ -478,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -478,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnBias"), op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias")); this->InputGrad("LnBias"));
} }
} } else {
if (this->HasInput("Ln2Scale")) {
if (this->HasInput("Ln2Scale")) { op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Scale", this->Input("Ln2Scale")); op->SetOutput(framework::GradVarName("Ln2Scale"),
op->SetOutput(framework::GradVarName("Ln2Scale"), this->InputGrad("Ln2Scale"));
this->InputGrad("Ln2Scale")); }
} if (this->HasInput("Ln2Bias")) {
if (this->HasInput("Ln2Bias")) { op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias")); op->SetOutput(framework::GradVarName("Ln2Bias"),
op->SetOutput(framework::GradVarName("Ln2Bias"), this->InputGrad("Ln2Bias"));
this->InputGrad("Ln2Bias")); }
} }
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
...@@ -511,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -511,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
if (this->HasOutput("LnVariance")) { if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance")); op->SetInput("LnVariance", this->Output("LnVariance"));
} }
} else {
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
} }
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
...@@ -523,12 +538,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -523,12 +538,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut")); op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput // backward outputs: dinput
...@@ -537,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -537,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnOut"), op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut")); this->OutputGrad("LnOut"));
} }
} else {
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
} }
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"), op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut")); this->OutputGrad("QKVBiasOut"));
...@@ -553,8 +567,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -553,8 +567,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("FMHAOut"), op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut")); this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"), op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut")); this->OutputGrad("OutLinearOut"));
} }
......
...@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const auto qkv_w_dims = qkv_weight->dims(); const auto qkv_w_dims = qkv_weight->dims();
auto *x_data = input_x->data<T>(); auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data =
pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>(); auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace()); auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
...@@ -130,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -130,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace()); auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for bias+dropout+residual+layernorm // get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *dropout_mask_out_data = auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace()); auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());
int batch_size = input_x_dims[0]; int batch_size = input_x_dims[0];
...@@ -178,6 +161,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -178,6 +161,13 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
ln_epsilon); ln_epsilon);
if (pre_layer_norm) { if (pre_layer_norm) {
auto *ln_scale_data =
(ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data); ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
...@@ -196,12 +186,27 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -196,12 +186,27 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// out_linear_out: [batch_size, seq_len, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr); nullptr, out_linear_out_data, nullptr);
// output = layernorm(residual + dropout(input + bias)) if (pre_layer_norm) {
fused_dropout_layernorm_helper.LayernormResidualDropoutBias( // output = (residual + dropout(input + bias))
ctx.cuda_device_context(), out_linear_out_data, x_data, fused_dropout_layernorm_helper.ResidualDropoutBias(
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, ctx.cuda_device_context(), out_linear_out_data, x_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, out_linear_bias_data, final_out_data, dropout_mask_out_data);
ln_mean_2_data, ln_var_2_data); } else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
}
} }
}; };
...@@ -271,10 +276,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -271,10 +276,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *src_mask_out_data = auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>(); (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>(); auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>(); auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();
// output's grad // output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -312,8 +314,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -312,8 +314,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data = auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace()); d_out_linear_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad // parameter grad
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW")); auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
...@@ -331,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -331,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_weight->mutable_data<T>(ctx.GetPlace()); d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data = auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace()); d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
const auto input_x_dims = input_x->dims(); const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims(); const auto qkv_w_dims = qkv_weight->dims();
...@@ -382,11 +376,30 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -382,11 +376,30 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon); ln2epsilon);
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( if (pre_layer_norm) {
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, ctx.cuda_device_context(), d_y_data, dropout_mask_out_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, d_out_linear_out_data, d_residual_data, d_out_linear_bias_data);
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); } else {
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->data<T>();
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data<U>(
ctx.GetPlace()));
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
}
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data, out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data, d_out_linear_out_data, d_fmha_out_data,
......
...@@ -155,8 +155,8 @@ class TestFusedAttentionOp(OpTest): ...@@ -155,8 +155,8 @@ class TestFusedAttentionOp(OpTest):
residual_out = residual + self.dropout(out) residual_out = residual + self.dropout(out)
if not self.pre_layer_norm: if not self.pre_layer_norm:
final_out = self.norm1(residual_out) final_out = self.norm1(residual_out)
if self.pre_layer_norm: else:
final_out = self.norm2(residual_out) final_out = residual_out
paddle.autograd.backward( paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True) [final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, tensor_query.grad return final_out, tensor_query.grad
...@@ -219,9 +219,9 @@ class TestFusedAttentionOp(OpTest): ...@@ -219,9 +219,9 @@ class TestFusedAttentionOp(OpTest):
final_out_ref, x_grad_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose( np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp): class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
...@@ -249,9 +249,9 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): ...@@ -249,9 +249,9 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
final_out_ref, x_grad_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose( np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
...@@ -279,9 +279,9 @@ class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): ...@@ -279,9 +279,9 @@ class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
final_out_ref, x_grad_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose( np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpFp16(TestFusedAttentionOp): class TestFusedAttentionOpFp16(TestFusedAttentionOp):
......
...@@ -138,9 +138,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -138,9 +138,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_bias_out = out_linear_out + out_linear_bias out_linear_bias_out = out_linear_out + out_linear_bias
out_linear_bias_dropout_out = out_linear_bias_out out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
out_linear_bias_dropout_residual_ln_out = layer_norm( if not pre_layer_norm:
out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias) out_linear_bias_dropout_residual_out = layer_norm(
return out_linear_bias_dropout_residual_ln_out out_linear_bias_dropout_residual_out, True, True, ln_2_scale,
ln_2_bias)
return out_linear_bias_dropout_residual_out
class TestFusedAttentionAPI(unittest.TestCase): class TestFusedAttentionAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册