diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 11601a5ce40d5a7d82311e08d95db3d28d478d20..39c7e52cc465c7ffdb5d7dd579e7781477f691de 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", - "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", - "FusedAttentionOp"); if (ctx->Attrs().Get("pre_layer_norm") == true) { OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", @@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp"); - OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", - "FusedAttentionOp"); + if (ctx->HasInput("QKVBias")) { + OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", + "FusedAttentionOp"); + } OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", @@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { "input qkv_weight = [%s]", x_dim, y_dim)); + PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "and must satisfy the limitations: " + "(num_head * dim_head == dim_embed)")); + if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); @@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { // [batch_size, seq_len, 3, num_head, head_size] ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); - ctx->SetOutputDim("QKVBiasOut", - {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + + if (ctx->HasInput("QKVBias")) { + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + } // [3, batch_size, num_head, seq_len, head_size] ctx->SetOutputDim("TransposeOut2", {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); @@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "H. Here, H represents the last dimension of its input tensor.") .AsDispensable(); AddInput("QKVW", "The qkv weight tensor."); - AddInput("QKVBias", "The qkv bias tensor."); + AddInput("QKVBias", "The qkv bias tensor.").AsDispensable(); AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") .AsDispensable(); AddInput("OutLinearW", "The out_linear weight tensor."); - AddInput("OutLinearBias", "The out_linear bias tensor."); + AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable(); AddInput("Ln2Scale", "(optional) Scale is a 1-dimensional tensor of size " "H. Here, H represents the last dimension of its input tensor.") @@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", - "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionGrad"); - OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", - "FusedAttentionGrad"); if (ctx->Attrs().Get("pre_layer_norm") == true) { if (ctx->HasOutput(framework::GradVarName("LnScale"))) { @@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } - - ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), - ctx->GetInputDim("OutLinearBias")); + if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) { + ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), + ctx->GetInputDim("OutLinearBias")); + } ctx->SetOutputDim(framework::GradVarName("OutLinearW"), ctx->GetInputDim("OutLinearW")); ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); - ctx->SetOutputDim(framework::GradVarName("QKVBias"), - ctx->GetInputDim("QKVBias")); + if (ctx->HasOutput(framework::GradVarName("QKVBias"))) { + ctx->SetOutputDim(framework::GradVarName("QKVBias"), + ctx->GetInputDim("QKVBias")); + } if (ctx->Attrs().Get("pre_layer_norm") == true) { ctx->SetOutputDim(framework::GradVarName("LnOut"), @@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { } ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->GetInputDim("QKVOut")); - ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), - ctx->GetInputDim("QKVBiasOut")); + if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) { + ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), + ctx->GetInputDim("QKVBiasOut")); + } ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), ctx->GetInputDim("OutLinearOut")); } @@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { // inputs x, parameters and their grad. op->SetInput("X", this->Input("X")); op->SetInput("QKVW", this->Input("QKVW")); - op->SetInput("QKVBias", this->Input("QKVBias")); + + if (this->HasInput("QKVBias")) { + op->SetInput("QKVBias", this->Input("QKVBias")); + op->SetOutput(framework::GradVarName("QKVBias"), + this->InputGrad("QKVBias")); + op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetOutput(framework::GradVarName("QKVBiasOut"), + this->OutputGrad("QKVBiasOut")); + } if (this->HasInput("SrcMask")) { op->SetInput("SrcMask", this->Input("SrcMask")); @@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { } op->SetInput("OutLinearW", this->Input("OutLinearW")); - op->SetInput("OutLinearBias", this->Input("OutLinearBias")); + if (this->HasInput("OutLinearBias")) { + op->SetInput("OutLinearBias", this->Input("OutLinearBias")); + op->SetOutput(framework::GradVarName("OutLinearBias"), + this->InputGrad("OutLinearBias")); + } op->SetAttrMap(this->Attrs()); bool is_pre_layer_norm = @@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW")); - op->SetOutput(framework::GradVarName("QKVBias"), - this->InputGrad("QKVBias")); - op->SetOutput(framework::GradVarName("OutLinearBias"), - this->InputGrad("OutLinearBias")); + op->SetOutput(framework::GradVarName("OutLinearW"), this->InputGrad("OutLinearW")); @@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { this->Output("BiasDropoutResidualOut")); } op->SetInput("QKVOut", this->Output("QKVOut")); - op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); + op->SetInput("TransposeOut2", this->Output("TransposeOut2")); op->SetInput("QKOut", this->Output("QKOut")); op->SetInput("QKTVOut", this->Output("QKTVOut")); @@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { } op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); - op->SetOutput(framework::GradVarName("QKVBiasOut"), - this->OutputGrad("QKVBiasOut")); + op->SetOutput(framework::GradVarName("QKTVOut"), this->OutputGrad("QKTVOut")); op->SetOutput(framework::GradVarName("TransposeOut2"), diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 5bcf12856083698962d0186c8de40ff0363e5dc9..9f6d6e2270673d1a915d7ea71c7f218a9ddf35ea 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto *x_data = input_x->data(); auto *qkv_weight_data = qkv_weight->data(); - auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); - auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); + auto *qkv_bias_out_data = + (qkv_bias == nullptr) ? nullptr + : qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. auto *transpose_out_2_data = @@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel { // get data ptr for out_linear. auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_bias_data = + (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm @@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel { auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); + + bool compute_bias = true; + if (qkv_bias == nullptr) { + compute_bias = false; + } // (transA, transB, compute_bias) = (false, true, true) - auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, - bsz_seq, output_size, input_size, true); + auto qkv_compute = + AttnMatMul(ctx.cuda_device_context(), false, true, bsz_seq, + output_size, input_size, compute_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_rate, @@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel { qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out); } - fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, - qk_out, src_mask_out, softmax_out, - attn_dropout_mask_out, attn_dropout_out, - qktv_out, fmha_out); + if (qkv_bias == nullptr) { + fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); + } else { + fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); + } // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] @@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *qkv_weight_data = qkv_weight->data(); - auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_bias_data = + (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); // fw output auto *fmha_out = ctx.Input("FMHAOut"); @@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_bias_dropout_residual_out = ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); - auto *d_qkv_out_data = d_qkv_out->mutable_data(ctx.GetPlace()); - auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data(ctx.GetPlace()); + // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the + // space can be reused. + auto *d_qkv_out_data = (d_qkv_bias_out != nullptr) + ? nullptr + : d_qkv_out->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_out_data = + (d_qkv_bias_out == nullptr) + ? nullptr + : d_qkv_bias_out->mutable_data(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); auto *d_transpose_out_2_data = d_transpose_out_2->mutable_data(ctx.GetPlace()); @@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); - auto *d_qkv_bias_data = d_qkv_bias->mutable_data(ctx.GetPlace()); + auto *d_qkv_bias_data = (d_qkv_bias == nullptr) + ? nullptr + : d_qkv_bias->mutable_data(ctx.GetPlace()); auto *d_out_linear_weight_data = d_out_linear_weight->mutable_data(ctx.GetPlace()); auto *d_out_linear_bias_data = - d_out_linear_bias->mutable_data(ctx.GetPlace()); + (d_out_linear_bias == nullptr) + ? nullptr + : d_out_linear_bias->mutable_data(ctx.GetPlace()); const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); @@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel { bool transA = false; bool transB = true; - bool compute_bias = true; + bool compute_qkv_bias = true; + if (qkv_bias == nullptr) { + compute_qkv_bias = false; + } auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, - output_size, input_size, compute_bias); + output_size, input_size, compute_qkv_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_prob, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); @@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel { output_size = hidden_size; transA = false; transB = false; - compute_bias = false; + bool compute_bias = false; auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_bias); @@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel { d_out_linear_out, d_fmha_out, d_out_linear_weight, nullptr); - fmha_ref_compute.ComputeBackward( - *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, - *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, - d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, - d_transpose_out_2, nullptr, d_qkv_bias_out); - cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data, - bsz_seq * 3 * num_head * dim_head * sizeof(T), - cudaMemcpyDeviceToDevice); + if (qkv_bias != nullptr) { + fmha_ref_compute.ComputeBackward( + *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, + *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, + d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, + d_transpose_out_2, nullptr, d_qkv_bias_out); + } else { + fmha_ref_compute.ComputeBackward( + *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, + *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, + d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, + d_transpose_out_2, nullptr, d_qkv_out); + } if (pre_layer_norm) { auto *ln_mean = ctx.Input("LnMean"); @@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_ln_bias_data = (d_ln_bias == nullptr ? nullptr : d_ln_bias->mutable_data(ctx.GetPlace())); - - qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out, - d_qkv_weight, d_qkv_bias); + if (qkv_bias != nullptr) { + qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, + d_ln_out, d_qkv_weight, d_qkv_bias); + } else { + qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out, + d_qkv_weight, d_qkv_bias); + } layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, ln_mean_data, ln_var_data, d_x_data, d_ln_scale_data, d_ln_bias_data); } else { - qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, - d_qkv_weight, d_qkv_bias); + if (qkv_bias != nullptr) { + qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, + d_qkv_weight, d_qkv_bias); + } else { + qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x, + d_qkv_weight, d_qkv_bias); + } } // gradient accumulation std::vector ins; diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index b2b5cac2bff965363e087d6d09c38477e0e0847a..443703aa937d8aead8307b892961e7054ede6ed4 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -168,17 +168,29 @@ class TestFusedAttentionOp(OpTest): paddle.disable_static(place=paddle.CUDAPlace(0)) q_proj_weight = paddle.to_tensor( self.q_proj.weight, stop_gradient=False) - q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) k_proj_weight = paddle.to_tensor( self.k_proj.weight, stop_gradient=False) - k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) v_proj_weight = paddle.to_tensor( self.v_proj.weight, stop_gradient=False) - v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) out_linear_weight = paddle.to_tensor( self.out_proj.weight, stop_gradient=False) - out_linear_bias = paddle.to_tensor( - self.out_proj.bias, stop_gradient=False) + + if self.bias_attr is False: + qkv_bias_tensor = None + out_linear_bias = None + else: + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False) + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) @@ -193,17 +205,12 @@ class TestFusedAttentionOp(OpTest): qkv_weight = qkv_weight.reshape( (3, self.num_heads, self.head_dim, self.embed_dim)) - qkv_bias = np.concatenate( - (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) - qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) - x = paddle.to_tensor(self.query, stop_gradient=False) if self.has_attn_mask: attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) else: attn_mask = None qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) - qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) epsilon = 1e-05 ln2_epsilon = 1e-05 @@ -227,6 +234,36 @@ class TestFusedAttentionOp(OpTest): x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) +class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp): + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = False + self.has_attn_mask = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = False + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4) + + class TestFusedAttentionOpPreLn(TestFusedAttentionOp): def config(self): self.x_type = np.float32 diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index df9cc68a02d8dc7ad0db20b30fec0414140a73b6..eafefd98298f542e0d16162e8e8c2bcc861ec622 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -356,6 +356,9 @@ def fused_multi_head_attention(x, 0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." assert qkv_weight.shape[3] == x.shape[ 2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." + assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[ + 3], "embed_dim must be divisible by num_heads." + _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',