From ad44a40cc68cba777879c3d01274b46ade6b6ca6 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Wed, 10 Nov 2021 14:11:34 +0800 Subject: [PATCH] Fix fused_attention_op scope. (#37065) att, bug fix --- .../operators/fused/fused_attention_op.cc | 128 ++++++++++-------- .../operators/fused/fused_attention_op.cu | 91 +++++++------ .../unittests/test_fused_attention_op.py | 16 +-- .../unittests/test_fused_attention_op_api.py | 8 +- 4 files changed, 135 insertions(+), 108 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 96e2a0fcad2..11601a5ce40 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -42,6 +42,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", "FusedAttentionOp"); + } else { + OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", + "BiasDropoutResidualOut", "FusedAttentionOp"); } // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] @@ -70,12 +77,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", - "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", - "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", - "BiasDropoutResidualOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); @@ -109,6 +111,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + } else { + ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); } // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", @@ -138,12 +144,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); - ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); - ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); if (ctx->Attrs().Get("dropout_is_test") == false) { ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); } - ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); } @@ -314,25 +318,28 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { }); AddComment(R"DOC( - Add fused attention op whose logic is as follows: - // @input: [batch_size, seq_len, 3, num_head, head_dim] - // @final_out: [batch_size, seq_len, num_heads, head_dim] - if (pre_layernorm) - out = layer_norm(input); + Add fused attention op whose logic is as follows: + // @input: [batch_size, seq_len, 3, num_head, head_dim] + // @final_out: [batch_size, seq_len, num_heads, head_dim] + if (pre_layernorm) + out = layer_norm(input); out = compute_qkv(out) + bias; // fmha module - { - out = transpose(out, perm=[2, 0, 3, 1, 4]); - out = q * k^t; - out = attn_mask + out; - out = softmax(out); - out = dropout(out); - out = out * v; - out = transpose(out, perm=[0, 2, 1, 3]); + { + out = transpose(out, perm=[2, 0, 3, 1, 4]); + out = q * k^t; + out = attn_mask + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); - } + } out = out_linear(out); - final_out = layer_norm(residual + dropout(bias + out)); + if (pre_layernorm) + final_out = residual + dropout(bias + out); + else + final_out = layer_norm(residual + dropout(bias + out)); )DOC"); } }; @@ -347,20 +354,20 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "GradOp is only callable when attn_dropout_is_test is false")); - OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", - "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", - "FusedAttentionGrad"); - if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { - ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), - ctx->GetInputDim("Ln2Scale")); - } - if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), - ctx->GetInputDim("Ln2Bias")); - } - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); - if (ctx->Attrs().Get("pre_layer_norm") == true) { + if (ctx->Attrs().Get("pre_layer_norm") == false) { + OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", + "FusedAttentionGrad"); + OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", + "FusedAttentionGrad"); + if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Scale"), + ctx->GetInputDim("Ln2Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Ln2Bias"), + ctx->GetInputDim("Ln2Bias")); + } + } else { OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance", @@ -368,6 +375,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", "FusedAttentionGrad"); } + + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", @@ -402,6 +411,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim(framework::GradVarName("LnOut"), ctx->GetInputDim("LnOut")); + } else { + ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), + ctx->GetInputDim("BiasDropoutResidualOut")); } ctx->SetOutputDim(framework::GradVarName("FMHAOut"), ctx->GetInputDim("FMHAOut")); @@ -426,8 +438,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("QKVBiasOut")); ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->GetInputDim("OutLinearOut")); - ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), - ctx->GetInputDim("BiasDropoutResidualOut")); } protected: @@ -478,17 +488,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("LnBias"), this->InputGrad("LnBias")); } - } - - if (this->HasInput("Ln2Scale")) { - op->SetInput("Ln2Scale", this->Input("Ln2Scale")); - op->SetOutput(framework::GradVarName("Ln2Scale"), - this->InputGrad("Ln2Scale")); - } - if (this->HasInput("Ln2Bias")) { - op->SetInput("Ln2Bias", this->Input("Ln2Bias")); - op->SetOutput(framework::GradVarName("Ln2Bias"), - this->InputGrad("Ln2Bias")); + } else { + if (this->HasInput("Ln2Scale")) { + op->SetInput("Ln2Scale", this->Input("Ln2Scale")); + op->SetOutput(framework::GradVarName("Ln2Scale"), + this->InputGrad("Ln2Scale")); + } + if (this->HasInput("Ln2Bias")) { + op->SetInput("Ln2Bias", this->Input("Ln2Bias")); + op->SetOutput(framework::GradVarName("Ln2Bias"), + this->InputGrad("Ln2Bias")); + } } op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); @@ -511,6 +521,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { if (this->HasOutput("LnVariance")) { op->SetInput("LnVariance", this->Output("LnVariance")); } + } else { + op->SetInput("Ln2Mean", this->Output("Ln2Mean")); + op->SetInput("Ln2Variance", this->Output("Ln2Variance")); + op->SetInput("BiasDropoutResidualOut", + this->Output("BiasDropoutResidualOut")); } op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); @@ -523,12 +538,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("OutLinearOut", this->Output("OutLinearOut")); - - op->SetInput("Ln2Mean", this->Output("Ln2Mean")); - op->SetInput("Ln2Variance", this->Output("Ln2Variance")); op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut")); - op->SetInput("BiasDropoutResidualOut", - this->Output("BiasDropoutResidualOut")); op->SetInput("QKVOut", this->Output("QKVOut")); // backward outputs: dinput @@ -537,7 +547,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); } + } else { + op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), + this->OutputGrad("BiasDropoutResidualOut")); } + op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVBiasOut"), this->OutputGrad("QKVBiasOut")); @@ -553,8 +567,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("FMHAOut"), this->OutputGrad("FMHAOut")); - op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), - this->OutputGrad("BiasDropoutResidualOut")); op->SetOutput(framework::GradVarName("OutLinearOut"), this->OutputGrad("OutLinearOut")); } diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 99f08d38b6d..76bcb7c9c3a 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -95,15 +95,6 @@ class FusedAttentionOpKernel : public framework::OpKernel { const auto qkv_w_dims = qkv_weight->dims(); auto *x_data = input_x->data(); - auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); - auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); - auto *ln_mean_data = - pre_layer_norm ? ln_mean->mutable_data(ctx.GetPlace()) : nullptr; - auto *ln_var_data = - pre_layer_norm ? ln_var->mutable_data(ctx.GetPlace()) : nullptr; - auto *ln_out_data = - pre_layer_norm ? ln_out->mutable_data(ctx.GetPlace()) : nullptr; - auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = qkv_bias->data(); auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); @@ -130,16 +121,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm - auto *ln_scale_2_data = - (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); - auto *ln_bias_2_data = - (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); auto *dropout_mask_out_data = dropout_mask_out->mutable_data(ctx.GetPlace()); - auto *bias_dropout_residual_out_data = - bias_dropout_residual_out->mutable_data(ctx.GetPlace()); - auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); - auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; @@ -178,6 +161,13 @@ class FusedAttentionOpKernel : public framework::OpKernel { ln_epsilon); if (pre_layer_norm) { + auto *ln_scale_data = + (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, ln_out_data, ln_mean_data, ln_var_data); qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, @@ -196,12 +186,27 @@ class FusedAttentionOpKernel : public framework::OpKernel { // out_linear_out: [batch_size, seq_len, embed_dim] out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, nullptr, out_linear_out_data, nullptr); - // output = layernorm(residual + dropout(input + bias)) - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx.cuda_device_context(), out_linear_out_data, x_data, - out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, - bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, - ln_mean_2_data, ln_var_2_data); + if (pre_layer_norm) { + // output = (residual + dropout(input + bias)) + fused_dropout_layernorm_helper.ResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, final_out_data, dropout_mask_out_data); + } else { + auto *ln_scale_2_data = + (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); + auto *ln_bias_2_data = + (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); + auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); + // output = layernorm(residual + dropout(input + bias)) + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, + bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, + ln_mean_2_data, ln_var_2_data); + } } }; @@ -271,10 +276,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->data(); auto *out_linear_out_data = out_linear_out->data(); - auto *ln_2_mean_data = ln_2_mean->data(); - auto *ln_2_var_data = ln_2_var->data(); auto *dropout_mask_out_data = dropout_mask_out->data(); - auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); // output's grad auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -312,8 +314,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); auto *d_out_linear_out_data = d_out_linear_out->mutable_data(ctx.GetPlace()); - auto *d_bias_dropout_residual_out_data = - d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); // parameter grad auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); @@ -331,12 +331,6 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_out_linear_weight->mutable_data(ctx.GetPlace()); auto *d_out_linear_bias_data = d_out_linear_bias->mutable_data(ctx.GetPlace()); - auto *d_ln_2_scale_data = - (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( - ctx.GetPlace())); - auto *d_ln_2_bias_data = - (d_ln_2_bias == nullptr ? nullptr - : d_ln_2_bias->mutable_data(ctx.GetPlace())); const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); @@ -382,11 +376,30 @@ class FusedAttentionGradKernel : public framework::OpKernel { ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln2epsilon); - fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( - ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, - dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, - d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, - d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); + if (pre_layer_norm) { + fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, dropout_mask_out_data, + d_out_linear_out_data, d_residual_data, d_out_linear_bias_data); + } else { + auto *ln_2_mean_data = ln_2_mean->data(); + auto *ln_2_var_data = ln_2_var->data(); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->data(); + auto *d_ln_2_scale_data = + (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( + ctx.GetPlace())); + auto *d_ln_2_bias_data = + (d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data( + ctx.GetPlace())); + auto *d_bias_dropout_residual_out_data = + d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, + dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, + d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, + d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); + } out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data, d_out_linear_out_data, d_fmha_out_data, diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 41962d5ada0..c0b3e27e671 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -155,8 +155,8 @@ class TestFusedAttentionOp(OpTest): residual_out = residual + self.dropout(out) if not self.pre_layer_norm: final_out = self.norm1(residual_out) - if self.pre_layer_norm: - final_out = self.norm2(residual_out) + else: + final_out = residual_out paddle.autograd.backward( [final_out], [paddle.to_tensor(self.dout)], retain_graph=True) return final_out, tensor_query.grad @@ -219,9 +219,9 @@ class TestFusedAttentionOp(OpTest): final_out_ref, x_grad_ref = self.GetBaselineOut() final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpPreLn(TestFusedAttentionOp): @@ -249,9 +249,9 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): final_out_ref, x_grad_ref = self.GetBaselineOut() final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): @@ -279,9 +279,9 @@ class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp): final_out_ref, x_grad_ref = self.GetBaselineOut() final_out, x_grad = self.GetFusedAttentionOut() np.testing.assert_allclose( - final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) np.testing.assert_allclose( - x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1) + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) class TestFusedAttentionOpFp16(TestFusedAttentionOp): diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index 02695be61c3..92acb5925a1 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -138,9 +138,11 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, out_linear_bias_out = out_linear_out + out_linear_bias out_linear_bias_dropout_out = out_linear_bias_out out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out - out_linear_bias_dropout_residual_ln_out = layer_norm( - out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias) - return out_linear_bias_dropout_residual_ln_out + if not pre_layer_norm: + out_linear_bias_dropout_residual_out = layer_norm( + out_linear_bias_dropout_residual_out, True, True, ln_2_scale, + ln_2_bias) + return out_linear_bias_dropout_residual_out class TestFusedAttentionAPI(unittest.TestCase): -- GitLab