diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 9eab7c6bcb5d369435133884099c19858cbac07e..69ab2362bd5a25e74bca3eebcb1ae9254f5d5e57 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -215,11 +215,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn2_weights = ctx.MultiInput("FFN2Weight"); auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - auto ffn2_linear_bias_residual = CublasFusedMLP(dev_ctx); - ffn2_linear_bias_residual.Setup( - ffn1_out.dims(), ffn2_weights[0]->dims(), false, false); + auto ffn2_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_embed, dim_ffn, false); - // 8. ffn2 Layernorm + // 8. ffn2 Layernorm residual bias DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); @@ -509,16 +508,11 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // step7. ffn2 matmul if (pre_layer_norm) { - ffn2_linear_bias_residual.ComputeForward(&ffn1_out, - ffn2_weights[i], - ffn2_biases[i], - &bias_dropout_residual_out, - buf1, - "none"); - + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr); } else { - ffn2_linear_bias_residual.ComputeForward( - &ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none"); + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER @@ -534,30 +528,48 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { VLOG(0) << "step7.1"; #endif - // step8. layer norm or do nothing - // because bias_add + residual has been fused into cublasFusedMLP + // step8. layer norm + bias_add + residual if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference if (i < layers - 1) { auto *ln_scale_data = ln_scales[i + 1]->data(); auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayerNorm(dev_ctx, - buf1->data(), - ln_scale_data, - ln_bias_data, - buf0->data(), - ln_mean_data, - ln_var_data); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf1->data(), + dropout_mask_out_data, + buf0->data(), + ln_mean_data, + ln_var_data); + } else { + ffn2_fused_dropout_helper.ResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases[i]->data(), + buf1->data(), + dropout_mask_out_data); } } else { auto *ln_scale_data = ffn_ln_scales[i]->data(); auto *ln_bias_data = ffn_ln_biases[i]->data(); - ffn2_fused_dropout_helper.LayerNorm(dev_ctx, - buf0->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER