未验证 提交 0019ef0c 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

revert back ffn2 (#49392)

上级 e0ee7403
......@@ -215,11 +215,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_bias_residual = CublasFusedMLP<T>(dev_ctx);
ffn2_linear_bias_residual.Setup(
ffn1_out.dims(), ffn2_weights[0]->dims(), false, false);
auto ffn2_linear_compute = AttnMatMul<T>(
dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false);
// 8. ffn2 Layernorm
// 8. ffn2 Layernorm residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
......@@ -509,16 +508,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step7. ffn2 matmul
if (pre_layer_norm) {
ffn2_linear_bias_residual.ComputeForward(&ffn1_out,
ffn2_weights[i],
ffn2_biases[i],
&bias_dropout_residual_out,
buf1,
"none");
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr);
} else {
ffn2_linear_bias_residual.ComputeForward(
&ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none");
ffn2_linear_compute.ComputeForward(
ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......@@ -534,30 +528,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
VLOG(0) << "step7.1";
#endif
// step8. layer norm or do nothing
// because bias_add + residual has been fused into cublasFusedMLP
// step8. layer norm + bias_add + residual
if (pre_layer_norm) {
// TODO(wangxi): remove dropout mask in inference
if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>();
auto *ln_bias_data = ln_biases[i + 1]->data<U>();
ffn2_fused_dropout_helper.LayerNorm(dev_ctx,
buf1->data<T>(),
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
ln_mean_data,
ln_var_data);
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx,
buf1->data<T>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
dropout_mask_out_data,
buf0->data<T>(),
ln_mean_data,
ln_var_data);
} else {
ffn2_fused_dropout_helper.ResidualDropoutBias(
dev_ctx,
buf1->data<T>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
buf1->data<T>(),
dropout_mask_out_data);
}
} else {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
ffn2_fused_dropout_helper.LayerNorm(dev_ctx,
buf0->data<T>(),
ln_scale_data,
ln_bias_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
dev_ctx,
buf0->data<T>(),
buf1->data<T>(),
ffn2_biases[i]->data<T>(),
ln_scale_data,
ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册