未验证 提交 ae592233 编写于 作者: L Li Min 提交者: GitHub

Fix fused_attention_op and fused_feedforward_op bug when pre_layer_norm is false. (#36793) (#36816)

* Fix bug when pre_layer_norm is false.
上级 11b9f5f9
...@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -37,12 +37,15 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
"FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", "FusedAttentionOp");
"FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", "FusedAttentionOp");
"FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
}
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head] // qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp"); "FusedAttentionOp");
...@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -101,9 +104,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"input qkv_weight = [%s]", "input qkv_weight = [%s]",
x_dim, y_dim)); x_dim, y_dim));
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
}
// [batch_size, seq_len, 3, num_head, head_size] // [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut", ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
...@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -351,11 +356,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("Ln2Bias")); ctx->GetInputDim("Ln2Bias"));
} }
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut", OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad"); "FusedAttentionGrad");
} }
...@@ -370,13 +375,15 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -370,13 +375,15 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad"); "FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("LnScale"))) { if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->SetOutputDim(framework::GradVarName("LnScale"), if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->GetInputDim("LnScale")); ctx->SetOutputDim(framework::GradVarName("LnScale"),
} ctx->GetInputDim("LnScale"));
if (ctx->HasOutput(framework::GradVarName("LnBias"))) { }
ctx->SetOutputDim(framework::GradVarName("LnBias"), if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->GetInputDim("LnBias")); ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
...@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -390,8 +397,10 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("QKVBias"), ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias")); ctx->GetInputDim("QKVBias"));
ctx->SetOutputDim(framework::GradVarName("LnOut"), if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
ctx->GetInputDim("LnOut")); ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
}
ctx->SetOutputDim(framework::GradVarName("FMHAOut"), ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut")); ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"), ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
...@@ -442,16 +451,23 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -442,16 +451,23 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SrcMask", this->Input("SrcMask")); op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias")); op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale")); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("LnScale"), bool is_pre_layer_norm =
this->InputGrad("LnScale")); BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
} if (is_pre_layer_norm) {
if (this->HasInput("LnBias")) { if (this->HasInput("LnScale")) {
op->SetInput("LnBias", this->Input("LnBias")); op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnBias"), op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnBias")); this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
} }
if (this->HasInput("Ln2Scale")) { if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale")); op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"), op->SetOutput(framework::GradVarName("Ln2Scale"),
...@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -473,9 +489,17 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("OutLinearW")); this->InputGrad("OutLinearW"));
// use forward outputs as backward inputs. // use forward outputs as backward inputs.
op->SetInput("LnOut", this->Output("LnOut")); if (is_pre_layer_norm) {
op->SetInput("LnMean", this->Output("LnMean")); if (this->HasOutput("LnOut")) {
op->SetInput("LnVariance", this->Output("LnVariance")); op->SetInput("LnOut", this->Output("LnOut"));
}
if (this->HasOutput("LnMean")) {
op->SetInput("LnMean", this->Output("LnMean"));
}
if (this->HasOutput("LnVariance")) {
op->SetInput("LnVariance", this->Output("LnVariance"));
}
}
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut")); op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2")); op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
...@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -496,7 +520,12 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("QKVOut", this->Output("QKVOut")); op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput // backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut")); if (is_pre_layer_norm) {
if (this->HasOutput("LnOut")) {
op->SetOutput(framework::GradVarName("LnOut"),
this->OutputGrad("LnOut"));
}
}
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut")); op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"), op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut")); this->OutputGrad("QKVBiasOut"));
...@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -517,8 +546,6 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("BiasDropoutResidualOut")); this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"), op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut")); this->OutputGrad("OutLinearOut"));
op->SetAttrMap(this->Attrs());
} }
}; };
......
...@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -97,9 +97,12 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
auto *x_data = input_x->data<T>(); auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>()); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace()); auto *ln_mean_data =
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace()); pre_layer_norm ? ln_mean->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace()); auto *ln_var_data =
pre_layer_norm ? ln_var->mutable_data<U>(ctx.GetPlace()) : nullptr;
auto *ln_out_data =
pre_layer_norm ? ln_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
auto *qkv_weight_data = qkv_weight->data<T>(); auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>(); auto *qkv_bias_data = qkv_bias->data<T>();
...@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -243,9 +246,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *out_linear_bias_data = out_linear_bias->data<T>(); auto *out_linear_bias_data = out_linear_bias->data<T>();
// fw output // fw output
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *fmha_out = ctx.Input<Tensor>("FMHAOut"); auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2"); auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
auto *qk_out = ctx.Input<Tensor>("QKOut"); auto *qk_out = ctx.Input<Tensor>("QKOut");
...@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -260,9 +260,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut"); auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out = auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut"); ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *fmha_out_data = fmha_out->data<T>(); auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>(); auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>(); auto *qk_out_data = qk_out->data<T>();
...@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -277,7 +274,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// output's grad // output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut")); auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
auto *d_qkv_bias_out = auto *d_qkv_bias_out =
ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut")); ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
...@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -297,7 +293,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_bias_dropout_residual_out = auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut")); ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace()); auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
...@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -315,8 +310,6 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace()); d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad // parameter grad
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW")); auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias")); auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
auto *d_out_linear_weight = auto *d_out_linear_weight =
...@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -325,12 +318,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias")); ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale")); auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias")); auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace()); auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data = auto *d_out_linear_weight_data =
...@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -407,6 +395,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
cudaMemcpyDeviceToDevice); cudaMemcpyDeviceToDevice);
if (pre_layer_norm) { if (pre_layer_norm) {
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data, qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data, d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data); d_qkv_weight_data, d_qkv_bias_data);
......
...@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { ...@@ -41,18 +41,8 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
"fused_feedforward"); "fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout2Mask"), "Output", "Dropout2Mask", OP_INOUT_CHECK(context->HasOutput("Dropout2Mask"), "Output", "Dropout2Mask",
"fused_feedforward"); "fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Linear1Out"), "Output", "Linear1Out", OP_INOUT_CHECK(context->HasOutput("Linear1Out"), "Output", "Linear1Out",
"fused_feedforward"); "fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout1Out"), "Output", "Dropout1Out", OP_INOUT_CHECK(context->HasOutput("Dropout1Out"), "Output", "Dropout1Out",
"fused_feedforward"); "fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout2Out"), "Output", "Dropout2Out", OP_INOUT_CHECK(context->HasOutput("Dropout2Out"), "Output", "Dropout2Out",
...@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { ...@@ -76,7 +66,6 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
} }
context->SetOutputDim("Dropout1Out", tmp_dim_x); context->SetOutputDim("Dropout1Out", tmp_dim_x);
context->SetOutputDim("Linear1Out", tmp_dim_x); context->SetOutputDim("Linear1Out", tmp_dim_x);
context->SetOutputDim("Ln1Out", dim_x);
context->SetOutputDim("Dropout2Out", dim_x); context->SetOutputDim("Dropout2Out", dim_x);
if (context->Attrs().Get<bool>("dropout2_is_test") == false) { if (context->Attrs().Get<bool>("dropout2_is_test") == false) {
...@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel { ...@@ -84,10 +73,25 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
} }
framework::DDim mean_dim = framework::DDim mean_dim =
framework::make_ddim({mat_dim_x.batch_size_ * mat_dim_x.height_}); framework::make_ddim({mat_dim_x.batch_size_ * mat_dim_x.height_});
context->SetOutputDim("Ln1Mean", mean_dim); bool pre_layer_norm = context->Attrs().Get<bool>("pre_layer_norm");
context->SetOutputDim("Ln1Variance", mean_dim); if (pre_layer_norm) {
context->SetOutputDim("Ln2Mean", mean_dim); OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean",
context->SetOutputDim("Ln2Variance", mean_dim); "fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out",
"fused_feedforward");
context->SetOutputDim("Ln1Out", dim_x);
context->SetOutputDim("Ln1Mean", mean_dim);
context->SetOutputDim("Ln1Variance", mean_dim);
} else {
OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"fused_feedforward");
context->SetOutputDim("Ln2Mean", mean_dim);
context->SetOutputDim("Ln2Variance", mean_dim);
}
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
...@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { ...@@ -218,14 +222,13 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false, PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false")); "GradOp is only callable when is_test is false"));
bool pre_layer_norm = ctx->Attrs().Get<bool>("pre_layer_norm");
OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask", OP_INOUT_CHECK(ctx->HasInput("Dropout1Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask", OP_INOUT_CHECK(ctx->HasInput("Dropout2Mask"), "Input", "Dropout1Mask",
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out", OP_INOUT_CHECK(ctx->HasInput("Linear1Out"), "Input", "Linear1Out",
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out", OP_INOUT_CHECK(ctx->HasInput("Dropout1Out"), "Input", "Dropout1Out",
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out", OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), "Input", "Dropout2Out",
...@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { ...@@ -234,14 +237,19 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight", OP_INOUT_CHECK(ctx->HasInput("Linear2Weight"), "Input", "Linear2Weight",
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean", if (pre_layer_norm) {
"FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln1Mean"), "Input", "Ln1Mean",
OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance", "FusedFeedForwardGrad");
"FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln1Variance"), "Input", "Ln1Variance",
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", "FusedFeedForwardGrad");
"FusedFeedForwardGrad"); OP_INOUT_CHECK(ctx->HasInput("Ln1Out"), "Input", "Ln1Out",
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance", "FusedFeedForwardGrad");
"FusedFeedForwardGrad"); } else {
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedFeedForwardGrad");
}
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "FusedFeedForwardGrad"); framework::GradVarName("Out"), "FusedFeedForwardGrad");
...@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -299,30 +307,36 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Linear1Weight", this->Input("Linear1Weight")); op->SetInput("Linear1Weight", this->Input("Linear1Weight"));
op->SetInput("Linear1Bias", this->Input("Linear1Bias")); op->SetInput("Linear1Bias", this->Input("Linear1Bias"));
op->SetInput("Linear2Weight", this->Input("Linear2Weight")); op->SetInput("Linear2Weight", this->Input("Linear2Weight"));
op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Dropout1Mask", this->Output("Dropout1Mask")); op->SetInput("Dropout1Mask", this->Output("Dropout1Mask"));
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask")); op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
op->SetInput("Linear1Out", this->Output("Linear1Out")); op->SetInput("Linear1Out", this->Output("Linear1Out"));
op->SetInput("Ln1Out", this->Output("Ln1Out"));
op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("Dropout1Out", this->Output("Dropout1Out")); op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
op->SetInput("Dropout2Out", this->Output("Dropout2Out")); op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
op->SetAttrMap(this->Attrs());
bool pre_layer_norm = BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Ln1Scale"), if (pre_layer_norm) {
this->InputGrad("Ln1Scale")); op->SetInput("Ln1Scale", this->Input("Ln1Scale"));
op->SetOutput(framework::GradVarName("Ln1Bias"), op->SetInput("Ln1Bias", this->Input("Ln1Bias"));
this->InputGrad("Ln1Bias")); op->SetInput("Ln1Out", this->Output("Ln1Out"));
op->SetOutput(framework::GradVarName("Ln2Scale"), op->SetInput("Ln1Mean", this->Output("Ln1Mean"));
this->InputGrad("Ln2Scale")); op->SetInput("Ln1Variance", this->Output("Ln1Variance"));
op->SetOutput(framework::GradVarName("Ln2Bias"), op->SetOutput(framework::GradVarName("Ln1Scale"),
this->InputGrad("Ln2Bias")); this->InputGrad("Ln1Scale"));
op->SetOutput(framework::GradVarName("Ln1Bias"),
this->InputGrad("Ln1Bias"));
} else {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}
op->SetOutput(framework::GradVarName("Linear1Weight"), op->SetOutput(framework::GradVarName("Linear1Weight"),
this->InputGrad("Linear1Weight")); this->InputGrad("Linear1Weight"));
op->SetOutput(framework::GradVarName("Linear1Bias"), op->SetOutput(framework::GradVarName("Linear1Bias"),
...@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -334,8 +348,6 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Linear2Bias"), op->SetOutput(framework::GradVarName("Linear2Bias"),
this->InputGrad("Linear2Bias")); this->InputGrad("Linear2Bias"));
} }
op->SetAttrMap(this->Attrs());
} }
}; };
......
...@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -113,26 +113,40 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias"); auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto* linear2_weight = context.Input<framework::Tensor>("Linear2Weight"); auto* linear2_weight = context.Input<framework::Tensor>("Linear2Weight");
auto* linear2_bias = context.Input<framework::Tensor>("Linear2Bias"); auto* linear2_bias = context.Input<framework::Tensor>("Linear2Bias");
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale"); const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale"); auto* ln1_scale =
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias"); pre_layer_norm ? context.Input<framework::Tensor>("Ln1Scale") : nullptr;
auto* ln1_bias =
auto* ln1_mean = context.Output<framework::Tensor>("Ln1Mean"); pre_layer_norm ? context.Input<framework::Tensor>("Ln1Bias") : nullptr;
auto* ln1_variance = context.Output<framework::Tensor>("Ln1Variance"); auto* ln2_scale = !pre_layer_norm
auto* ln2_mean = context.Output<framework::Tensor>("Ln2Mean"); ? context.Input<framework::Tensor>("Ln2Scale")
auto* ln2_variance = context.Output<framework::Tensor>("Ln2Variance"); : nullptr;
auto* ln2_bias =
!pre_layer_norm ? context.Input<framework::Tensor>("Ln2Bias") : nullptr;
auto* ln1_mean =
pre_layer_norm ? context.Output<framework::Tensor>("Ln1Mean") : nullptr;
auto* ln1_variance = pre_layer_norm
? context.Output<framework::Tensor>("Ln1Variance")
: nullptr;
auto* ln2_mean = !pre_layer_norm
? context.Output<framework::Tensor>("Ln2Mean")
: nullptr;
auto* ln2_variance = !pre_layer_norm
? context.Output<framework::Tensor>("Ln2Variance")
: nullptr;
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto* dropout1_mask = context.Output<framework::Tensor>("Dropout1Mask"); auto* dropout1_mask = context.Output<framework::Tensor>("Dropout1Mask");
auto* dropout2_mask = context.Output<framework::Tensor>("Dropout2Mask"); auto* dropout2_mask = context.Output<framework::Tensor>("Dropout2Mask");
auto* linear1_out = context.Output<framework::Tensor>("Linear1Out"); auto* linear1_out = context.Output<framework::Tensor>("Linear1Out");
auto* ln1_out = context.Output<framework::Tensor>("Ln1Out"); auto* ln1_out =
pre_layer_norm ? context.Output<framework::Tensor>("Ln1Out") : nullptr;
auto* dropout1_out = context.Output<framework::Tensor>("Dropout1Out"); auto* dropout1_out = context.Output<framework::Tensor>("Dropout1Out");
auto* dropout2_out = context.Output<framework::Tensor>("Dropout2Out"); auto* dropout2_out = context.Output<framework::Tensor>("Dropout2Out");
const std::string act_method = context.Attr<std::string>("act_method"); const std::string act_method = context.Attr<std::string>("act_method");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const float epsilon1 = context.Attr<float>("ln1_epsilon"); const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon"); const float epsilon2 = context.Attr<float>("ln2_epsilon");
...@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> { ...@@ -144,12 +158,16 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
out->mutable_data<T>(place); out->mutable_data<T>(place);
dropout1_mask->mutable_data<uint8_t>(place); dropout1_mask->mutable_data<uint8_t>(place);
dropout2_mask->mutable_data<uint8_t>(place); dropout2_mask->mutable_data<uint8_t>(place);
ln1_mean->mutable_data<U>(place); if (pre_layer_norm) {
ln1_variance->mutable_data<U>(place); ln1_mean->mutable_data<U>(place);
ln2_mean->mutable_data<U>(place); ln1_variance->mutable_data<U>(place);
ln2_variance->mutable_data<U>(place); ln1_out->mutable_data<T>(place);
} else {
ln2_mean->mutable_data<U>(place);
ln2_variance->mutable_data<U>(place);
}
linear1_out->mutable_data<T>(place); linear1_out->mutable_data<T>(place);
ln1_out->mutable_data<T>(place);
dropout1_out->mutable_data<T>(place); dropout1_out->mutable_data<T>(place);
dropout2_out->mutable_data<T>(place); dropout2_out->mutable_data<T>(place);
...@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -193,16 +211,16 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const framework::Tensor& d_out, const framework::Tensor& x, const framework::Tensor& d_out, const framework::Tensor& x,
const framework::Tensor& dropout1_mask, const framework::Tensor& dropout1_mask,
const framework::Tensor& dropout2_mask, const framework::Tensor& dropout2_mask,
const framework::Tensor& linear1_out, const framework::Tensor& ln1_out, const framework::Tensor& linear1_out, const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out, const framework::Tensor& dropout1_out,
const framework::Tensor& dropout2_out, const framework::Tensor& dropout2_out,
const framework::Tensor& linear1_weight, const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias, const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight, const framework::Tensor& linear2_weight,
const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta, const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta,
const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance, const framework::Tensor* ln1_mean, const framework::Tensor* ln1_variance,
const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta, const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta,
const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance, const framework::Tensor* ln2_mean, const framework::Tensor* ln2_variance,
framework::Tensor* d_x, framework::Tensor* d_linear1_weight, framework::Tensor* d_x, framework::Tensor* d_linear1_weight,
framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight, framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight,
framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma, framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma,
...@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -252,8 +270,8 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} else { } else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, d_out.data<T>(), dropout2_out.data<T>(), ctx, d_out.data<T>(), dropout2_out.data<T>(),
dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean.data<U>(), dropout2_mask.data<uint8_t>(), ln2_gamma_ptr, ln2_mean->data<U>(),
ln2_variance.data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr, ln2_variance->data<U>(), d_dropout2_out.data<T>(), d_ln2_gamma_ptr,
d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr, d_ln2_beta_ptr, d_linear2_out.data<T>(), d_linear2_bias_ptr,
d_residual.data<T>()); d_residual.data<T>());
} }
...@@ -273,13 +291,13 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -273,13 +291,13 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
if (pre_layer_norm) { if (pre_layer_norm) {
framework::Tensor d_ln1_out; framework::Tensor d_ln1_out;
d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place); d_ln1_out.mutable_data<T>({bsz_seq, d_model}, place);
MatMulGrad(ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out, MatMulGrad(ctx, d_linear1_out, *ln1_out, linear1_weight, &d_ln1_out,
d_linear1_weight); d_linear1_weight);
pre_layernorm_helper.LayerNormGrad(ctx, d_ln1_out.data<T>(), x.data<T>(), pre_layernorm_helper.LayerNormGrad(
ln1_gamma_ptr, ln1_mean.data<U>(), ctx, d_ln1_out.data<T>(), x.data<T>(), ln1_gamma_ptr,
ln1_variance.data<U>(), d_x->data<T>(), ln1_mean->data<U>(), ln1_variance->data<U>(), d_x->data<T>(),
d_ln1_gamma_ptr, d_ln1_beta_ptr); d_ln1_gamma_ptr, d_ln1_beta_ptr);
} else { } else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
} }
...@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -290,33 +308,52 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
auto d_out = auto d_out =
*context.Input<framework::Tensor>(framework::GradVarName("Out")); *context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto x = *context.Input<framework::Tensor>("X"); auto x = *context.Input<framework::Tensor>("X");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
auto dropout1_mask = *context.Input<framework::Tensor>("Dropout1Mask"); auto dropout1_mask = *context.Input<framework::Tensor>("Dropout1Mask");
auto dropout2_mask = *context.Input<framework::Tensor>("Dropout2Mask"); auto dropout2_mask = *context.Input<framework::Tensor>("Dropout2Mask");
auto linear1_out = *context.Input<framework::Tensor>("Linear1Out"); auto linear1_out = *context.Input<framework::Tensor>("Linear1Out");
auto ln1_out = *context.Input<framework::Tensor>("Ln1Out"); auto* ln1_out =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Out") : nullptr;
auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out"); auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out");
auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out"); auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out");
auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight"); auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight");
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias"); auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight"); auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight");
auto ln1_mean = *context.Input<framework::Tensor>("Ln1Mean"); auto* ln1_mean =
auto ln1_variance = *context.Input<framework::Tensor>("Ln1Variance"); pre_layer_norm ? context.Input<framework::Tensor>("Ln1Mean") : nullptr;
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale"); auto* ln1_variance = pre_layer_norm
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias"); ? context.Input<framework::Tensor>("Ln1Variance")
auto ln2_mean = *context.Input<framework::Tensor>("Ln2Mean"); : nullptr;
auto ln2_variance = *context.Input<framework::Tensor>("Ln2Variance"); auto* ln1_scale =
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale"); pre_layer_norm ? context.Input<framework::Tensor>("Ln1Scale") : nullptr;
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias"); auto* ln1_bias =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Bias") : nullptr;
auto* ln2_mean =
!pre_layer_norm ? context.Input<framework::Tensor>("Ln2Mean") : nullptr;
auto* ln2_variance = !pre_layer_norm
? context.Input<framework::Tensor>("Ln2Variance")
: nullptr;
auto* ln2_scale = !pre_layer_norm
? context.Input<framework::Tensor>("Ln2Scale")
: nullptr;
auto* ln2_bias =
!pre_layer_norm ? context.Input<framework::Tensor>("Ln2Bias") : nullptr;
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* d_ln1_scale = auto* d_ln1_scale = pre_layer_norm
context.Output<framework::Tensor>(framework::GradVarName("Ln1Scale")); ? context.Output<framework::Tensor>(
auto* d_ln1_bias = framework::GradVarName("Ln1Scale"))
context.Output<framework::Tensor>(framework::GradVarName("Ln1Bias")); : nullptr;
auto* d_ln1_bias = pre_layer_norm
? context.Output<framework::Tensor>(
framework::GradVarName("Ln1Bias"))
: nullptr;
auto* d_ln2_scale = auto* d_ln2_scale =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Scale")); pre_layer_norm ? nullptr : context.Output<framework::Tensor>(
framework::GradVarName("Ln2Scale"));
auto* d_ln2_bias = auto* d_ln2_bias =
context.Output<framework::Tensor>(framework::GradVarName("Ln2Bias")); pre_layer_norm ? nullptr : context.Output<framework::Tensor>(
framework::GradVarName("Ln2Bias"));
auto* d_linear1_weight = context.Output<framework::Tensor>( auto* d_linear1_weight = context.Output<framework::Tensor>(
framework::GradVarName("Linear1Weight")); framework::GradVarName("Linear1Weight"));
auto* d_linear1_bias = context.Output<framework::Tensor>( auto* d_linear1_bias = context.Output<framework::Tensor>(
...@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -328,7 +365,6 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const float epsilon1 = context.Attr<float>("ln1_epsilon"); const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon"); const float epsilon2 = context.Attr<float>("ln2_epsilon");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const std::string act_method = context.Attr<std::string>("act_method"); const std::string act_method = context.Attr<std::string>("act_method");
DropoutParam dropout_param1(context, 1); DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2); DropoutParam dropout_param2(context, 2);
......
...@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest): ...@@ -65,7 +65,7 @@ class TestFusedAttentionOp(OpTest):
def config(self): def config(self):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True self.pre_layer_norm = False
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
...@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest): ...@@ -213,11 +213,40 @@ class TestFusedAttentionOp(OpTest):
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5) x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
class TestFusedAttentionOpFp16(TestFusedAttentionOp): class TestFusedAttentionOpFp16(TestFusedAttentionOp):
def config(self): def config(self):
self.x_type = np.float16 self.x_type = np.float16
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True self.pre_layer_norm = False
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册