diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index de62fe38653b5d497e2189c114030c046aa65e65..b1212f410fe9f27f55ba6a4d3e7cce05b4ad339e 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -813,6 +813,8 @@ PD_REGISTER_KERNEL(fused_attention, phi::dtype::float16, double, float) { + kernel->OutputAt(9).SetDataType(phi::DataType::UINT8); + kernel->OutputAt(14).SetDataType(phi::DataType::UINT8); if (kernel_key.dtype() == phi::DataType::FLOAT16) { kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 2c3456d31025b5662824830af4809cc473ef7dde..ee40633e4252b3eaa66c25a3a9e99ddd59f03151 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -13,634 +13,668 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/fluid/operators/fused/fused_attention_utils.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/errors.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" +#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" + +namespace phi { +namespace fusion { + +template +void MatMul(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& a, + const phi::DenseTensor& b, + phi::DenseTensor* c) { + auto blas = phi::funcs::GetBlas(dev_ctx); + auto a_2d = phi::FoldInitDims(a); + auto b_2d = phi::FoldInitDims(b); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, false); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, false); + T alpha = static_cast(1.0); + blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0)); +} -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/process_group_nccl.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif - -namespace paddle { -namespace operators { - -template -static void AllReduce(phi::DenseTensor& tensor, // NOLINT - const int ring_id, - const phi::GPUContext& ctx) { - if (ring_id == -1) return; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - - if (map->has(ring_id)) { - paddle::distributed::ProcessGroup* pg = map->get(ring_id); - auto pg_nccl = static_cast(pg); - paddle::distributed::AllreduceOptions opts; - opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true); - task->Wait(); +template +void FFN(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& linear1_weight, + const phi::DenseTensor* linear1_bias, + const phi::DenseTensor& linear2_weight, + const phi::DenseTensor* linear2_bias, + const phi::DenseTensor* ln1_scale, + const phi::DenseTensor* ln1_bias, + const phi::DenseTensor* ln2_scale, + const phi::DenseTensor* ln2_bias, + phi::DenseTensor* out, + phi::DenseTensor* dropout1_mask, + phi::DenseTensor* dropout2_mask, + phi::DenseTensor* ln1_mean, + phi::DenseTensor* ln1_variance, + phi::DenseTensor* ln2_mean, + phi::DenseTensor* ln2_variance, + phi::DenseTensor* linear1_out, + phi::DenseTensor* ln1_out, + phi::DenseTensor* dropout1_out, + phi::DenseTensor* dropout2_out, + const int bsz_seq, + const int d_model, + const int dim_feedforward, + const std::string& act_method, + const bool pre_layer_norm, + const float epsilon1, + const float epsilon2, + const bool add_residual, + const int ring_id, + const phi::fusion::DropoutParam& dropout_param1, + const phi::fusion::DropoutParam& dropout_param2) { + phi::fusion::FusedDropoutLayerNormHelper pre_layernorm_helper( + bsz_seq, d_model, epsilon1); + phi::fusion::FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, bsz_seq, dim_feedforward, dropout_param1); + phi::fusion::FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, d_model, dropout_param2, epsilon2); + + using U = phi::funcs::LayerNormParamType; + const phi::DenseTensor* in = &x; + + const U* ln1_scale_ptr = + ln1_scale == nullptr ? nullptr : ln1_scale->data(); + const U* ln1_bias_ptr = ln1_bias == nullptr ? nullptr : ln1_bias->data(); + const U* ln2_scale_ptr = + ln2_scale == nullptr ? nullptr : ln2_scale->data(); + const U* ln2_bias_ptr = ln2_bias == nullptr ? nullptr : ln2_bias->data(); + const T* linear1_bias_ptr = + linear1_bias == nullptr ? nullptr : linear1_bias->data(); + const T* linear2_bias_ptr = + linear2_bias == nullptr ? nullptr : linear2_bias->data(); + + if (pre_layer_norm) { + pre_layernorm_helper.LayerNorm(dev_ctx, + x.data(), + ln1_scale_ptr, + ln1_bias_ptr, + ln1_out->data(), + ln1_mean->data(), + ln1_variance->data()); + in = ln1_out; + } + MatMul(dev_ctx, *in, linear1_weight, linear1_out); + fused_act_dropout_helper.DropoutActBias(dev_ctx, + linear1_out->data(), + linear1_bias_ptr, + act_method, + dropout1_out->data(), + dropout1_mask->data()); + phi::DenseTensor linear2_out; + linear2_out.Resize({bsz_seq, d_model}); + dev_ctx.template Alloc(&linear2_out, linear2_out.numel() * sizeof(T)); + MatMul(dev_ctx, *dropout1_out, linear2_weight, &linear2_out); + + // tensor model parallel + phi::fusion::AllReduce(linear2_out, ring_id, dev_ctx); + + const T* residual_ptr = add_residual ? x.data() : nullptr; + if (!pre_layer_norm) { + // TODO(Xreki): support post layer_norm case when add_residual is false. + PADDLE_ENFORCE_EQ(add_residual, + true, + phi::errors::InvalidArgument( + "Attribute add_residual is expected to be true " + "when pre_layer_norm is false.")); + + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + linear2_out.data(), + residual_ptr, + linear2_bias_ptr, + ln2_scale_ptr, + ln2_bias_ptr, + dropout2_out->data(), + dropout2_mask->data(), + out->data(), + ln2_mean->data(), + ln2_variance->data()); } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); - int64_t numel = tensor.numel(); - const void* sendbuff = tensor.data(); - auto place = ctx.GetPlace(); - void* recvbuff = ctx.Alloc(&tensor, tensor.numel() * sizeof(T)); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = ctx.stream(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); + fused_dropout_layernorm_helper.ResidualDropoutBias( + dev_ctx, + linear2_out.data(), + residual_ptr, + linear2_bias_ptr, + out->data(), + dropout2_mask->data()); } -#else - PADDLE_THROW(platform::errors::Unimplemented( - "PaddlePaddle should compile with NCCL or RCCL when used tensor model " - "parallel op.")); -#endif } -template -class FusedFeedForwardKernel : public framework::OpKernel { - public: - void MatMul(const phi::GPUContext& ctx, - const phi::DenseTensor& a, - const phi::DenseTensor& b, - phi::DenseTensor* c) const { - auto blas = phi::funcs::GetBlas(ctx); - auto a_2d = FoldInitDims(a); - auto b_2d = FoldInitDims(b); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, false); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, false); - T alpha = static_cast(1.0); - blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0)); +template +void FusedFeedForwardKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& dropout1_seed, + const paddle::optional& dropout2_seed, + const DenseTensor& linear1_weight, + const paddle::optional& linear1_bias, + const DenseTensor& linear2_weight, + const paddle::optional& linear2_bias, + const paddle::optional& ln1_scale, + const paddle::optional& ln1_bias, + const paddle::optional& ln2_scale, + const paddle::optional& ln2_bias, + bool pre_layer_norm, + float ln1_epsilon, + float ln2_epsilon, + const std::string& act_method, + float dropout1_prob, + float dropout2_prob, + const std::string& dropout1_implementation, + const std::string& dropout2_implementation, + bool is_test, + bool dropout1_fix_seed, + bool dropout2_fix_seed, + int dropout1_seed_val, + int dropout2_seed_val, + bool add_residual, + int ring_id, + DenseTensor* out, + DenseTensor* dropout1_mask, + DenseTensor* dropout2_mask, + DenseTensor* ln1_mean, + DenseTensor* ln1_variance, + DenseTensor* ln2_mean, + DenseTensor* ln2_variance, + DenseTensor* linear1_out, + DenseTensor* ln1_out, + DenseTensor* dropout1_out, + DenseTensor* dropout2_out) { + auto* x_ptr = &x; + auto* linear1_weight_ptr = &linear1_weight; + auto* linear1_bias_ptr = linear1_bias.get_ptr(); + auto* linear2_weight_ptr = &linear2_weight; + auto* linear2_bias_ptr = linear2_bias.get_ptr(); + + auto* ln1_scale_ptr = pre_layer_norm ? ln1_scale.get_ptr() : nullptr; + auto* ln1_bias_ptr = pre_layer_norm ? ln1_bias.get_ptr() : nullptr; + auto* ln2_scale_ptr = !pre_layer_norm ? ln2_scale.get_ptr() : nullptr; + auto* ln2_bias_ptr = !pre_layer_norm ? ln2_bias.get_ptr() : nullptr; + + if (!pre_layer_norm) { + ln1_mean = nullptr; + ln1_variance = nullptr; + ln1_out = nullptr; + } else { + ln2_mean = nullptr; + ln2_variance = nullptr; } - void FFN(const phi::GPUContext& ctx, - const phi::DenseTensor& x, - const phi::DenseTensor& linear1_weight, - const phi::DenseTensor* linear1_bias, - const phi::DenseTensor& linear2_weight, - const phi::DenseTensor* linear2_bias, - const phi::DenseTensor* ln1_scale, - const phi::DenseTensor* ln1_bias, - const phi::DenseTensor* ln2_scale, - const phi::DenseTensor* ln2_bias, - phi::DenseTensor* out, - phi::DenseTensor* dropout1_mask, - phi::DenseTensor* dropout2_mask, - phi::DenseTensor* ln1_mean, - phi::DenseTensor* ln1_variance, - phi::DenseTensor* ln2_mean, - phi::DenseTensor* ln2_variance, - phi::DenseTensor* linear1_out, - phi::DenseTensor* ln1_out, - phi::DenseTensor* dropout1_out, - phi::DenseTensor* dropout2_out, - const int bsz_seq, - const int d_model, - const int dim_feedforward, - const std::string& act_method, - const bool pre_layer_norm, - const float epsilon1, - const float epsilon2, - const bool add_residual, - const int ring_id, - const DropoutParam& dropout_param1, - const DropoutParam& dropout_param2) const { - FusedDropoutLayerNormHelper pre_layernorm_helper( - bsz_seq, d_model, epsilon1); - FusedDropoutHelper fused_act_dropout_helper( - ctx, bsz_seq, dim_feedforward, dropout_param1); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - ctx, bsz_seq, d_model, dropout_param2, epsilon2); - - using U = phi::funcs::LayerNormParamType; - const phi::DenseTensor* in = &x; - - const U* ln1_scale_ptr = - ln1_scale == nullptr ? nullptr : ln1_scale->data(); - const U* ln1_bias_ptr = ln1_bias == nullptr ? nullptr : ln1_bias->data(); - const U* ln2_scale_ptr = - ln2_scale == nullptr ? nullptr : ln2_scale->data(); - const U* ln2_bias_ptr = ln2_bias == nullptr ? nullptr : ln2_bias->data(); - const T* linear1_bias_ptr = - linear1_bias == nullptr ? nullptr : linear1_bias->data(); - const T* linear2_bias_ptr = - linear2_bias == nullptr ? nullptr : linear2_bias->data(); - - if (pre_layer_norm) { - pre_layernorm_helper.LayerNorm(ctx, - x.data(), - ln1_scale_ptr, - ln1_bias_ptr, - ln1_out->data(), - ln1_mean->data(), - ln1_variance->data()); - in = ln1_out; - } - MatMul(ctx, *in, linear1_weight, linear1_out); - fused_act_dropout_helper.DropoutActBias(ctx, - linear1_out->data(), - linear1_bias_ptr, - act_method, - dropout1_out->data(), - dropout1_mask->data()); - phi::DenseTensor linear2_out; - linear2_out.Resize({bsz_seq, d_model}); - ctx.Alloc(&linear2_out, linear2_out.numel() * sizeof(T)); - MatMul(ctx, *dropout1_out, linear2_weight, &linear2_out); + bool is_upscale_in_train1 = dropout1_implementation == "upscale_in_train"; + bool is_upscale_in_train2 = dropout2_implementation == "upscale_in_train"; + auto* dropout1_seed_ptr = dropout1_seed.get_ptr(); + auto* dropout2_seed_ptr = dropout2_seed.get_ptr(); + + phi::fusion::DropoutParam dropout_param1(dropout1_fix_seed, + 0, + is_test, + is_upscale_in_train1, + dropout1_prob, + dropout1_seed_ptr, + dropout1_seed_val); + phi::fusion::DropoutParam dropout_param2(dropout2_fix_seed, + 0, + is_test, + is_upscale_in_train2, + dropout2_prob, + dropout2_seed_ptr, + dropout2_seed_val); + + using U = phi::funcs::LayerNormParamType; + dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + dev_ctx.template Alloc(dropout1_mask, + dropout1_mask->numel() * sizeof(uint8_t)); + dev_ctx.template Alloc(dropout2_mask, + dropout2_mask->numel() * sizeof(uint8_t)); + if (pre_layer_norm) { + dev_ctx.template Alloc(ln1_mean, ln1_mean->numel() * sizeof(U)); + dev_ctx.template Alloc(ln1_variance, ln1_variance->numel() * sizeof(U)); + dev_ctx.template Alloc(ln1_out, ln1_out->numel() * sizeof(T)); + } else { + dev_ctx.template Alloc(ln2_mean, ln2_mean->numel() * sizeof(U)); + dev_ctx.template Alloc(ln2_variance, ln2_variance->numel() * sizeof(U)); + } - // tensor model parallel - AllReduce(linear2_out, ring_id, ctx); - - const T* residual_ptr = add_residual ? x.data() : nullptr; - if (!pre_layer_norm) { - // TODO(Xreki): support post layer_norm case when add_residual is false. - PADDLE_ENFORCE_EQ(add_residual, - true, - platform::errors::InvalidArgument( - "Attribute add_residual is expected to be true " - "when pre_layer_norm is false.")); - - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - ctx, - linear2_out.data(), - residual_ptr, - linear2_bias_ptr, - ln2_scale_ptr, - ln2_bias_ptr, - dropout2_out->data(), - dropout2_mask->data(), - out->data(), - ln2_mean->data(), - ln2_variance->data()); - } else { - fused_dropout_layernorm_helper.ResidualDropoutBias( - ctx, - linear2_out.data(), - residual_ptr, - linear2_bias_ptr, - out->data(), - dropout2_mask->data()); - } + dev_ctx.template Alloc(linear1_out, linear1_out->numel() * sizeof(T)); + dev_ctx.template Alloc(dropout1_out, dropout1_out->numel() * sizeof(T)); + dev_ctx.template Alloc(dropout2_out, dropout2_out->numel() * sizeof(T)); + + auto x_dim = x_ptr->dims(); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( + phi::RowMatrixFromVector(x_dim), 0, false); + + auto dim = linear1_weight_ptr->dims(); + int d_model = dim[0]; + int dim_feedforward = dim[dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + + phi::fusion::FFN(dev_ctx, + x, + linear1_weight, + linear1_bias_ptr, + linear2_weight, + linear2_bias_ptr, + ln1_scale_ptr, + ln1_bias_ptr, + ln2_scale_ptr, + ln2_bias_ptr, + out, + dropout1_mask, + dropout2_mask, + ln1_mean, + ln1_variance, + ln2_mean, + ln2_variance, + linear1_out, + ln1_out, + dropout1_out, + dropout2_out, + bsz_seq, + d_model, + dim_feedforward, + act_method, + pre_layer_norm, + ln1_epsilon, + ln2_epsilon, + add_residual, + ring_id, + dropout_param1, + dropout_param2); +} + +template +void MatMulGrad(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& d_out, + const phi::DenseTensor& a, + const phi::DenseTensor& b, + phi::DenseTensor* d_a, + phi::DenseTensor* d_b) { + auto blas = phi::funcs::GetBlas(dev_ctx); + auto a_2d = phi::FoldInitDims(a); + auto b_2d = phi::FoldInitDims(b); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, true); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, true); + auto mat_dim_dout = + phi::funcs::CreateMatrixDescriptor(d_out.dims(), 0, false); + T alpha = static_cast(1.0); + blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0)); + blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0)); +} + +template +void FFNGrad(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& d_out, + const phi::DenseTensor& x, + const phi::DenseTensor& dropout1_mask, + const phi::DenseTensor& dropout2_mask, + const phi::DenseTensor& linear1_out, + const phi::DenseTensor* ln1_out, + const phi::DenseTensor& dropout1_out, + const phi::DenseTensor* dropout2_out, + const phi::DenseTensor& linear1_weight, + const phi::DenseTensor* linear1_bias, + const phi::DenseTensor& linear2_weight, + const phi::DenseTensor* ln1_gamma, + const phi::DenseTensor* ln1_beta, + const phi::DenseTensor* ln1_mean, + const phi::DenseTensor* ln1_variance, + const phi::DenseTensor* ln2_gamma, + const phi::DenseTensor* ln2_beta, + const phi::DenseTensor* ln2_mean, + const phi::DenseTensor* ln2_variance, + phi::DenseTensor* d_x, + phi::DenseTensor* d_linear1_weight, + phi::DenseTensor* d_linear1_bias, + phi::DenseTensor* d_linear2_weight, + phi::DenseTensor* d_linear2_bias, + phi::DenseTensor* d_ln1_gamma, + phi::DenseTensor* d_ln1_beta, + phi::DenseTensor* d_ln2_gamma, + phi::DenseTensor* d_ln2_beta, + const int bsz_seq, + const int d_model, + const int dim_feedforward, + const phi::fusion::DropoutParam& dropout_param1, + const phi::fusion::DropoutParam& dropout_param2, + const std::string& act_method, + const bool pre_layer_norm, + const float epsilon1, + const float epsilon2, + const bool add_residual, + const int ring_id) { + phi::fusion::FusedDropoutLayerNormHelper pre_layernorm_helper( + bsz_seq, d_model, epsilon1); + phi::fusion::FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, bsz_seq, dim_feedforward, dropout_param1); + phi::fusion::FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, d_model, dropout_param2, epsilon2); + + using U = phi::funcs::LayerNormParamType; + const U* ln1_gamma_ptr = + ln1_gamma == nullptr ? nullptr : ln1_gamma->data(); + const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data(); + const U* ln2_gamma_ptr = + ln2_gamma == nullptr ? nullptr : ln2_gamma->data(); + const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data(); + const T* linear1_bias_ptr = + linear1_bias == nullptr ? nullptr : linear1_bias->data(); + T* d_linear1_bias_ptr = + d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data(); + T* d_linear2_bias_ptr = + d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data(); + U* d_ln1_gamma_ptr = + d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data(); + U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data(); + U* d_ln2_gamma_ptr = + d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data(); + U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data(); + + phi::DenseTensor d_linear2_out, d_dropout2_out, d_residual; + d_linear2_out.Resize({bsz_seq, d_model}); + dev_ctx.template Alloc(&d_linear2_out, d_linear2_out.numel() * sizeof(T)); + d_dropout2_out.Resize({bsz_seq, d_model}); + dev_ctx.template Alloc(&d_dropout2_out, + d_dropout2_out.numel() * sizeof(T)); + + T* d_residual_ptr = nullptr; + if (add_residual) { + d_residual.Resize(d_x->dims()); + d_residual_ptr = + dev_ctx.template Alloc(&d_residual, d_residual.numel() * sizeof(T)); + } + if (pre_layer_norm) { + fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( + dev_ctx, + d_out.data(), + dropout2_mask.data(), + d_linear2_out.data(), + d_residual_ptr, + d_linear2_bias_ptr); + } else { + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + dev_ctx, + d_out.data(), + dropout2_out->data(), + dropout2_mask.data(), + ln2_gamma_ptr, + ln2_mean->data(), + ln2_variance->data(), + d_dropout2_out.data(), + d_ln2_gamma_ptr, + d_ln2_beta_ptr, + d_linear2_out.data(), + d_linear2_bias_ptr, + d_residual_ptr); } - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* linear1_weight = context.Input("Linear1Weight"); - auto* linear1_bias = context.Input("Linear1Bias"); - auto* linear2_weight = context.Input("Linear2Weight"); - auto* linear2_bias = context.Input("Linear2Bias"); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); - auto& dev_ctx = context.template device_context(); - - auto* ln1_scale = - pre_layer_norm ? context.Input("Ln1Scale") : nullptr; - auto* ln1_bias = - pre_layer_norm ? context.Input("Ln1Bias") : nullptr; - auto* ln2_scale = - !pre_layer_norm ? context.Input("Ln2Scale") : nullptr; - auto* ln2_bias = - !pre_layer_norm ? context.Input("Ln2Bias") : nullptr; - - auto* ln1_mean = - pre_layer_norm ? context.Output("Ln1Mean") : nullptr; - auto* ln1_variance = pre_layer_norm - ? context.Output("Ln1Variance") - : nullptr; - auto* ln2_mean = - !pre_layer_norm ? context.Output("Ln2Mean") : nullptr; - auto* ln2_variance = !pre_layer_norm - ? context.Output("Ln2Variance") - : nullptr; - auto* out = context.Output("Out"); - auto* dropout1_mask = context.Output("Dropout1Mask"); - auto* dropout2_mask = context.Output("Dropout2Mask"); - auto* linear1_out = context.Output("Linear1Out"); - auto* ln1_out = - pre_layer_norm ? context.Output("Ln1Out") : nullptr; - auto* dropout1_out = context.Output("Dropout1Out"); - auto* dropout2_out = context.Output("Dropout2Out"); - - const std::string act_method = context.Attr("act_method"); - - const float epsilon1 = context.Attr("ln1_epsilon"); - const float epsilon2 = context.Attr("ln2_epsilon"); - const int ring_id = context.Attr("ring_id"); - const bool add_residual = context.Attr("add_residual"); - - DropoutParam dropout_param1(context, 1); - DropoutParam dropout_param2(context, 2); - - using U = phi::funcs::LayerNormParamType; - dev_ctx.Alloc(out, out->numel() * sizeof(T)); - dev_ctx.Alloc(dropout1_mask, - dropout1_mask->numel() * sizeof(uint8_t)); - dev_ctx.Alloc(dropout2_mask, - dropout2_mask->numel() * sizeof(uint8_t)); - if (pre_layer_norm) { - dev_ctx.Alloc(ln1_mean, ln1_mean->numel() * sizeof(U)); - dev_ctx.Alloc(ln1_variance, ln1_variance->numel() * sizeof(U)); - dev_ctx.Alloc(ln1_out, ln1_out->numel() * sizeof(T)); - } else { - dev_ctx.Alloc(ln2_mean, ln2_mean->numel() * sizeof(U)); - dev_ctx.Alloc(ln2_variance, ln2_variance->numel() * sizeof(U)); - } - - dev_ctx.Alloc(linear1_out, linear1_out->numel() * sizeof(T)); - dev_ctx.Alloc(dropout1_out, dropout1_out->numel() * sizeof(T)); - dev_ctx.Alloc(dropout2_out, dropout2_out->numel() * sizeof(T)); - - auto x_dim = x->dims(); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dim), 0, false); - - auto dim = linear1_weight->dims(); - int d_model = dim[0]; - int dim_feedforward = dim[dim.size() - 1]; - int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; - - FFN(context.cuda_device_context(), - *x, - *linear1_weight, - linear1_bias, - *linear2_weight, - linear2_bias, - ln1_scale, - ln1_bias, - ln2_scale, - ln2_bias, - out, - dropout1_mask, - dropout2_mask, - ln1_mean, - ln1_variance, - ln2_mean, - ln2_variance, - linear1_out, - ln1_out, - dropout1_out, - dropout2_out, - bsz_seq, - d_model, - dim_feedforward, - act_method, - pre_layer_norm, - epsilon1, - epsilon2, - add_residual, - ring_id, - dropout_param1, - dropout_param2); + phi::DenseTensor d_dropout1_out; + d_dropout1_out.Resize({bsz_seq, dim_feedforward}); + dev_ctx.template Alloc(&d_dropout1_out, + d_dropout1_out.numel() * sizeof(T)); + MatMulGrad(dev_ctx, + d_linear2_out, + dropout1_out, + linear2_weight, + &d_dropout1_out, + d_linear2_weight); + + phi::DenseTensor d_linear1_out; + d_linear1_out.Resize({bsz_seq, dim_feedforward}); + dev_ctx.template Alloc(&d_linear1_out, d_linear1_out.numel() * sizeof(T)); + fused_act_dropout_helper.DropoutActBiasGrad(dev_ctx, + d_dropout1_out.data(), + linear1_out.data(), + linear1_bias_ptr, + dropout1_mask.data(), + d_linear1_out.data(), + d_linear1_bias_ptr, + act_method); + + if (pre_layer_norm) { + phi::DenseTensor d_ln1_out; + d_ln1_out.Resize({bsz_seq, d_model}); + dev_ctx.template Alloc(&d_ln1_out, d_ln1_out.numel() * sizeof(T)); + MatMulGrad(dev_ctx, + d_linear1_out, + *ln1_out, + linear1_weight, + &d_ln1_out, + d_linear1_weight); + // tensor model parallel + phi::fusion::AllReduce(d_ln1_out, ring_id, dev_ctx); + pre_layernorm_helper.LayerNormGrad(dev_ctx, + d_ln1_out.data(), + x.data(), + ln1_gamma_ptr, + ln1_mean->data(), + ln1_variance->data(), + d_x->data(), + d_ln1_gamma_ptr, + d_ln1_beta_ptr); + } else { + MatMulGrad( + dev_ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); + // tensor model parallel + phi::fusion::AllReduce(*d_x, ring_id, dev_ctx); } -}; - -template -class FusedFeedForwardGradKernel : public framework::OpKernel { - public: - void MatMulGrad(const phi::GPUContext& ctx, - const phi::DenseTensor& d_out, - const phi::DenseTensor& a, - const phi::DenseTensor& b, - phi::DenseTensor* d_a, - phi::DenseTensor* d_b) const { - auto blas = phi::funcs::GetBlas(ctx); - auto a_2d = FoldInitDims(a); - auto b_2d = FoldInitDims(b); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(a_2d.dims(), 0, true); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(b_2d.dims(), 0, true); - auto mat_dim_dout = - phi::funcs::CreateMatrixDescriptor(d_out.dims(), 0, false); - T alpha = static_cast(1.0); - blas.MatMul(d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T(0)); - blas.MatMul(a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T(0)); + + if (add_residual) { + // gradient accumulation + std::vector ins = {&d_residual, d_x}; + std::vector outs = {d_x}; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, phi::funcs::AddFunctor()); } +} - void FFNGrad(const phi::GPUContext& ctx, - const phi::DenseTensor& d_out, - const phi::DenseTensor& x, - const phi::DenseTensor& dropout1_mask, - const phi::DenseTensor& dropout2_mask, - const phi::DenseTensor& linear1_out, - const phi::DenseTensor* ln1_out, - const phi::DenseTensor& dropout1_out, - const phi::DenseTensor* dropout2_out, - const phi::DenseTensor& linear1_weight, - const phi::DenseTensor* linear1_bias, - const phi::DenseTensor& linear2_weight, - const phi::DenseTensor* ln1_gamma, - const phi::DenseTensor* ln1_beta, - const phi::DenseTensor* ln1_mean, - const phi::DenseTensor* ln1_variance, - const phi::DenseTensor* ln2_gamma, - const phi::DenseTensor* ln2_beta, - const phi::DenseTensor* ln2_mean, - const phi::DenseTensor* ln2_variance, - phi::DenseTensor* d_x, - phi::DenseTensor* d_linear1_weight, - phi::DenseTensor* d_linear1_bias, - phi::DenseTensor* d_linear2_weight, - phi::DenseTensor* d_linear2_bias, - phi::DenseTensor* d_ln1_gamma, - phi::DenseTensor* d_ln1_beta, - phi::DenseTensor* d_ln2_gamma, - phi::DenseTensor* d_ln2_beta, - const int bsz_seq, - const int d_model, - const int dim_feedforward, - const DropoutParam& dropout_param1, - const DropoutParam& dropout_param2, - const std::string& act_method, - const bool pre_layer_norm, - const float epsilon1, - const float epsilon2, - const bool add_residual, - const int ring_id) const { - FusedDropoutLayerNormHelper pre_layernorm_helper( - bsz_seq, d_model, epsilon1); - FusedDropoutHelper fused_act_dropout_helper( - ctx, bsz_seq, dim_feedforward, dropout_param1); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - ctx, bsz_seq, d_model, dropout_param2, epsilon2); - - using U = phi::funcs::LayerNormParamType; - const U* ln1_gamma_ptr = - ln1_gamma == nullptr ? nullptr : ln1_gamma->data(); - const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data(); - const U* ln2_gamma_ptr = - ln2_gamma == nullptr ? nullptr : ln2_gamma->data(); - const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data(); - const T* linear1_bias_ptr = - linear1_bias == nullptr ? nullptr : linear1_bias->data(); - T* d_linear1_bias_ptr = - d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data(); - T* d_linear2_bias_ptr = - d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data(); - U* d_ln1_gamma_ptr = - d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data(); - U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data(); - U* d_ln2_gamma_ptr = - d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data(); - U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data(); - - phi::DenseTensor d_linear2_out, d_dropout2_out, d_residual; - d_linear2_out.Resize({bsz_seq, d_model}); - ctx.Alloc(&d_linear2_out, d_linear2_out.numel() * sizeof(T)); - d_dropout2_out.Resize({bsz_seq, d_model}); - ctx.Alloc(&d_dropout2_out, d_dropout2_out.numel() * sizeof(T)); - - T* d_residual_ptr = nullptr; - if (add_residual) { - d_residual.Resize(d_x->dims()); - d_residual_ptr = - ctx.Alloc(&d_residual, d_residual.numel() * sizeof(T)); - } - if (pre_layer_norm) { - fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( - ctx, - d_out.data(), - dropout2_mask.data(), - d_linear2_out.data(), - d_residual_ptr, - d_linear2_bias_ptr); - } else { - fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( - ctx, - d_out.data(), - dropout2_out->data(), - dropout2_mask.data(), - ln2_gamma_ptr, - ln2_mean->data(), - ln2_variance->data(), - d_dropout2_out.data(), - d_ln2_gamma_ptr, - d_ln2_beta_ptr, - d_linear2_out.data(), - d_linear2_bias_ptr, - d_residual_ptr); - } - - phi::DenseTensor d_dropout1_out; - d_dropout1_out.Resize({bsz_seq, dim_feedforward}); - ctx.Alloc(&d_dropout1_out, d_dropout1_out.numel() * sizeof(T)); - MatMulGrad(ctx, - d_linear2_out, - dropout1_out, - linear2_weight, - &d_dropout1_out, - d_linear2_weight); - - phi::DenseTensor d_linear1_out; - d_linear1_out.Resize({bsz_seq, dim_feedforward}); - ctx.Alloc(&d_linear1_out, d_linear1_out.numel() * sizeof(T)); - fused_act_dropout_helper.DropoutActBiasGrad(ctx, - d_dropout1_out.data(), - linear1_out.data(), - linear1_bias_ptr, - dropout1_mask.data(), - d_linear1_out.data(), - d_linear1_bias_ptr, - act_method); - - if (pre_layer_norm) { - phi::DenseTensor d_ln1_out; - d_ln1_out.Resize({bsz_seq, d_model}); - ctx.Alloc(&d_ln1_out, d_ln1_out.numel() * sizeof(T)); - MatMulGrad(ctx, - d_linear1_out, - *ln1_out, - linear1_weight, - &d_ln1_out, - d_linear1_weight); - // tensor model parallel - AllReduce(d_ln1_out, ring_id, ctx); - pre_layernorm_helper.LayerNormGrad(ctx, - d_ln1_out.data(), - x.data(), - ln1_gamma_ptr, - ln1_mean->data(), - ln1_variance->data(), - d_x->data(), - d_ln1_gamma_ptr, - d_ln1_beta_ptr); - } else { - MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); - // tensor model parallel - AllReduce(*d_x, ring_id, ctx); - } - - if (add_residual) { - // gradient accumulation - std::vector ins = {&d_residual, d_x}; - std::vector outs = {d_x}; - phi::funcs::ElementwiseKernel( - ctx, ins, &outs, phi::funcs::AddFunctor()); - } +template +void FusedFeedForwardGradKernel( + const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& linear1_weight, + const paddle::optional& linear1_bias, + const DenseTensor& linear2_weight, + const DenseTensor& dropout1_mask, + const DenseTensor& dropout2_mask, + const DenseTensor& linear1_out, + const DenseTensor& dropout1_out, + const paddle::optional& dropout2_out, + const paddle::optional& ln1_scale, + const paddle::optional& ln1_bias, + const paddle::optional& ln1_out, + const paddle::optional& ln1_mean, + const paddle::optional& ln1_variance, + const paddle::optional& ln2_scale, + const paddle::optional& ln2_bias, + const paddle::optional& ln2_mean, + const paddle::optional& ln2_variance, + const paddle::optional& linear2_bias, + bool pre_layer_norm, + float ln1_epsilon, + float ln2_epsilon, + const std::string& act_method, + float dropout1_prob, + float dropout2_prob, + const std::string& dropout1_implementation, + const std::string& dropout2_implementation, + bool is_test, + bool dropout1_fix_seed, + bool dropout2_fix_seed, + int dropout1_seed_val, + int dropout2_seed_val, + bool add_residual, + int ring_id, + DenseTensor* x_grad, + DenseTensor* ln1_scale_grad, + DenseTensor* ln1_bias_grad, + DenseTensor* ln2_scale_grad, + DenseTensor* ln2_bias_grad, + DenseTensor* linear1_weight_grad, + DenseTensor* linear1_bias_grad, + DenseTensor* linear2_weight_grad, + DenseTensor* linear2_bias_grad) { + using U = phi::funcs::LayerNormParamType; + + auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr; + auto* dropout2_out_ptr = dropout2_out.get_ptr(); + auto* linear1_bias_ptr = linear1_bias.get_ptr(); + auto* ln1_mean_ptr = pre_layer_norm ? ln1_mean.get_ptr() : nullptr; + auto* ln1_variance_ptr = pre_layer_norm ? ln1_variance.get_ptr() : nullptr; + auto* ln1_scale_ptr = pre_layer_norm ? ln1_scale.get_ptr() : nullptr; + auto* ln1_bias_ptr = pre_layer_norm ? ln1_bias.get_ptr() : nullptr; + auto* ln2_mean_ptr = !pre_layer_norm ? ln2_mean.get_ptr() : nullptr; + auto* ln2_variance_ptr = !pre_layer_norm ? ln2_variance.get_ptr() : nullptr; + auto* ln2_scale_ptr = !pre_layer_norm ? ln2_scale.get_ptr() : nullptr; + auto* ln2_bias_ptr = !pre_layer_norm ? ln2_bias.get_ptr() : nullptr; + + auto* d_x = x_grad; + auto* d_ln1_scale = pre_layer_norm ? ln1_scale_grad : nullptr; + auto* d_ln1_bias = pre_layer_norm ? ln1_bias_grad : nullptr; + auto* d_ln2_scale = pre_layer_norm ? nullptr : ln2_scale_grad; + auto* d_ln2_bias = pre_layer_norm ? nullptr : ln2_bias_grad; + auto* d_linear1_weight = linear1_weight_grad; + auto* d_linear1_bias = linear1_bias_grad; + auto* d_linear2_weight = linear2_weight_grad; + auto* d_linear2_bias = linear2_bias_grad; + + bool is_upscale_in_train1 = dropout1_implementation == "upscale_in_train"; + bool is_upscale_in_train2 = dropout2_implementation == "upscale_in_train"; + + phi::fusion::DropoutParam dropout_param1(dropout1_fix_seed, + 0, + is_test, + is_upscale_in_train1, + dropout1_prob, + nullptr, + dropout1_seed_val); + phi::fusion::DropoutParam dropout_param2(dropout2_fix_seed, + 0, + is_test, + is_upscale_in_train2, + dropout2_prob, + nullptr, + dropout2_seed_val); + + dev_ctx.template Alloc(d_x, d_x->numel() * sizeof(T)); + if (d_ln1_scale) { + dev_ctx.template Alloc(d_ln1_scale, d_ln1_scale->numel() * sizeof(U)); + } + if (d_ln1_bias) { + dev_ctx.template Alloc(d_ln1_bias, d_ln1_bias->numel() * sizeof(U)); + } + if (d_ln2_scale) { + dev_ctx.template Alloc(d_ln2_scale, d_ln2_scale->numel() * sizeof(U)); + } + if (d_ln2_bias) { + dev_ctx.template Alloc(d_ln2_bias, d_ln2_bias->numel() * sizeof(U)); } + if (d_linear1_bias) { + dev_ctx.template Alloc(d_linear1_bias, + d_linear1_bias->numel() * sizeof(T)); + } + if (d_linear2_bias) { + dev_ctx.template Alloc(d_linear2_bias, + d_linear2_bias->numel() * sizeof(T)); + } + dev_ctx.template Alloc(d_linear1_weight, + d_linear1_weight->numel() * sizeof(T)); + dev_ctx.template Alloc(d_linear2_weight, + d_linear2_weight->numel() * sizeof(T)); + + auto x_dim = x.dims(); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( + phi::RowMatrixFromVector(x_dim), 0, false); + + auto linear1_weight_dim = linear1_weight.dims(); + int d_model = linear1_weight_dim[0]; + int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + + FFNGrad(dev_ctx, + out_grad, + x, + dropout1_mask, + dropout2_mask, + linear1_out, + ln1_out_ptr, + dropout1_out, + dropout2_out_ptr, + linear1_weight, + linear1_bias_ptr, + linear2_weight, + ln1_scale_ptr, + ln1_bias_ptr, + ln1_mean_ptr, + ln1_variance_ptr, + ln2_scale_ptr, + ln2_bias_ptr, + ln2_mean_ptr, + ln2_variance_ptr, + d_x, + d_linear1_weight, + d_linear1_bias, + d_linear2_weight, + d_linear2_bias, + d_ln1_scale, + d_ln1_bias, + d_ln2_scale, + d_ln2_bias, + bsz_seq, + d_model, + dim_feedforward, + dropout_param1, + dropout_param2, + act_method, + pre_layer_norm, + ln1_epsilon, + ln2_epsilon, + add_residual, + ring_id); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_feedforward, + GPU, + ALL_LAYOUT, + phi::fusion::FusedFeedForwardKernel, + float, + double, + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UINT8); + kernel->OutputAt(2).SetDataType(phi::DataType::UINT8); + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + } +} - void Compute(const framework::ExecutionContext& context) const override { - using U = phi::funcs::LayerNormParamType; - auto& dev_ctx = context.template device_context(); - auto d_out = - *context.Input(framework::GradVarName("Out")); - auto x = *context.Input("X"); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); - auto dropout1_mask = *context.Input("Dropout1Mask"); - auto dropout2_mask = *context.Input("Dropout2Mask"); - auto linear1_out = *context.Input("Linear1Out"); - auto* ln1_out = - pre_layer_norm ? context.Input("Ln1Out") : nullptr; - auto dropout1_out = *context.Input("Dropout1Out"); - auto* dropout2_out = context.Input("Dropout2Out"); - auto linear1_weight = *context.Input("Linear1Weight"); - auto* linear1_bias = context.Input("Linear1Bias"); - auto linear2_weight = *context.Input("Linear2Weight"); - auto* ln1_mean = - pre_layer_norm ? context.Input("Ln1Mean") : nullptr; - auto* ln1_variance = pre_layer_norm - ? context.Input("Ln1Variance") - : nullptr; - auto* ln1_scale = - pre_layer_norm ? context.Input("Ln1Scale") : nullptr; - auto* ln1_bias = - pre_layer_norm ? context.Input("Ln1Bias") : nullptr; - auto* ln2_mean = - !pre_layer_norm ? context.Input("Ln2Mean") : nullptr; - auto* ln2_variance = !pre_layer_norm - ? context.Input("Ln2Variance") - : nullptr; - auto* ln2_scale = - !pre_layer_norm ? context.Input("Ln2Scale") : nullptr; - auto* ln2_bias = - !pre_layer_norm ? context.Input("Ln2Bias") : nullptr; - - auto* d_x = context.Output(framework::GradVarName("X")); - auto* d_ln1_scale = pre_layer_norm ? context.Output( - framework::GradVarName("Ln1Scale")) - : nullptr; - auto* d_ln1_bias = pre_layer_norm ? context.Output( - framework::GradVarName("Ln1Bias")) - : nullptr; - auto* d_ln2_scale = pre_layer_norm - ? nullptr - : context.Output( - framework::GradVarName("Ln2Scale")); - auto* d_ln2_bias = pre_layer_norm ? nullptr - : context.Output( - framework::GradVarName("Ln2Bias")); - auto* d_linear1_weight = context.Output( - framework::GradVarName("Linear1Weight")); - auto* d_linear1_bias = - context.Output(framework::GradVarName("Linear1Bias")); - auto* d_linear2_weight = context.Output( - framework::GradVarName("Linear2Weight")); - auto* d_linear2_bias = - context.Output(framework::GradVarName("Linear2Bias")); - - const float epsilon1 = context.Attr("ln1_epsilon"); - const float epsilon2 = context.Attr("ln2_epsilon"); - const bool add_residual = context.Attr("add_residual"); - const int ring_id = context.Attr("ring_id"); - const std::string act_method = context.Attr("act_method"); - DropoutParam dropout_param1(context, 1); - DropoutParam dropout_param2(context, 2); - - dev_ctx.Alloc(d_x, d_x->numel() * sizeof(T)); - if (d_ln1_scale) { - dev_ctx.Alloc(d_ln1_scale, d_ln1_scale->numel() * sizeof(U)); - } - if (d_ln1_bias) { - dev_ctx.Alloc(d_ln1_bias, d_ln1_bias->numel() * sizeof(U)); - } - if (d_ln2_scale) { - dev_ctx.Alloc(d_ln2_scale, d_ln2_scale->numel() * sizeof(U)); - } - if (d_ln2_bias) { - dev_ctx.Alloc(d_ln2_bias, d_ln2_bias->numel() * sizeof(U)); - } - if (d_linear1_bias) { - dev_ctx.Alloc(d_linear1_bias, d_linear1_bias->numel() * sizeof(T)); - } - if (d_linear2_bias) { - dev_ctx.Alloc(d_linear2_bias, d_linear2_bias->numel() * sizeof(T)); - } - dev_ctx.Alloc(d_linear1_weight, d_linear1_weight->numel() * sizeof(T)); - dev_ctx.Alloc(d_linear2_weight, d_linear2_weight->numel() * sizeof(T)); - - auto x_dim = x.dims(); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dim), 0, false); - - auto linear1_weight_dim = linear1_weight.dims(); - int d_model = linear1_weight_dim[0]; - int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; - int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; - - FFNGrad(context.cuda_device_context(), - d_out, - x, - dropout1_mask, - dropout2_mask, - linear1_out, - ln1_out, - dropout1_out, - dropout2_out, - linear1_weight, - linear1_bias, - linear2_weight, - ln1_scale, - ln1_bias, - ln1_mean, - ln1_variance, - ln2_scale, - ln2_bias, - ln2_mean, - ln2_variance, - d_x, - d_linear1_weight, - d_linear1_bias, - d_linear2_weight, - d_linear2_bias, - d_ln1_scale, - d_ln1_bias, - d_ln2_scale, - d_ln2_bias, - bsz_seq, - d_model, - dim_feedforward, - dropout_param1, - dropout_param2, - act_method, - pre_layer_norm, - epsilon1, - epsilon2, - add_residual, - ring_id); +PD_REGISTER_KERNEL(fused_feedforward_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedFeedForwardGradKernel, + float, + double, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(fused_feedforward, - GPU, - ALL_LAYOUT, - ops::FusedFeedForwardKernel, - float, - double, - plat::float16) {} -PD_REGISTER_STRUCT_KERNEL(fused_feedforward_grad, - GPU, - ALL_LAYOUT, - ops::FusedFeedForwardGradKernel, - float, - double, - plat::float16) {} +} diff --git a/paddle/phi/kernels/funcs/functors.h b/paddle/phi/kernels/funcs/functors.h index 3c7ae5ed09af3ca3ac177a166d81857f9df3f33d..ce67f7f167199c4ce19c50c13a175174e93562ed 100644 --- a/paddle/phi/kernels/funcs/functors.h +++ b/paddle/phi/kernels/funcs/functors.h @@ -25,11 +25,6 @@ struct MulGradFunctor { inline HOSTDEVICE T Dy(T x, T y) { return x; } }; -template -struct MaxFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; } -}; - template struct AddGradFunctor { inline HOSTDEVICE T Dx(T x, T y) { return static_cast(1.); } diff --git a/paddle/phi/kernels/fused_feedforward_grad_kernel.h b/paddle/phi/kernels/fused_feedforward_grad_kernel.h index 9eee46a83987ed6aa626b60a2385f3a83f1f0b09..79b175d45ee89f745cab7054be316c52c604f947 100644 --- a/paddle/phi/kernels/fused_feedforward_grad_kernel.h +++ b/paddle/phi/kernels/fused_feedforward_grad_kernel.h @@ -24,13 +24,13 @@ void FusedFeedForwardGradKernel( const DenseTensor& out_grad, const DenseTensor& x, const DenseTensor& linear1_weight, - const DenseTensor& linear1_bias, + const paddle::optional& linear1_bias, const DenseTensor& linear2_weight, const DenseTensor& dropout1_mask, const DenseTensor& dropout2_mask, const DenseTensor& linear1_out, const DenseTensor& dropout1_out, - const DenseTensor& dropout2_out, + const paddle::optional& dropout2_out, const paddle::optional& ln1_scale, const paddle::optional& ln1_bias, const paddle::optional& ln1_out, diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h index 57cbf678b92b3e8f448686f70785a358e567e010..c73a35d2265ce74938f0998cdd77605cf0477624 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h @@ -30,6 +30,8 @@ #include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h" #include "paddle/phi/kernels/layer_norm_kernel.h" +PHI_DECLARE_bool(use_fast_math); + namespace phi { namespace fusion { @@ -292,21 +294,22 @@ class FusedDropoutHelper { T* d_bias, const std::string& act_method) { if (act_method == "gelu") { - phi::funcs::GeluGradFunctor gelu_grad; - phi::fusion:: - LaunchDropoutActBiasGrad>( - gelu_grad, - dout, - mask, - src, - bias, - dropout_param_.dropout_prob, - dropout_param_.is_upscale_in_train, - rows_, - cols_, - d_src, - d_bias, - ctx); + phi::fusion::GeluGradFunctor gelu_grad; + phi::fusion::LaunchDropoutActBiasGrad>( + gelu_grad, + dout, + mask, + src, + bias, + dropout_param_.dropout_prob, + dropout_param_.is_upscale_in_train, + rows_, + cols_, + d_src, + d_bias, + ctx); } else if (act_method == "relu") { phi::funcs::ReluGradFunctor relu_grad; phi::fusion:: diff --git a/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc index 6798df360de198145ec602c9ec671bbe3c631677..3448efca7c3ab1610e60ef849f52dddd6cc9f01d 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc @@ -366,13 +366,13 @@ void FusedFeedForwardGradKernel( const DenseTensor& out_grad, const DenseTensor& x, const DenseTensor& linear1_weight, - const DenseTensor& linear1_bias, + const paddle::optional& linear1_bias, const DenseTensor& linear2_weight, const DenseTensor& dropout1_mask, const DenseTensor& dropout2_mask, const DenseTensor& linear1_out, const DenseTensor& dropout1_out, - const DenseTensor& dropout2_out, + const paddle::optional& dropout2_out, const paddle::optional& ln1_scale, const paddle::optional& ln1_bias, const paddle::optional& ln1_out, @@ -417,7 +417,7 @@ void FusedFeedForwardGradKernel( auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr; auto* dropout1_out_ptr = &dropout1_out; - auto* dropout2_out_ptr = &dropout2_out; + auto* dropout2_out_ptr = dropout2_out.get_ptr(); auto* linear1_weight_ptr = &linear1_weight; auto* linear2_weight_ptr = &linear2_weight; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a383e1b0c062429e0190257e5b3c706bde5bc8fb..e7d11ed8b16d613eeb65fe458e2d7db71ec0a939 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1162,6 +1162,8 @@ set(STATIC_BUILD_TESTS test_fetch_lod_tensor_array test_fused_attention_op test_fused_attention_op_api + test_fused_feedforward_op + test_fused_feedforward_pass test_imperative_optimizer test_lamb_op test_layer_norm_op @@ -1191,6 +1193,8 @@ set(STATIC_BUILD_TESTS if(NOT WITH_GPU) list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op) list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api) + list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op) + list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op_pass) endif() foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})