未验证 提交 0019ef0c 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

revert back ffn2 (#49392)

上级 e0ee7403
...@@ -215,11 +215,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -215,11 +215,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight"); auto ffn2_weights = ctx.MultiInput<phi::DenseTensor>("FFN2Weight");
auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias"); auto ffn2_biases = ctx.MultiInput<phi::DenseTensor>("FFN2Bias");
auto ffn2_linear_bias_residual = CublasFusedMLP<T>(dev_ctx); auto ffn2_linear_compute = AttnMatMul<T>(
ffn2_linear_bias_residual.Setup( dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false);
ffn1_out.dims(), ffn2_weights[0]->dims(), false, false);
// 8. ffn2 Layernorm // 8. ffn2 Layernorm residual bias
DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper( FusedDropoutLayerNormHelper<T, uint8_t> ffn2_fused_dropout_helper(
dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon);
...@@ -509,16 +508,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -509,16 +508,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step7. ffn2 matmul // step7. ffn2 matmul
if (pre_layer_norm) { if (pre_layer_norm) {
ffn2_linear_bias_residual.ComputeForward(&ffn1_out, ffn2_linear_compute.ComputeForward(
ffn2_weights[i], ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr);
ffn2_biases[i],
&bias_dropout_residual_out,
buf1,
"none");
} else { } else {
ffn2_linear_bias_residual.ComputeForward( ffn2_linear_compute.ComputeForward(
&ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none"); ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
...@@ -534,30 +528,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> { ...@@ -534,30 +528,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
VLOG(0) << "step7.1"; VLOG(0) << "step7.1";
#endif #endif
// step8. layer norm or do nothing // step8. layer norm + bias_add + residual
// because bias_add + residual has been fused into cublasFusedMLP
if (pre_layer_norm) { if (pre_layer_norm) {
// TODO(wangxi): remove dropout mask in inference
if (i < layers - 1) { if (i < layers - 1) {
auto *ln_scale_data = ln_scales[i + 1]->data<U>(); auto *ln_scale_data = ln_scales[i + 1]->data<U>();
auto *ln_bias_data = ln_biases[i + 1]->data<U>(); auto *ln_bias_data = ln_biases[i + 1]->data<U>();
ffn2_fused_dropout_helper.LayerNorm(dev_ctx, ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
buf1->data<T>(), dev_ctx,
ln_scale_data, buf1->data<T>(),
ln_bias_data, bias_dropout_residual_out_data,
buf0->data<T>(), ffn2_biases[i]->data<T>(),
ln_mean_data, ln_scale_data,
ln_var_data); ln_bias_data,
buf1->data<T>(),
dropout_mask_out_data,
buf0->data<T>(),
ln_mean_data,
ln_var_data);
} else {
ffn2_fused_dropout_helper.ResidualDropoutBias(
dev_ctx,
buf1->data<T>(),
bias_dropout_residual_out_data,
ffn2_biases[i]->data<T>(),
buf1->data<T>(),
dropout_mask_out_data);
} }
} else { } else {
auto *ln_scale_data = ffn_ln_scales[i]->data<U>(); auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>(); auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
ffn2_fused_dropout_helper.LayerNorm(dev_ctx, ffn2_fused_dropout_helper.LayernormResidualDropoutBias(
buf0->data<T>(), dev_ctx,
ln_scale_data, buf0->data<T>(),
ln_bias_data, buf1->data<T>(),
buf1->data<T>(), ffn2_biases[i]->data<T>(),
ln_mean_data, ln_scale_data,
ln_var_data); ln_bias_data,
buf0->data<T>(),
dropout_mask_out_data,
buf1->data<T>(),
ln_mean_data,
ln_var_data);
} }
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册