未验证 提交 1a8786cf 编写于 作者: L Li Min 提交者: GitHub

Add support bias is none for fused_attention op. (#37411)

Add support for bias is none for fused_attention op.
上级 4812eda5
......@@ -28,12 +28,8 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
......@@ -54,8 +50,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp");
if (ctx->HasInput("QKVBias")) {
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut",
......@@ -107,6 +105,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]",
x_dim, y_dim));
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
......@@ -119,8 +124,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
if (ctx->HasInput("QKVBias")) {
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
}
// [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
......@@ -173,11 +181,11 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("QKVW", "The qkv weight tensor.");
AddInput("QKVBias", "The qkv bias tensor.");
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable();
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable();
AddInput("Ln2Scale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
......@@ -379,12 +387,8 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
......@@ -399,14 +403,17 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) {
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
if (ctx->HasOutput(framework::GradVarName("QKVBias"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
}
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
......@@ -434,8 +441,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) {
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
}
......@@ -462,7 +471,15 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
// inputs x, parameters and their grad.
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
if (this->HasInput("QKVBias")) {
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
}
if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask"));
......@@ -472,7 +489,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("OutLinearBias")) {
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
}
op->SetAttrMap(this->Attrs());
bool is_pre_layer_norm =
......@@ -503,10 +524,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW"));
......@@ -528,7 +546,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->Output("BiasDropoutResidualOut"));
}
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut"));
......@@ -553,8 +571,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"),
......
......@@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *x_data = input_x->data<T>();
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data =
(qkv_bias == nullptr) ? nullptr
: qkv_bias_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for FMHA.
auto *transpose_out_2_data =
......@@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for bias+dropout+residual+layernorm
......@@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
bool compute_bias = true;
if (qkv_bias == nullptr) {
compute_bias = false;
}
// (transA, transB, compute_bias) = (false, true, true)
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, true,
bsz_seq, output_size, input_size, true);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, true, bsz_seq,
output_size, input_size, compute_bias);
AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_rate,
......@@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
qkv_bias_out);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
if (qkv_bias == nullptr) {
fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
} else {
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
}
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
......@@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
// fw output
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
......@@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
// when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
// space can be reused.
auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
? nullptr
: d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data =
(d_qkv_bias_out == nullptr)
? nullptr
: d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
auto *d_transpose_out_2_data =
d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
......@@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = (d_qkv_bias == nullptr)
? nullptr
: d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data =
d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
(d_out_linear_bias == nullptr)
? nullptr
: d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
......@@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
bool transA = false;
bool transB = true;
bool compute_bias = true;
bool compute_qkv_bias = true;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
output_size, input_size, compute_qkv_bias);
AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_prob,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
......@@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
output_size = hidden_size;
transA = false;
transB = false;
compute_bias = false;
bool compute_bias = false;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
......@@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out, d_fmha_out,
d_out_linear_weight, nullptr);
fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data,
bsz_seq * 3 * num_head * dim_head * sizeof(T),
cudaMemcpyDeviceToDevice);
if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
} else {
fmha_ref_compute.ComputeBackward(
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_out);
}
if (pre_layer_norm) {
auto *ln_mean = ctx.Input<Tensor>("LnMean");
......@@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out,
d_ln_out, d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out,
d_qkv_weight, d_qkv_bias);
}
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x,
d_qkv_weight, d_qkv_bias);
}
}
// gradient accumulation
std::vector<const Tensor *> ins;
......
......@@ -168,17 +168,29 @@ class TestFusedAttentionOp(OpTest):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False)
q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False)
k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False)
k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False)
v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False)
v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False)
out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False)
if self.bias_attr is False:
qkv_bias_tensor = None
out_linear_bias = None
else:
q_proj_bias = paddle.to_tensor(
self.q_proj.bias, stop_gradient=False)
k_proj_bias = paddle.to_tensor(
self.k_proj.bias, stop_gradient=False)
v_proj_bias = paddle.to_tensor(
self.v_proj.bias, stop_gradient=False)
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False)
ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
......@@ -193,17 +205,12 @@ class TestFusedAttentionOp(OpTest):
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim))
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05
ln2_epsilon = 1e-05
......@@ -227,6 +234,36 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = False
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-4)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-4)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
......
......@@ -356,6 +356,9 @@ def fused_multi_head_attention(x,
0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert qkv_weight.shape[3] == x.shape[
2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[
3], "embed_dim must be divisible by num_heads."
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册