未验证 提交 ff3018d7 编写于 作者: L Li Min 提交者: GitHub

Fix fused_attention_op and fused_feedforward_op bug when pre_layer_norm is false. (#36793)

* Fix bug when pre_layer_norm is false.
上级 9516108a
......@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
}
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp");
......@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]",
x_dim, y_dim));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
}
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
......@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
......@@ -370,6 +375,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
......@@ -378,6 +384,7 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
......@@ -442,6 +451,11 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
op->SetAttrMap(this->Attrs());
bool is_pre_layer_norm =
BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
if (is_pre_layer_norm) {
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
......@@ -452,6 +466,8 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
}
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
......@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("OutLinearW"));
// use forward outputs as backward inputs.
if (is_pre_layer_norm) {
if (this->HasOutput("LnOut")) {
op->SetInput("LnOut", this->Output("LnOut"));
}
if (this->HasOutput("LnMean")) {
op->SetInput("LnMean", this->Output("LnMean"));
}
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
}
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
......@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut"));
if (is_pre_layer_norm) {
if (this->HasOutput("LnOut")) {
op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut"));
}
}
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
......@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));
op->SetAttrMap(this->Attrs());
}
};
......
......@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_data =
pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
......@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *out_linear_bias_data = out_linear_bias->data<T>();
// fw output
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
auto *qk_out = ctx.Input<Tensor>("QKOut");
......@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>();
......@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
auto *d_qkv_bias_out =
ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
......@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
......@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
auto *d_out_linear_weight =
......@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data =
......@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice);
if (pre_layer_norm) {
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
......
......@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout2Mask"), "Output", "Dropout2Mask",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Linear1Out"), "Output", "Linear1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout1Out"), "Output", "Dropout1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout2Out"), "Output", "Dropout2Out",
......@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
}
context->SetOutputDim("Dropout1Out", tmp_dim_x);
context->SetOutputDim("Linear1Out", tmp_dim_x);
context->SetOutputDim("Ln1Out", dim_x);
context->SetOutputDim("Dropout2Out", dim_x);
if (context->Attrs().Get<bool>("dropout2_is_test") == false) {
......@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
}
framework::DDim mean_dim =
framework::make_ddim({mat_dim_x.batch_size_ * mat_dim_x.height_});
bool pre_layer_norm = context->Attrs().Get<bool>("pre_layer_norm");
if (pre_layer_norm) {
OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out",
"fused_feedforward");
context->SetOutputDim("Ln1Out", dim_x);
context->SetOutputDim("Ln1Mean", mean_dim);
context->SetOutputDim("Ln1Variance", mean_dim);
} else {
OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"fused_feedforward");
context->SetOutputDim("Ln2Mean", mean_dim);
context->SetOutputDim("Ln2Variance", mean_dim);
}
context->ShareLoD("X", "Out");
}
......@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
bool pre_layer_norm = ctx->Attrs().Get<bool>("pre_layer_norm");
OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out",
......@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight",
"FusedFeedForwardGrad");
if (pre_layer_norm) {
OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
"FusedFeedForwardGrad");
} else {
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedFeedForwardGrad");
}
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FusedFeedForwardGrad");
......@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Linear1Weight", this->Input("Linear1Weight"));
op->SetInput("Linear1Bias", this->Input("Linear1Bias"));
op->SetInput("Linear2Weight", this->Input("Linear2Weight"));
op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Dropout1Mask", this->Output("Dropout1Mask"));
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
op->SetInput("Linear1Out", this->Output("Linear1Out"));
op->SetInput("Ln1Out", this->Output("Ln1Out"));
op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
op->SetAttrMap(this->Attrs());
bool pre_layer_norm = BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
if (pre_layer_norm) {
op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
op->SetInput("Ln1Out", this->Output("Ln1Out"));
op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
op->SetOutput(framework::GradVarName("Ln1Scale"),
this->InputGrad("Ln1Scale"));
op->SetOutput(framework::GradVarName("Ln1Bias"),
this->InputGrad("Ln1Bias"));
} else {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}
op->SetOutput(framework::GradVarName("Linear1Weight"),
this->InputGrad("Linear1Weight"));
op->SetOutput(framework::GradVarName("Linear1Bias"),
......@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Linear2Bias"),
this->InputGrad("Linear2Bias"));
}
op->SetAttrMap(this->Attrs());
}
};
......
......@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto* linear2_weight = context.Input<framework::Tensor>("Linear2Weight");
auto* linear2_bias = context.Input<framework::Tensor>("Linear2Bias");
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale");
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale");
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias");
auto* ln1_mean = context.Output<framework::Tensor>("Ln1Mean");
auto* ln1_variance = context.Output<framework::Tensor>("Ln1Variance");
auto* ln2_mean = context.Output<framework::Tensor>("Ln2Mean");
auto* ln2_variance = context.Output<framework::Tensor>("Ln2Variance");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto* ln1_scale =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Scale") : nullptr;
auto* ln1_bias =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Bias") : nullptr;
auto* ln2_scale = !pre_layer_norm
? context.Input<framework::Tensor>("Ln2Scale")
: nullptr;
auto* ln2_bias =
!pre_layer_norm ? context.Input<framework::Tensor>("Ln2Bias") : nullptr;
auto* ln1_mean =
pre_layer_norm ? context.Output<framework::Tensor>("Ln1Mean") : nullptr;
auto* ln1_variance = pre_layer_norm
? context.Output<framework::Tensor>("Ln1Variance")
: nullptr;
auto* ln2_mean = !pre_layer_norm
? context.Output<framework::Tensor>("Ln2Mean")
: nullptr;
auto* ln2_variance = !pre_layer_norm
? context.Output<framework::Tensor>("Ln2Variance")
: nullptr;
auto* out = context.Output<framework::Tensor>("Out");
auto* dropout1_mask = context.Output<framework::Tensor>("Dropout1Mask");
auto* dropout2_mask = context.Output<framework::Tensor>("Dropout2Mask");
auto* linear1_out = context.Output<framework::Tensor>("Linear1Out");
auto* ln1_out = context.Output<framework::Tensor>("Ln1Out");
auto* ln1_out =
pre_layer_norm ? context.Output<framework::Tensor>("Ln1Out") : nullptr;
auto* dropout1_out = context.Output<framework::Tensor>("Dropout1Out");
auto* dropout2_out = context.Output<framework::Tensor>("Dropout2Out");
const std::string act_method = context.Attr<std::string>("act_method");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
......@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
out->mutable_data<T>(place);
dropout1_mask->mutable_data<uint8_t>(place);
dropout2_mask->mutable_data<uint8_t>(place);
if (pre_layer_norm) {
ln1_mean->mutable_data<U>(place);
ln1_variance->mutable_data<U>(place);
ln1_out->mutable_data<T>(place);
} else {
ln2_mean->mutable_data<U>(place);
ln2_variance->mutable_data<U>(place);
}
linear1_out->mutable_data<T>(place);
ln1_out->mutable_data<T>(place);
dropout1_out->mutable_data<T>(place);
dropout2_out->mutable_data<T>(place);
......@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const framework::Tensor& d_out, const framework::Tensor& x,
const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor& ln1_out,
const framework::Tensor& linear1_out, const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out,
const framework::Tensor& dropout2_out,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta,
const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance,
const framework::Tensor* ln1_mean, const framework::Tensor* ln1_variance,
const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta,
const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance,
const framework::Tensor* ln2_mean, const framework::Tensor* ln2_variance,
framework::Tensor* d_x, framework::Tensor* d_linear1_weight,
framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight,
framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma,
......@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean.data<U>(),
ln2_variance.data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(),
ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>());
}
......@@ -273,12 +291,12 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
if (pre_layer_norm) {
framework::Tensor d_ln1_out;
d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place);
MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out,
MatMulGrad(ctx, d_linear1_out, *ln1_out, linear1_weight, &d_ln1_out,
d_linear1_weight);
pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data<T>(), x.data<T>(),
ln1_gamma_ptr, ln1_mean.data<U>(),
ln1_variance.data<U>(), d_x->data<T>(),
pre_layernorm_helper.LayerNormGrad(
ctx, d_ln1_out.data<T>(), x.data<T>(), ln1_gamma_ptr,
ln1_mean->data<U>(), ln1_variance->data<U>(), d_x->data<T>(),
d_ln1_gamma_ptr, d_ln1_beta_ptr);
} else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
......@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
auto d_out =
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto x = *context.Input<framework::Tensor>("X");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto dropout1_mask = *context.Input<framework::Tensor>("Dropout1Mask");
auto dropout2_mask = *context.Input<framework::Tensor>("Dropout2Mask");
auto linear1_out = *context.Input<framework::Tensor>("Linear1Out");
auto ln1_out = *context.Input<framework::Tensor>("Ln1Out");
auto* ln1_out =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Out") : nullptr;
auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out");
auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out");
auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight");
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight");
auto ln1_mean = *context.Input<framework::Tensor>("Ln1Mean");
auto ln1_variance = *context.Input<framework::Tensor>("Ln1Variance");
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale");
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
auto ln2_mean = *context.Input<framework::Tensor>("Ln2Mean");
auto ln2_variance = *context.Input<framework::Tensor>("Ln2Variance");
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale");
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias");
auto* ln1_mean =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Mean") : nullptr;
auto* ln1_variance = pre_layer_norm
? context.Input<framework::Tensor>("Ln1Variance")
: nullptr;
auto* ln1_scale =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Scale") : nullptr;
auto* ln1_bias =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Bias") : nullptr;
auto* ln2_mean =
!pre_layer_norm ? context.Input<framework::Tensor>("Ln2Mean") : nullptr;
auto* ln2_variance = !pre_layer_norm
? context.Input<framework::Tensor>("Ln2Variance")
: nullptr;
auto* ln2_scale = !pre_layer_norm
? context.Input<framework::Tensor>("Ln2Scale")
: nullptr;
auto* ln2_bias =
!pre_layer_norm ? context.Input<framework::Tensor>("Ln2Bias") : nullptr;
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* d_ln1_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln1Scale"));
auto* d_ln1_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln1Bias"));
auto* d_ln1_scale = pre_layer_norm
? context.Output<framework::Tensor>(
framework::GradVarName("Ln1Scale"))
: nullptr;
auto* d_ln1_bias = pre_layer_norm
? context.Output<framework::Tensor>(
framework::GradVarName("Ln1Bias"))
: nullptr;
auto* d_ln2_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Scale"));
pre_layer_norm ? nullptr : context.Output<framework::Tensor>(
framework::GradVarName("Ln2Scale"));
auto* d_ln2_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Bias"));
pre_layer_norm ? nullptr : context.Output<framework::Tensor>(
framework::GradVarName("Ln2Bias"));
auto* d_linear1_weight = context.Output<framework::Tensor>(
framework::GradVarName("Linear1Weight"));
auto* d_linear1_bias = context.Output<framework::Tensor>(
......@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
......
......@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.pre_layer_norm = False
self.training = True
self.batch_size = 8
......@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
def config(self):
self.x_type = np.float16
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.pre_layer_norm = False
self.training = True
self.batch_size = 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册