diff --git a/paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc b/paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc deleted file mode 100644 index 4b9ba95143345685057feaeed6050900c5a035c0..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fused/fused_feedforward_op_xpu.cc +++ /dev/null @@ -1,832 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/operators/matmul_v2_op.h" -#include "paddle/fluid/operators/xpu_api_wrapper.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -#include "paddle/fluid/operators/fused/xpu_fused_common_function.h" - -namespace paddle { -namespace operators { - -template -class FusedFeedForwardXPUKernel : public framework::OpKernel { - using XPUTypeT = typename XPUTypeTrait::Type; - - public: - void FFN(const phi::XPUContext& dev_ctx, - const phi::DenseTensor* x, - const phi::DenseTensor* linear1_weight, - const phi::DenseTensor* linear1_bias, - const phi::DenseTensor* linear2_weight, - const phi::DenseTensor* linear2_bias, - const phi::DenseTensor* ln_scale, - const phi::DenseTensor* ln_bias, - phi::DenseTensor* out, - phi::DenseTensor* dropout1_mask, - phi::DenseTensor* dropout2_mask, - phi::DenseTensor* ln_mean, - phi::DenseTensor* ln_variance, - phi::DenseTensor* linear1_out, - phi::DenseTensor* ln1_out, - phi::DenseTensor* dropout1_out, - phi::DenseTensor* dropout2_out, - const int bsz_seq, - const int d_model, - const int dim_feedforward, - const std::string& act_method, - const bool pre_layer_norm, - const float epsilon1, - const float epsilon2, - const XPUDropoutParam& dropout_param1, - const XPUDropoutParam& dropout_param2, - int ring_id) const { - xpu::Context* xpu_ctx = dev_ctx.x_context(); - xpu::ctx_guard RAII_GUARD(xpu_ctx); - - int r = xpu::SUCCESS; - - const XPUTypeT* x_ptr = reinterpret_cast(x->data()); - const XPUTypeT* residual_ptr = x_ptr; - const XPUTypeT* linear1_weight_ptr = - reinterpret_cast(linear1_weight->data()); - const XPUTypeT* linear1_bias_ptr = - reinterpret_cast(linear1_bias->data()); - const XPUTypeT* linear2_weight_ptr = - reinterpret_cast(linear2_weight->data()); - const XPUTypeT* linear2_bias_ptr = - reinterpret_cast(linear2_bias->data()); - - const float* ln_scale_ptr = ln_scale->data(); - - const float* ln_bias_ptr = ln_bias->data(); - - // out - XPUTypeT* out_ptr = reinterpret_cast(out->data()); - XPUTypeT* linear1_out_ptr = - reinterpret_cast(linear1_out->data()); - XPUTypeT* dropout1_mask_ptr = - reinterpret_cast(dropout1_mask->data()); - XPUTypeT* dropout2_mask_ptr = - reinterpret_cast(dropout2_mask->data()); - float* ln_mean_ptr = ln_mean->data(); - float* ln_variance_ptr = ln_variance->data(); - - XPUTypeT* dropout1_out_ptr = - reinterpret_cast(dropout1_out->data()); - XPUTypeT* dropout2_out_ptr = - reinterpret_cast(dropout2_out->data()); - - size_t l3_total_size = xpu_ctx->_l3_mgr.get_size(); - XPUTypeT* linear2_before_tmp_ptr = NULL; // dim_feedforward * bsz_seq - XPUTypeT* linear2_after_tmp_ptr = NULL; // d_model * bsz_seq - if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T)) { - XPUTypeT* l3_ptr = - RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(l3_ptr); - linear2_before_tmp_ptr = linear2_after_tmp_ptr = l3_ptr; - } else if ((l3_total_size < dim_feedforward * bsz_seq * sizeof(T)) && - (l3_total_size >= d_model * bsz_seq * sizeof(T))) { - XPUTypeT* l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(l3_ptr); - linear2_after_tmp_ptr = l3_ptr; - linear2_before_tmp_ptr = - RAII_GUARD.alloc(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(linear2_before_tmp_ptr); - - } else { - XPUTypeT* gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(gm_ptr); - linear2_before_tmp_ptr = linear2_after_tmp_ptr = gm_ptr; - } - - // layernorm - if (pre_layer_norm) { - XPUTypeT* ln1_out_ptr = reinterpret_cast(ln1_out->data()); - r = xpu::layer_norm(xpu_ctx, - x_ptr, - ln1_out_ptr, - bsz_seq, - d_model, - epsilon1, - ln_scale_ptr, - ln_bias_ptr, - ln_mean_ptr, - ln_variance_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm "); - x_ptr = ln1_out_ptr; - } - - // fc - phi::XpuFcInfo linear1_fc_info; - linear1_fc_info.InitFcInfo(0, - bsz_seq, - dim_feedforward, - d_model, - false, - false, - nullptr, - nullptr, - nullptr); - phi::MatMulXPUFunction(xpu_ctx, - x_ptr, - linear1_weight_ptr, - linear2_before_tmp_ptr, - linear1_fc_info, - 1.0f); - - // bias - r = xpu::broadcast_add(xpu_ctx, - linear2_before_tmp_ptr, - linear1_bias_ptr, - linear1_out_ptr, - {bsz_seq, dim_feedforward}, - {dim_feedforward}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - - // act - if (act_method == "gelu") { - r = xpu::gelu(xpu_ctx, - linear1_out_ptr, - linear2_before_tmp_ptr, - linear1_out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); - } else if (act_method == "relu") { - r = xpu::relu(xpu_ctx, - linear1_out_ptr, - linear2_before_tmp_ptr, - linear1_out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently only supports gelu or relu activation functions!")); - } - - // dropout1 - Dropout(xpu_ctx, - linear2_before_tmp_ptr, - dropout1_mask_ptr, - dropout1_out_ptr, - dropout_param1, - dropout1_out->numel()); - - // fc - phi::XpuFcInfo linear2_fc_info; - linear2_fc_info.InitFcInfo(0, - bsz_seq, - d_model, - dim_feedforward, - false, - false, - nullptr, - nullptr, - nullptr); - phi::MatMulXPUFunction(xpu_ctx, - dropout1_out_ptr, - linear2_weight_ptr, - dropout2_out_ptr, - linear2_fc_info, - 1.0f); - - // bias - r = xpu::broadcast_add(xpu_ctx, - dropout2_out_ptr, - linear2_bias_ptr, - dropout2_out_ptr, - {bsz_seq, d_model}, - {d_model}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - - // dropout2 - Dropout(xpu_ctx, - dropout2_out_ptr, - dropout2_mask_ptr, - dropout2_out_ptr, - dropout_param2, - dropout2_out->numel()); - - // residual_ptr + dropout_out - XPUTypeT* residual_add_out_ptr = out_ptr; - if (pre_layer_norm == false) { - residual_add_out_ptr = dropout2_out_ptr; - } - r = xpu::broadcast_add(xpu_ctx, - residual_ptr, - dropout2_out_ptr, - residual_add_out_ptr, - {bsz_seq, d_model}, - {bsz_seq, d_model}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - - if (pre_layer_norm == false) { - r = xpu::layer_norm(xpu_ctx, - residual_add_out_ptr, - out_ptr, - bsz_seq, - d_model, - epsilon2, - ln_scale_ptr, - ln_bias_ptr, - ln_mean_ptr, - ln_variance_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); - } - } - - void Compute(const framework::ExecutionContext& context) const override { - auto place = context.GetPlace(); - - auto* x = context.Input("X"); - - auto* linear1_weight = context.Input("Linear1Weight"); - auto* linear1_bias = context.Input("Linear1Bias"); - auto* linear2_weight = context.Input("Linear2Weight"); - auto* linear2_bias = context.Input("Linear2Bias"); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); - - const phi::DenseTensor* ln_scale = nullptr; - const phi::DenseTensor* ln_bias = nullptr; - phi::DenseTensor* ln_mean = nullptr; - phi::DenseTensor* ln_variance = nullptr; - phi::DenseTensor* ln1_out = nullptr; - - if (pre_layer_norm) { - ln_scale = context.Input("Ln1Scale"); - ln_bias = context.Input("Ln1Bias"); - ln_mean = context.Output("Ln1Mean"); - ln_variance = context.Output("Ln1Variance"); - ln1_out = context.Output("Ln1Out"); - ln1_out->mutable_data(place); - } else { - ln_scale = context.Input("Ln2Scale"); - ln_bias = context.Input("Ln2Bias"); - ln_mean = context.Output("Ln2Mean"); - ln_variance = context.Output("Ln2Variance"); - } - - auto* out = context.Output("Out"); - auto* dropout1_mask = context.Output("Dropout1Mask"); - auto* dropout2_mask = context.Output("Dropout2Mask"); - auto* linear1_out = context.Output("Linear1Out"); - - auto* dropout1_out = context.Output("Dropout1Out"); - auto* dropout2_out = context.Output("Dropout2Out"); - - const std::string act_method = context.Attr("act_method"); - - const int ring_id = context.Attr("ring_id"); - const float epsilon1 = context.Attr("ln1_epsilon"); - const float epsilon2 = context.Attr("ln2_epsilon"); - XPUDropoutParam dropout_param1; - dropout_param1.initXPUDropoutParam(context, 1); - XPUDropoutParam dropout_param2; - dropout_param2.initXPUDropoutParam(context, 2); - - ln_mean->mutable_data(place); - ln_variance->mutable_data(place); - out->mutable_data(place); - dropout1_mask->mutable_data(place); - dropout2_mask->mutable_data(place); - dropout1_out->mutable_data(place); - dropout2_out->mutable_data(place); - linear1_out->mutable_data(place); - - auto x_dim = x->dims(); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dim), 0, false); - - auto dim = linear1_weight->dims(); - int d_model = dim[0]; - int dim_feedforward = dim[dim.size() - 1]; - int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; - - auto& dev_ctx = context.template device_context(); - FFN(dev_ctx, - x, - linear1_weight, - linear1_bias, - linear2_weight, - linear2_bias, - ln_scale, - ln_bias, - out, - dropout1_mask, - dropout2_mask, - ln_mean, - ln_variance, - linear1_out, - ln1_out, - dropout1_out, - dropout2_out, - bsz_seq, - d_model, - dim_feedforward, - act_method, - pre_layer_norm, - epsilon1, - epsilon2, - dropout_param1, - dropout_param2, - ring_id); - } -}; - -template -class FusedFeedForwardGradXPUKernel : public framework::OpKernel { - using XPUTypeT = typename XPUTypeTrait::Type; - - public: - void FFNGrad(const phi::XPUContext& dev_ctx, - const phi::DenseTensor* d_out, - const phi::DenseTensor* x, - const phi::DenseTensor* dropout1_mask, - const phi::DenseTensor* dropout2_mask, - const phi::DenseTensor* linear1_out, - const phi::DenseTensor* ln1_out, - const phi::DenseTensor* dropout1_out, - const phi::DenseTensor* dropout2_out, - const phi::DenseTensor* linear1_weight, - const phi::DenseTensor* linear2_weight, - const phi::DenseTensor* ln_scale, - const phi::DenseTensor* ln_mean, - const phi::DenseTensor* ln_variance, - phi::DenseTensor* d_x, - phi::DenseTensor* d_linear1_weight, - phi::DenseTensor* d_linear1_bias, - phi::DenseTensor* d_linear2_weight, - phi::DenseTensor* d_linear2_bias, - phi::DenseTensor* d_ln_scale, - phi::DenseTensor* d_ln_bias, - const int bsz_seq, - const int d_model, - const int dim_feedforward, - const XPUDropoutParam& dropout_param1, - const XPUDropoutParam& dropout_param2, - const std::string& act_method, - const bool pre_layer_norm, - const float epsilon, - const int ring_id) const { - xpu::Context* xpu_ctx = dev_ctx.x_context(); - xpu::ctx_guard RAII_GUARD(xpu_ctx); - int r = xpu::SUCCESS; - - // inputs ptr - const XPUTypeT* d_out_ptr = - reinterpret_cast(d_out->data()); - const XPUTypeT* x_ptr = reinterpret_cast(x->data()); - const XPUTypeT* dropout1_mask_ptr = - reinterpret_cast(dropout1_mask->data()); - const XPUTypeT* dropout2_mask_ptr = - reinterpret_cast(dropout2_mask->data()); - const XPUTypeT* linear1_out_ptr = - reinterpret_cast(linear1_out->data()); - const XPUTypeT* dropout1_out_ptr = - reinterpret_cast(dropout1_out->data()); - const XPUTypeT* linear1_weight_ptr = - reinterpret_cast(linear1_weight->data()); - const XPUTypeT* linear2_weight_ptr = - reinterpret_cast(linear2_weight->data()); - const float* ln_scale_ptr = ln_scale->data(); - - const float* ln_mean_ptr = ln_mean->data(); - const float* ln_variance_ptr = ln_variance->data(); - // outputs ptr - XPUTypeT* d_x_ptr = reinterpret_cast(d_x->data()); - XPUTypeT* d_linear1_weight_ptr = - reinterpret_cast(d_linear1_weight->data()); - XPUTypeT* d_linear1_bias_ptr = - reinterpret_cast(d_linear1_bias->data()); - XPUTypeT* d_linear2_weight_ptr = - reinterpret_cast(d_linear2_weight->data()); - XPUTypeT* d_linear2_bias_ptr = - reinterpret_cast(d_linear2_bias->data()); - float* d_ln_scale_ptr = d_ln_scale->data(); - float* d_ln_bias_ptr = d_ln_bias->data(); - - size_t l3_total_size = xpu_ctx->_l3_mgr.get_size(); - - XPUTypeT* big_tmp_l3_ptr = NULL; // dim_feedforward * bsz_seq - XPUTypeT* small_tmp_l3_ptr = NULL; // d_model * bsz_seq - XPUTypeT* big_tmp_gm_ptr = NULL; // dim_feedforward * bsz_seq - XPUTypeT* small_tmp_gm_ptr = NULL; // d_model * bsz_seq - - XPUTypeT* d_layernorm_out_ptr = NULL; // dx9 - XPUTypeT* d_dropout2_out_ptr = NULL; // dx7 - - XPUTypeT* d_linear2_out_ptr = NULL; // dx5 - XPUTypeT* d_dropout1_out_ptr = NULL; // dx4 - XPUTypeT* d_act_out_ptr = NULL; // dx3 - - XPUTypeT* d_linear1_out_ptr = NULL; // dx1 - - const XPUTypeT* d_residual_ptr = d_out_ptr; - - if (l3_total_size >= (dim_feedforward * bsz_seq * sizeof(T) + - d_model * bsz_seq * sizeof(T))) { - big_tmp_l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_l3_ptr); - small_tmp_l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_l3_ptr); - d_layernorm_out_ptr = small_tmp_l3_ptr; - d_dropout2_out_ptr = small_tmp_l3_ptr; - d_linear2_out_ptr = big_tmp_l3_ptr; - d_dropout1_out_ptr = big_tmp_l3_ptr; - d_act_out_ptr = big_tmp_l3_ptr; - d_linear1_out_ptr = small_tmp_l3_ptr; - } else if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T)) { - big_tmp_l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_l3_ptr); - small_tmp_l3_ptr = big_tmp_l3_ptr; - big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); - small_tmp_gm_ptr = RAII_GUARD.alloc(d_model * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_gm_ptr); - - d_layernorm_out_ptr = small_tmp_l3_ptr; - d_dropout2_out_ptr = small_tmp_gm_ptr; - d_linear2_out_ptr = big_tmp_l3_ptr; - d_dropout1_out_ptr = big_tmp_l3_ptr; - d_act_out_ptr = big_tmp_gm_ptr; - d_linear1_out_ptr = small_tmp_l3_ptr; - - } else if (l3_total_size >= d_model * bsz_seq * sizeof(T)) { - big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); - small_tmp_l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_l3_ptr); - - d_layernorm_out_ptr = small_tmp_l3_ptr; - d_dropout2_out_ptr = small_tmp_l3_ptr; - d_linear2_out_ptr = big_tmp_gm_ptr; - d_dropout1_out_ptr = big_tmp_gm_ptr; - d_act_out_ptr = big_tmp_gm_ptr; - d_linear1_out_ptr = small_tmp_l3_ptr; - } else { - big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); - small_tmp_gm_ptr = RAII_GUARD.alloc(d_model * bsz_seq); - PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_gm_ptr); - d_layernorm_out_ptr = small_tmp_gm_ptr; - d_dropout2_out_ptr = small_tmp_gm_ptr; - d_linear2_out_ptr = big_tmp_gm_ptr; - d_dropout1_out_ptr = big_tmp_gm_ptr; - d_act_out_ptr = big_tmp_gm_ptr; - d_linear1_out_ptr = small_tmp_gm_ptr; - } - - if (pre_layer_norm == false) { - const XPUTypeT* dropout2_out_ptr = - reinterpret_cast(dropout2_out->data()); - r = xpu::layer_norm_grad(xpu_ctx, - dropout2_out_ptr, - d_out_ptr, - d_layernorm_out_ptr, - bsz_seq, - d_model, - epsilon, - ln_scale_ptr, - ln_mean_ptr, - ln_variance_ptr, - d_ln_scale_ptr, - d_ln_bias_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); - d_residual_ptr = d_layernorm_out_ptr; - } - DropoutGrad(xpu_ctx, - d_residual_ptr, - dropout2_mask_ptr, - d_dropout2_out_ptr, - dropout_param2, - bsz_seq * d_model); - // linear_grad2 - r = xpu::reduce_sum(xpu_ctx, - d_dropout2_out_ptr, - d_linear2_bias_ptr, - {bsz_seq, d_model}, - {0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); - - phi::XpuFcInfo linear2_fc_info; - linear2_fc_info.InitFcInfo(0, - bsz_seq, - d_model, - dim_feedforward, - false, - false, - nullptr, - nullptr, - nullptr); - - const XPUTypeT* a_1 = reinterpret_cast(NULL); - const XPUTypeT* b_1 = reinterpret_cast(NULL); - const XPUTypeT* a_2 = reinterpret_cast(NULL); - const XPUTypeT* b_2 = reinterpret_cast(NULL); - XPUTypeT* c_1 = d_linear2_out_ptr; - XPUTypeT* c_2 = d_linear2_weight_ptr; - phi::XpuFcInfo info_d_dropout1; - phi::XpuFcInfo info_dw2; - - std::tuple - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - linear2_fc_info, - false, - false, - dropout1_out_ptr, - linear2_weight_ptr, - d_dropout2_out_ptr); - - std::tie(info_d_dropout1, info_dw2, a_1, b_1, a_2, b_2) = fc_info; - - // if l3_total_size >= dim_feedforward * bsz_seq * sizeof(T), first transpos - if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T) && - info_dw2.trans_x) { - r = xpu::transpose(xpu_ctx, - dropout1_out_ptr, - big_tmp_l3_ptr, - {bsz_seq, dim_feedforward}, - {1, 0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - a_2 = big_tmp_l3_ptr; - info_dw2.trans_x = !info_dw2.trans_x; - info_dw2.stride_x = info_dw2.k; - } - - phi::MatMulXPUFunction( - xpu_ctx, a_1, b_1, c_1, info_d_dropout1, 1.0f, true); - - phi::MatMulXPUFunction( - xpu_ctx, a_2, b_2, c_2, info_dw2, 1.0f, true); - - // dropout_grad1 - DropoutGrad(xpu_ctx, - d_linear2_out_ptr, - dropout1_mask_ptr, - d_dropout1_out_ptr, - dropout_param1, - bsz_seq * dim_feedforward); - - // act_grad - if (act_method == "gelu") { - r = xpu::gelu_grad(xpu_ctx, - linear1_out_ptr, - linear1_out_ptr, - d_dropout1_out_ptr, - d_act_out_ptr, - bsz_seq * dim_feedforward); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); - } else if (act_method == "relu") { - r = xpu::relu_grad(xpu_ctx, - linear1_out_ptr, - linear1_out_ptr, - d_dropout1_out_ptr, - d_act_out_ptr, - bsz_seq * dim_feedforward); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad"); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Currently only supports gelu or relu activation functions!")); - } - - // linear1_grad - r = xpu::reduce_sum(xpu_ctx, - d_act_out_ptr, - d_linear1_bias_ptr, - {bsz_seq, dim_feedforward}, - {0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); - - phi::XpuFcInfo linear1_fc_info; - linear1_fc_info.InitFcInfo(0, - bsz_seq, - dim_feedforward, - d_model, - false, - false, - nullptr, - nullptr, - nullptr); - - a_1 = reinterpret_cast(NULL); - b_1 = reinterpret_cast(NULL); - a_2 = reinterpret_cast(NULL); - b_2 = reinterpret_cast(NULL); - - c_1 = (pre_layer_norm == true ? d_linear1_out_ptr : d_x_ptr); - c_2 = d_linear1_weight_ptr; - phi::XpuFcInfo info_dx; - phi::XpuFcInfo info_dw1; - - const XPUTypeT* linear1_x_ptr = - (pre_layer_norm == true - ? reinterpret_cast(ln1_out->data()) - : x_ptr); - - if (l3_total_size >= d_model * bsz_seq * sizeof(T) && info_dw1.trans_x) { - r = xpu::transpose( - xpu_ctx, linear1_x_ptr, small_tmp_l3_ptr, {bsz_seq, d_model}, {1, 0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - a_2 = small_tmp_l3_ptr; - info_dw1.trans_x = !info_dw1.trans_x; - info_dw1.stride_x = info_dw1.k; - } - - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - linear1_fc_info, - false, - false, - linear1_x_ptr, - linear1_weight_ptr, - d_act_out_ptr); - - std::tie(info_dx, info_dw1, a_1, b_1, a_2, b_2) = fc_info; - - phi::MatMulXPUFunction( - xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f, true); - - phi::MatMulXPUFunction( - xpu_ctx, a_2, b_2, c_2, info_dw1, 1.0f, true); - - if (pre_layer_norm) { - r = xpu::layer_norm_grad(xpu_ctx, - x_ptr, - c_1, - c_1, - bsz_seq, - d_model, - epsilon, - ln_scale_ptr, - ln_mean_ptr, - ln_variance_ptr, - d_ln_scale_ptr, - d_ln_bias_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); - } - - r = xpu::add(xpu_ctx, c_1, d_residual_ptr, d_x_ptr, d_model * bsz_seq); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); - } - - void Compute(const framework::ExecutionContext& context) const override { - auto place = context.GetPlace(); - const bool pre_layer_norm = context.Attr("pre_layer_norm"); - // inputs - auto* d_out = - context.Input(framework::GradVarName("Out")); - auto* x = context.Input("X"); - - auto* dropout1_mask = context.Input("Dropout1Mask"); - auto* dropout2_mask = context.Input("Dropout2Mask"); - auto* linear1_out = context.Input("Linear1Out"); - auto* ln1_out = - pre_layer_norm ? context.Input("Ln1Out") : nullptr; - - auto* dropout1_out = context.Input("Dropout1Out"); - auto* dropout2_out = context.Input("Dropout2Out"); - auto* linear1_weight = context.Input("Linear1Weight"); - auto* linear2_weight = context.Input("Linear2Weight"); - - const phi::DenseTensor* ln_mean = nullptr; - const phi::DenseTensor* ln_variance = nullptr; - const phi::DenseTensor* ln_scale = nullptr; - - if (pre_layer_norm) { - ln_mean = context.Input("Ln1Mean"); - ln_variance = context.Input("Ln1Variance"); - ln_scale = context.Input("Ln1Scale"); - } else { - ln_mean = context.Input("Ln2Mean"); - ln_variance = context.Input("Ln2Variance"); - ln_scale = context.Input("Ln2Scale"); - } - - // output - auto* d_x = context.Output(framework::GradVarName("X")); - - phi::DenseTensor* d_ln_scale = nullptr; - phi::DenseTensor* d_ln_bias = nullptr; - - if (pre_layer_norm) { - d_ln_scale = - context.Output(framework::GradVarName("Ln1Scale")); - d_ln_bias = - context.Output(framework::GradVarName("Ln1Bias")); - } else { - d_ln_scale = - context.Output(framework::GradVarName("Ln2Scale")); - d_ln_bias = - context.Output(framework::GradVarName("Ln2Bias")); - } - - auto* d_linear1_weight = context.Output( - framework::GradVarName("Linear1Weight")); - auto* d_linear1_bias = - context.Output(framework::GradVarName("Linear1Bias")); - auto* d_linear2_weight = context.Output( - framework::GradVarName("Linear2Weight")); - auto* d_linear2_bias = - context.Output(framework::GradVarName("Linear2Bias")); - - float epsilon = 0.0f; - if (pre_layer_norm) { - epsilon = context.Attr("ln1_epsilon"); - } else { - epsilon = context.Attr("ln2_epsilon"); - } - - const std::string act_method = context.Attr("act_method"); - - XPUDropoutParam dropout_param1(context, 1); - XPUDropoutParam dropout_param2(context, 2); - - const int ring_id = context.Attr("ring_id"); - - d_x->mutable_data(place); - d_ln_scale->mutable_data(place); - d_ln_bias->mutable_data(place); - d_linear1_bias->mutable_data(place); - d_linear2_bias->mutable_data(place); - d_linear1_weight->mutable_data(place); - d_linear2_weight->mutable_data(place); - - auto x_dim = x->dims(); - auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x_dim), 0, false); - - auto linear1_weight_dim = linear1_weight->dims(); - int d_model = linear1_weight_dim[0]; - int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; - int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; - auto& dev_ctx = context.template device_context(); - - FFNGrad(dev_ctx, - d_out, - x, - dropout1_mask, - dropout2_mask, - linear1_out, - ln1_out, - dropout1_out, - dropout2_out, - linear1_weight, - linear2_weight, - ln_scale, - ln_mean, - ln_variance, - d_x, - d_linear1_weight, - d_linear1_bias, - d_linear2_weight, - d_linear2_bias, - d_ln_scale, - d_ln_bias, - bsz_seq, - d_model, - dim_feedforward, - dropout_param1, - dropout_param2, - act_method, - pre_layer_norm, - epsilon, - ring_id); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - fused_feedforward, - ops::FusedFeedForwardXPUKernel, - ops::FusedFeedForwardXPUKernel); - -REGISTER_OP_XPU_KERNEL( - fused_feedforward_grad, - ops::FusedFeedForwardGradXPUKernel, - ops::FusedFeedForwardGradXPUKernel); - -#endif diff --git a/paddle/phi/kernels/fused_feedforward_grad_kernel.h b/paddle/phi/kernels/fused_feedforward_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9eee46a83987ed6aa626b60a2385f3a83f1f0b09 --- /dev/null +++ b/paddle/phi/kernels/fused_feedforward_grad_kernel.h @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedFeedForwardGradKernel( + const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& linear1_weight, + const DenseTensor& linear1_bias, + const DenseTensor& linear2_weight, + const DenseTensor& dropout1_mask, + const DenseTensor& dropout2_mask, + const DenseTensor& linear1_out, + const DenseTensor& dropout1_out, + const DenseTensor& dropout2_out, + const paddle::optional& ln1_scale, + const paddle::optional& ln1_bias, + const paddle::optional& ln1_out, + const paddle::optional& ln1_mean, + const paddle::optional& ln1_variance, + const paddle::optional& ln2_scale, + const paddle::optional& ln2_bias, + const paddle::optional& ln2_mean, + const paddle::optional& ln2_variance, + const paddle::optional& linear2_bias, + bool pre_layer_norm, + float ln1_epsilon, + float ln2_epsilon, + const std::string& act_method, + float dropout1_prob, + float dropout2_prob, + const std::string& dropout1_implementation, + const std::string& dropout2_implementation, + bool is_test, + bool dropout1_fix_seed, + bool dropout2_fix_seed, + int dropout1_seed_val, + int dropout2_seed_val, + bool add_residual, + int ring_id, + DenseTensor* x_grad, + DenseTensor* ln1_scale_grad, + DenseTensor* ln1_bias_grad, + DenseTensor* ln2_scale_grad, + DenseTensor* ln2_bias_grad, + DenseTensor* linear1_weight_grad, + DenseTensor* linear1_bias_grad, + DenseTensor* linear2_weight_grad, + DenseTensor* linear2_bias_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fused_feedforward_kernel.h b/paddle/phi/kernels/fused_feedforward_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cade7adc0c7c24778e82a99bac365e3cd1777f7c --- /dev/null +++ b/paddle/phi/kernels/fused_feedforward_kernel.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FusedFeedForwardKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& dropout1_seed, + const paddle::optional& dropout2_seed, + const DenseTensor& linear1_weight, + const paddle::optional& linear1_bias, + const DenseTensor& linear2_weight, + const paddle::optional& linear2_bias, + const paddle::optional& ln1_scale, + const paddle::optional& ln1_bias, + const paddle::optional& ln2_scale, + const paddle::optional& ln2_bias, + bool pre_layer_norm, + float ln1_epsilon, + float ln2_epsilon, + const std::string& act_method, + float dropout1_prob, + float dropout2_prob, + const std::string& dropout1_implementation, + const std::string& dropout2_implementation, + bool is_test, + bool dropout1_fix_seed, + bool dropout2_fix_seed, + int dropout1_seed_val, + int dropout2_seed_val, + bool add_residual, + int ring_id, + DenseTensor* out, + DenseTensor* dropout1_mask, + DenseTensor* dropout2_mask, + DenseTensor* ln1_mean, + DenseTensor* ln1_variance, + DenseTensor* ln2_mean, + DenseTensor* ln2_variance, + DenseTensor* linear1_out, + DenseTensor* ln1_out, + DenseTensor* dropout1_out, + DenseTensor* dropout2_out); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_xpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb10930dc9b3e2ede0f2b5ac1db68e51aac9d288 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_xpu_kernel.cc @@ -0,0 +1,542 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" +#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h" + +namespace phi { +namespace fusion { + +template +void FFNGrad(const phi::XPUContext& dev_ctx, + const phi::DenseTensor* d_out, + const phi::DenseTensor* x, + const phi::DenseTensor* dropout1_mask, + const phi::DenseTensor* dropout2_mask, + const phi::DenseTensor* linear1_out, + const phi::DenseTensor* ln1_out, + const phi::DenseTensor* dropout1_out, + const phi::DenseTensor* dropout2_out, + const phi::DenseTensor* linear1_weight, + const phi::DenseTensor* linear2_weight, + const phi::DenseTensor* ln_scale, + const phi::DenseTensor* ln_mean, + const phi::DenseTensor* ln_variance, + phi::DenseTensor* d_x, + phi::DenseTensor* d_linear1_weight, + phi::DenseTensor* d_linear1_bias, + phi::DenseTensor* d_linear2_weight, + phi::DenseTensor* d_linear2_bias, + phi::DenseTensor* d_ln_scale, + phi::DenseTensor* d_ln_bias, + const int bsz_seq, + const int d_model, + const int dim_feedforward, + const XPUDropoutParam& dropout_param1, + const XPUDropoutParam& dropout_param2, + const std::string& act_method, + const bool pre_layer_norm, + const float epsilon, + const int ring_id) { + using XPUTypeT = typename XPUTypeTrait::Type; + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + int r = xpu::SUCCESS; + + // inputs ptr + const XPUTypeT* d_out_ptr = + reinterpret_cast(d_out->data()); + const XPUTypeT* x_ptr = reinterpret_cast(x->data()); + const XPUTypeT* dropout1_mask_ptr = + reinterpret_cast(dropout1_mask->data()); + const XPUTypeT* dropout2_mask_ptr = + reinterpret_cast(dropout2_mask->data()); + const XPUTypeT* linear1_out_ptr = + reinterpret_cast(linear1_out->data()); + const XPUTypeT* dropout1_out_ptr = + reinterpret_cast(dropout1_out->data()); + const XPUTypeT* linear1_weight_ptr = + reinterpret_cast(linear1_weight->data()); + const XPUTypeT* linear2_weight_ptr = + reinterpret_cast(linear2_weight->data()); + const float* ln_scale_ptr = ln_scale->data(); + + const float* ln_mean_ptr = ln_mean->data(); + const float* ln_variance_ptr = ln_variance->data(); + // outputs ptr + XPUTypeT* d_x_ptr = reinterpret_cast(d_x->data()); + XPUTypeT* d_linear1_weight_ptr = + reinterpret_cast(d_linear1_weight->data()); + XPUTypeT* d_linear1_bias_ptr = + reinterpret_cast(d_linear1_bias->data()); + XPUTypeT* d_linear2_weight_ptr = + reinterpret_cast(d_linear2_weight->data()); + XPUTypeT* d_linear2_bias_ptr = + reinterpret_cast(d_linear2_bias->data()); + float* d_ln_scale_ptr = d_ln_scale->data(); + float* d_ln_bias_ptr = d_ln_bias->data(); + + size_t l3_total_size = xpu_ctx->_l3_mgr.get_size(); + + XPUTypeT* big_tmp_l3_ptr = NULL; // dim_feedforward * bsz_seq + XPUTypeT* small_tmp_l3_ptr = NULL; // d_model * bsz_seq + XPUTypeT* big_tmp_gm_ptr = NULL; // dim_feedforward * bsz_seq + XPUTypeT* small_tmp_gm_ptr = NULL; // d_model * bsz_seq + + XPUTypeT* d_layernorm_out_ptr = NULL; // dx9 + XPUTypeT* d_dropout2_out_ptr = NULL; // dx7 + + XPUTypeT* d_linear2_out_ptr = NULL; // dx5 + XPUTypeT* d_dropout1_out_ptr = NULL; // dx4 + XPUTypeT* d_act_out_ptr = NULL; // dx3 + + XPUTypeT* d_linear1_out_ptr = NULL; // dx1 + + const XPUTypeT* d_residual_ptr = d_out_ptr; + + if (l3_total_size >= + (dim_feedforward * bsz_seq * sizeof(T) + d_model * bsz_seq * sizeof(T))) { + big_tmp_l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_l3_ptr); + small_tmp_l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_l3_ptr); + d_layernorm_out_ptr = small_tmp_l3_ptr; + d_dropout2_out_ptr = small_tmp_l3_ptr; + d_linear2_out_ptr = big_tmp_l3_ptr; + d_dropout1_out_ptr = big_tmp_l3_ptr; + d_act_out_ptr = big_tmp_l3_ptr; + d_linear1_out_ptr = small_tmp_l3_ptr; + } else if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T)) { + big_tmp_l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_l3_ptr); + small_tmp_l3_ptr = big_tmp_l3_ptr; + big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); + small_tmp_gm_ptr = RAII_GUARD.alloc(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_gm_ptr); + + d_layernorm_out_ptr = small_tmp_l3_ptr; + d_dropout2_out_ptr = small_tmp_gm_ptr; + d_linear2_out_ptr = big_tmp_l3_ptr; + d_dropout1_out_ptr = big_tmp_l3_ptr; + d_act_out_ptr = big_tmp_gm_ptr; + d_linear1_out_ptr = small_tmp_l3_ptr; + + } else if (l3_total_size >= d_model * bsz_seq * sizeof(T)) { + big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); + small_tmp_l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_l3_ptr); + + d_layernorm_out_ptr = small_tmp_l3_ptr; + d_dropout2_out_ptr = small_tmp_l3_ptr; + d_linear2_out_ptr = big_tmp_gm_ptr; + d_dropout1_out_ptr = big_tmp_gm_ptr; + d_act_out_ptr = big_tmp_gm_ptr; + d_linear1_out_ptr = small_tmp_l3_ptr; + } else { + big_tmp_gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(big_tmp_gm_ptr); + small_tmp_gm_ptr = RAII_GUARD.alloc(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(small_tmp_gm_ptr); + d_layernorm_out_ptr = small_tmp_gm_ptr; + d_dropout2_out_ptr = small_tmp_gm_ptr; + d_linear2_out_ptr = big_tmp_gm_ptr; + d_dropout1_out_ptr = big_tmp_gm_ptr; + d_act_out_ptr = big_tmp_gm_ptr; + d_linear1_out_ptr = small_tmp_gm_ptr; + } + + if (pre_layer_norm == false) { + const XPUTypeT* dropout2_out_ptr = + reinterpret_cast(dropout2_out->data()); + r = xpu::layer_norm_grad(xpu_ctx, + dropout2_out_ptr, + d_out_ptr, + d_layernorm_out_ptr, + bsz_seq, + d_model, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_variance_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + d_residual_ptr = d_layernorm_out_ptr; + } + phi::DropoutGrad(xpu_ctx, + d_residual_ptr, + dropout2_mask_ptr, + d_dropout2_out_ptr, + dropout_param2, + bsz_seq * d_model); + // linear_grad2 + r = xpu::reduce_sum( + xpu_ctx, d_dropout2_out_ptr, d_linear2_bias_ptr, {bsz_seq, d_model}, {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + + phi::XpuFcInfo linear2_fc_info; + linear2_fc_info.InitFcInfo(0, + bsz_seq, + d_model, + dim_feedforward, + false, + false, + nullptr, + nullptr, + nullptr); + + const XPUTypeT* a_1 = reinterpret_cast(NULL); + const XPUTypeT* b_1 = reinterpret_cast(NULL); + const XPUTypeT* a_2 = reinterpret_cast(NULL); + const XPUTypeT* b_2 = reinterpret_cast(NULL); + XPUTypeT* c_1 = d_linear2_out_ptr; + XPUTypeT* c_2 = d_linear2_weight_ptr; + phi::XpuFcInfo info_d_dropout1; + phi::XpuFcInfo info_dw2; + + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + linear2_fc_info, + false, + false, + dropout1_out_ptr, + linear2_weight_ptr, + d_dropout2_out_ptr); + + std::tie(info_d_dropout1, info_dw2, a_1, b_1, a_2, b_2) = fc_info; + + // if l3_total_size >= dim_feedforward * bsz_seq * sizeof(T), first transpos + if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T) && + info_dw2.trans_x) { + r = xpu::transpose(xpu_ctx, + dropout1_out_ptr, + big_tmp_l3_ptr, + {bsz_seq, dim_feedforward}, + {1, 0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + a_2 = big_tmp_l3_ptr; + info_dw2.trans_x = !info_dw2.trans_x; + info_dw2.stride_x = info_dw2.k; + } + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_dropout1, 1.0f, true); + + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_dw2, 1.0f, true); + + // dropout_grad1 + DropoutGrad(xpu_ctx, + d_linear2_out_ptr, + dropout1_mask_ptr, + d_dropout1_out_ptr, + dropout_param1, + bsz_seq * dim_feedforward); + + // act_grad + if (act_method == "gelu") { + r = xpu::gelu_grad(xpu_ctx, + linear1_out_ptr, + linear1_out_ptr, + d_dropout1_out_ptr, + d_act_out_ptr, + bsz_seq * dim_feedforward); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu_grad"); + } else if (act_method == "relu") { + r = xpu::relu_grad(xpu_ctx, + linear1_out_ptr, + linear1_out_ptr, + d_dropout1_out_ptr, + d_act_out_ptr, + bsz_seq * dim_feedforward); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad"); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently only supports gelu or relu activation functions!")); + } + + // linear1_grad + r = xpu::reduce_sum(xpu_ctx, + d_act_out_ptr, + d_linear1_bias_ptr, + {bsz_seq, dim_feedforward}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + + phi::XpuFcInfo linear1_fc_info; + linear1_fc_info.InitFcInfo(0, + bsz_seq, + dim_feedforward, + d_model, + false, + false, + nullptr, + nullptr, + nullptr); + + a_1 = reinterpret_cast(NULL); + b_1 = reinterpret_cast(NULL); + a_2 = reinterpret_cast(NULL); + b_2 = reinterpret_cast(NULL); + + c_1 = (pre_layer_norm == true ? d_linear1_out_ptr : d_x_ptr); + c_2 = d_linear1_weight_ptr; + phi::XpuFcInfo info_dx; + phi::XpuFcInfo info_dw1; + + const XPUTypeT* linear1_x_ptr = + (pre_layer_norm == true + ? reinterpret_cast(ln1_out->data()) + : x_ptr); + + if (l3_total_size >= d_model * bsz_seq * sizeof(T) && info_dw1.trans_x) { + r = xpu::transpose( + xpu_ctx, linear1_x_ptr, small_tmp_l3_ptr, {bsz_seq, d_model}, {1, 0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + a_2 = small_tmp_l3_ptr; + info_dw1.trans_x = !info_dw1.trans_x; + info_dw1.stride_x = info_dw1.k; + } + + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + linear1_fc_info, + false, + false, + linear1_x_ptr, + linear1_weight_ptr, + d_act_out_ptr); + + std::tie(info_dx, info_dw1, a_1, b_1, a_2, b_2) = fc_info; + + phi::MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f, true); + + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_dw1, 1.0f, true); + + if (pre_layer_norm) { + r = xpu::layer_norm_grad(xpu_ctx, + x_ptr, + c_1, + c_1, + bsz_seq, + d_model, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_variance_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + } + + r = xpu::add(xpu_ctx, c_1, d_residual_ptr, d_x_ptr, d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); +} + +template +void FusedFeedForwardGradKernel( + const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& linear1_weight, + const DenseTensor& linear1_bias, + const DenseTensor& linear2_weight, + const DenseTensor& dropout1_mask, + const DenseTensor& dropout2_mask, + const DenseTensor& linear1_out, + const DenseTensor& dropout1_out, + const DenseTensor& dropout2_out, + const paddle::optional& ln1_scale, + const paddle::optional& ln1_bias, + const paddle::optional& ln1_out, + const paddle::optional& ln1_mean, + const paddle::optional& ln1_variance, + const paddle::optional& ln2_scale, + const paddle::optional& ln2_bias, + const paddle::optional& ln2_mean, + const paddle::optional& ln2_variance, + const paddle::optional& linear2_bias, + bool pre_layer_norm, + float ln1_epsilon, + float ln2_epsilon, + const std::string& act_method, + float dropout1_prob, + float dropout2_prob, + const std::string& dropout1_implementation, + const std::string& dropout2_implementation, + bool is_test, + bool dropout1_fix_seed, + bool dropout2_fix_seed, + int dropout1_seed_val, + int dropout2_seed_val, + bool add_residual, + int ring_id, + DenseTensor* x_grad, + DenseTensor* ln1_scale_grad, + DenseTensor* ln1_bias_grad, + DenseTensor* ln2_scale_grad, + DenseTensor* ln2_bias_grad, + DenseTensor* linear1_weight_grad, + DenseTensor* linear1_bias_grad, + DenseTensor* linear2_weight_grad, + DenseTensor* linear2_bias_grad) { + // inputs + auto* d_out = &out_grad; + auto* x_ptr = &x; + + auto* dropout1_mask_ptr = &dropout1_mask; + auto* dropout2_mask_ptr = &dropout2_mask; + auto* linear1_out_ptr = &linear1_out; + auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr; + + auto* dropout1_out_ptr = &dropout1_out; + auto* dropout2_out_ptr = &dropout2_out; + auto* linear1_weight_ptr = &linear1_weight; + auto* linear2_weight_ptr = &linear2_weight; + + const phi::DenseTensor* ln_mean = nullptr; + const phi::DenseTensor* ln_variance = nullptr; + const phi::DenseTensor* ln_scale = nullptr; + + if (pre_layer_norm) { + ln_mean = ln1_mean.get_ptr(); + ln_variance = ln1_variance.get_ptr(); + ln_scale = ln1_scale.get_ptr(); + } else { + ln_mean = ln2_mean.get_ptr(); + ln_variance = ln2_variance.get_ptr(); + ln_scale = ln2_scale.get_ptr(); + } + + // output + auto* d_x = x_grad; + + phi::DenseTensor* d_ln_scale = nullptr; + phi::DenseTensor* d_ln_bias = nullptr; + + if (pre_layer_norm) { + d_ln_scale = ln1_scale_grad; + d_ln_bias = ln1_bias_grad; + } else { + d_ln_scale = ln2_scale_grad; + d_ln_bias = ln2_bias_grad; + } + + auto* d_linear1_weight = linear1_weight_grad; + auto* d_linear1_bias = linear1_bias_grad; + auto* d_linear2_weight = linear2_weight_grad; + auto* d_linear2_bias = linear2_bias_grad; + + float epsilon = 0.0f; + if (pre_layer_norm) { + epsilon = ln1_epsilon; + } else { + epsilon = ln2_epsilon; + } + + bool is_upscale_in_train_1 = dropout1_implementation == "upscale_in_train"; + bool is_upscale_in_train_2 = dropout2_implementation == "upscale_in_train"; + + phi::XPUDropoutParam dropout_param1; + dropout_param1.initXPUDropoutParam(dropout1_prob, + is_upscale_in_train_1, + is_test, + dropout1_fix_seed, + nullptr, + dropout1_seed_val); + phi::XPUDropoutParam dropout_param2; + dropout_param2.initXPUDropoutParam(dropout2_prob, + is_upscale_in_train_2, + is_test, + dropout2_fix_seed, + nullptr, + dropout2_seed_val); + + dev_ctx.template Alloc(d_ln_scale); + dev_ctx.template Alloc(d_ln_bias); + dev_ctx.template Alloc(d_linear1_bias); + dev_ctx.template Alloc(d_linear2_bias); + dev_ctx.template Alloc(d_linear1_weight); + dev_ctx.template Alloc(d_linear2_weight); + + auto x_dim = x_ptr->dims(); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( + phi::RowMatrixFromVector(x_dim), 0, false); + + auto linear1_weight_dim = linear1_weight_ptr->dims(); + int d_model = linear1_weight_dim[0]; + int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + + FFNGrad(dev_ctx, + d_out, + x_ptr, + dropout1_mask_ptr, + dropout2_mask_ptr, + linear1_out_ptr, + ln1_out_ptr, + dropout1_out_ptr, + dropout2_out_ptr, + linear1_weight_ptr, + linear2_weight_ptr, + ln_scale, + ln_mean, + ln_variance, + d_x, + d_linear1_weight, + d_linear1_bias, + d_linear2_weight, + d_linear2_bias, + d_ln_scale, + d_ln_bias, + bsz_seq, + d_model, + dim_feedforward, + dropout_param1, + dropout_param2, + act_method, + pre_layer_norm, + epsilon, + ring_id); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_feedward_grad, + XPU, + ALL_LAYOUT, + phi::fusion::FusedFeedForwardGradKernel, + float, + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/kernels/fusion/xpu/fused_feedforward_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_feedforward_xpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..35039ba571e5761a4c3ee18e8ec2450ee6587f3a --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_feedforward_xpu_kernel.cc @@ -0,0 +1,390 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/matmul_grad_kernel_impl.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" +#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h" + +namespace phi { +namespace fusion { + +template +void FFN(const phi::XPUContext& dev_ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* linear1_weight, + const phi::DenseTensor* linear1_bias, + const phi::DenseTensor* linear2_weight, + const phi::DenseTensor* linear2_bias, + const phi::DenseTensor* ln_scale, + const phi::DenseTensor* ln_bias, + phi::DenseTensor* out, + phi::DenseTensor* dropout1_mask, + phi::DenseTensor* dropout2_mask, + phi::DenseTensor* ln_mean, + phi::DenseTensor* ln_variance, + phi::DenseTensor* linear1_out, + phi::DenseTensor* ln1_out, + phi::DenseTensor* dropout1_out, + phi::DenseTensor* dropout2_out, + const int bsz_seq, + const int d_model, + const int dim_feedforward, + const std::string& act_method, + const bool pre_layer_norm, + const float epsilon1, + const float epsilon2, + const phi::XPUDropoutParam& dropout_param1, + const phi::XPUDropoutParam& dropout_param2, + int ring_id) { + using XPUTypeT = typename XPUTypeTrait::Type; + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + int r = xpu::SUCCESS; + + const XPUTypeT* x_ptr = reinterpret_cast(x->data()); + const XPUTypeT* residual_ptr = x_ptr; + const XPUTypeT* linear1_weight_ptr = + reinterpret_cast(linear1_weight->data()); + const XPUTypeT* linear1_bias_ptr = + reinterpret_cast(linear1_bias->data()); + const XPUTypeT* linear2_weight_ptr = + reinterpret_cast(linear2_weight->data()); + const XPUTypeT* linear2_bias_ptr = + reinterpret_cast(linear2_bias->data()); + + const float* ln_scale_ptr = ln_scale->data(); + + const float* ln_bias_ptr = ln_bias->data(); + + // out + XPUTypeT* out_ptr = reinterpret_cast(out->data()); + XPUTypeT* linear1_out_ptr = + reinterpret_cast(linear1_out->data()); + XPUTypeT* dropout1_mask_ptr = + reinterpret_cast(dropout1_mask->data()); + XPUTypeT* dropout2_mask_ptr = + reinterpret_cast(dropout2_mask->data()); + float* ln_mean_ptr = ln_mean->data(); + float* ln_variance_ptr = ln_variance->data(); + + XPUTypeT* dropout1_out_ptr = + reinterpret_cast(dropout1_out->data()); + XPUTypeT* dropout2_out_ptr = + reinterpret_cast(dropout2_out->data()); + + size_t l3_total_size = xpu_ctx->_l3_mgr.get_size(); + XPUTypeT* linear2_before_tmp_ptr = NULL; // dim_feedforward * bsz_seq + XPUTypeT* linear2_after_tmp_ptr = NULL; // d_model * bsz_seq + if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T)) { + XPUTypeT* l3_ptr = RAII_GUARD.alloc_l3(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(l3_ptr); + linear2_before_tmp_ptr = linear2_after_tmp_ptr = l3_ptr; + } else if ((l3_total_size < dim_feedforward * bsz_seq * sizeof(T)) && + (l3_total_size >= d_model * bsz_seq * sizeof(T))) { + XPUTypeT* l3_ptr = RAII_GUARD.alloc_l3(d_model * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(l3_ptr); + linear2_after_tmp_ptr = l3_ptr; + linear2_before_tmp_ptr = + RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(linear2_before_tmp_ptr); + + } else { + XPUTypeT* gm_ptr = RAII_GUARD.alloc(dim_feedforward * bsz_seq); + PADDLE_ENFORCE_XDNN_NOT_NULL(gm_ptr); + linear2_before_tmp_ptr = linear2_after_tmp_ptr = gm_ptr; + } + + // layernorm + if (pre_layer_norm) { + XPUTypeT* ln1_out_ptr = reinterpret_cast(ln1_out->data()); + r = xpu::layer_norm(xpu_ctx, + x_ptr, + ln1_out_ptr, + bsz_seq, + d_model, + epsilon1, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_variance_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm "); + x_ptr = ln1_out_ptr; + } + + // fc + phi::XpuFcInfo linear1_fc_info; + linear1_fc_info.InitFcInfo(0, + bsz_seq, + dim_feedforward, + d_model, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction(xpu_ctx, + x_ptr, + linear1_weight_ptr, + linear2_before_tmp_ptr, + linear1_fc_info, + 1.0f); + + // bias + r = xpu::broadcast_add(xpu_ctx, + linear2_before_tmp_ptr, + linear1_bias_ptr, + linear1_out_ptr, + {bsz_seq, dim_feedforward}, + {dim_feedforward}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + // act + if (act_method == "gelu") { + r = xpu::gelu( + xpu_ctx, linear1_out_ptr, linear2_before_tmp_ptr, linear1_out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "gelu"); + } else if (act_method == "relu") { + r = xpu::relu( + xpu_ctx, linear1_out_ptr, linear2_before_tmp_ptr, linear1_out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently only supports gelu or relu activation functions!")); + } + + // dropout1 + phi::Dropout(xpu_ctx, + linear2_before_tmp_ptr, + dropout1_mask_ptr, + dropout1_out_ptr, + dropout_param1, + dropout1_out->numel()); + + // fc + phi::XpuFcInfo linear2_fc_info; + linear2_fc_info.InitFcInfo(0, + bsz_seq, + d_model, + dim_feedforward, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction(xpu_ctx, + dropout1_out_ptr, + linear2_weight_ptr, + dropout2_out_ptr, + linear2_fc_info, + 1.0f); + + // bias + r = xpu::broadcast_add(xpu_ctx, + dropout2_out_ptr, + linear2_bias_ptr, + dropout2_out_ptr, + {bsz_seq, d_model}, + {d_model}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + // dropout2 + phi::Dropout(xpu_ctx, + dropout2_out_ptr, + dropout2_mask_ptr, + dropout2_out_ptr, + dropout_param2, + dropout2_out->numel()); + + // residual_ptr + dropout_out + XPUTypeT* residual_add_out_ptr = out_ptr; + if (pre_layer_norm == false) { + residual_add_out_ptr = dropout2_out_ptr; + } + r = xpu::broadcast_add(xpu_ctx, + residual_ptr, + dropout2_out_ptr, + residual_add_out_ptr, + {bsz_seq, d_model}, + {bsz_seq, d_model}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + if (pre_layer_norm == false) { + r = xpu::layer_norm(xpu_ctx, + residual_add_out_ptr, + out_ptr, + bsz_seq, + d_model, + epsilon2, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_variance_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + } +} + +template +void FusedFeedForwardKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& dropout1_seed, + const paddle::optional& dropout2_seed, + const DenseTensor& linear1_weight, + const paddle::optional& linear1_bias, + const DenseTensor& linear2_weight, + const paddle::optional& linear2_bias, + const paddle::optional& ln1_scale, + const paddle::optional& ln1_bias, + const paddle::optional& ln2_scale, + const paddle::optional& ln2_bias, + bool pre_layer_norm, + float ln1_epsilon, + float ln2_epsilon, + const std::string& act_method, + float dropout1_prob, + float dropout2_prob, + const std::string& dropout1_implementation, + const std::string& dropout2_implementation, + bool is_test, + bool dropout1_fix_seed, + bool dropout2_fix_seed, + int dropout1_seed_val, + int dropout2_seed_val, + bool add_residual, + int ring_id, + DenseTensor* out, + DenseTensor* dropout1_mask, + DenseTensor* dropout2_mask, + DenseTensor* ln1_mean, + DenseTensor* ln1_variance, + DenseTensor* ln2_mean, + DenseTensor* ln2_variance, + DenseTensor* linear1_out, + DenseTensor* ln1_out, + DenseTensor* dropout1_out, + DenseTensor* dropout2_out) { + auto* x_ptr = &x; + auto* linear1_weight_ptr = &linear1_weight; + auto* linear1_bias_ptr = linear1_bias.get_ptr(); + auto* linear2_weight_ptr = &linear2_weight; + auto* linear2_bias_ptr = linear2_bias.get_ptr(); + + const phi::DenseTensor* ln_scale = nullptr; + const phi::DenseTensor* ln_bias = nullptr; + phi::DenseTensor* ln_mean = nullptr; + phi::DenseTensor* ln_variance = nullptr; + + if (pre_layer_norm) { + ln_scale = ln1_scale.get_ptr(); + ln_bias = ln1_bias.get_ptr(); + ln_mean = ln1_mean; + ln_variance = ln1_variance; + dev_ctx.template Alloc(ln1_out); + } else { + ln_scale = ln2_scale.get_ptr(); + ln_bias = ln2_bias.get_ptr(); + ln_mean = ln2_mean; + ln_variance = ln2_variance; + } + + const float epsilon1 = ln1_epsilon; + const float epsilon2 = ln2_epsilon; + + bool is_upscale_in_train_1 = dropout1_implementation == "upscale_in_train"; + bool is_upscale_in_train_2 = dropout2_implementation == "upscale_in_train"; + + auto* dropout1_seed_ptr = dropout1_seed.get_ptr(); + auto* dropout2_seed_ptr = dropout2_seed.get_ptr(); + phi::XPUDropoutParam dropout_param1; + dropout_param1.initXPUDropoutParam(dropout1_prob, + is_upscale_in_train_1, + is_test, + dropout1_fix_seed, + dropout1_seed_ptr, + dropout1_seed_val); + phi::XPUDropoutParam dropout_param2; + dropout_param2.initXPUDropoutParam(dropout2_prob, + is_upscale_in_train_2, + is_test, + dropout2_fix_seed, + dropout2_seed_ptr, + dropout2_seed_val); + + dev_ctx.template Alloc(ln_mean); + dev_ctx.template Alloc(ln_variance); + + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(dropout1_mask); + dev_ctx.template Alloc(dropout2_mask); + dev_ctx.template Alloc(dropout1_out); + dev_ctx.template Alloc(dropout2_out); + dev_ctx.template Alloc(linear1_out); + + auto x_dim = x_ptr->dims(); + auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( + phi::RowMatrixFromVector(x_dim), 0, false); + + auto dim = linear1_weight_ptr->dims(); + int d_model = dim[0]; + int dim_feedforward = dim[dim.size() - 1]; + int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_; + + phi::fusion::FFN(dev_ctx, + x_ptr, + linear1_weight_ptr, + linear1_bias_ptr, + linear2_weight_ptr, + linear2_bias_ptr, + ln_scale, + ln_bias, + out, + dropout1_mask, + dropout2_mask, + ln_mean, + ln_variance, + linear1_out, + ln1_out, + dropout1_out, + dropout2_out, + bsz_seq, + d_model, + dim_feedforward, + act_method, + pre_layer_norm, + epsilon1, + epsilon2, + dropout_param1, + dropout_param2, + ring_id); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_feedward, + XPU, + ALL_LAYOUT, + phi::fusion::FusedFeedForwardKernel, + float, + phi::dtype::float16) { + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/ops/compat/fused_feedforward_sig.cc b/paddle/phi/ops/compat/fused_feedforward_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dd78288deaf86677aacd64ab1af1790260d48cf --- /dev/null +++ b/paddle/phi/ops/compat/fused_feedforward_sig.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FeedForwardFuseOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_feedforward", + {"X", + "Dropout1Seed", + "Dropout2Seed", + "Linear1Weight", + "Linear1Bias", + "Linear2Weight", + "Linear2Bias", + "Ln1Scale", + "Ln1Bias", + "Ln2Scale", + "Ln2Bias"}, + {"pre_layer_norm", + "ln1_epsilon", + "ln2_epsilon", + "act_method", + "dropout1_rate", + "dropout2_rate", + "dropout1_implementation", + "dropout2_implementation", + "is_test", + "dropout1_fix_seed", + "dropout2_fix_seed", + "dropout1_seed", + "dropout2_seed", + "add_residual", + "ring_id"}, + {"Out", + "Dropout1Mask", + "Dropout2Mask", + "Ln1Mean", + "Ln1Variance", + "Ln2Mean", + "Ln2Variance", + "Linear1Out", + "Ln1Out", + "Dropout1Out", + "Dropout2Out"}); +} + +KernelSignature FeedForwardGradFuseOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_feedforward_grad", + {"Out@GRAD", "X", + "Linear1Weight", "Linear1Bias", + "Linear2Weight", "Dropout1Mask", + "Dropout2Mask", "Linear1Out", + "Dropout1Out", "Dropout2Out", + "Ln1Scale", "Ln1Bias", + "Ln1Out", "Ln1Mean", + "Ln1Variance", "Ln2Scale", + "Ln2Bias", "Ln2Mean", + "Ln2Variance", "Linear2Bias"}, + {"pre_layer_norm", + "ln1_epsilon", + "ln2_epsilon", + "act_method", + "dropout1_rate", + "dropout2_rate", + "dropout1_implementation", + "dropout2_implementation", + "is_test", + "dropout1_fix_seed", + "dropout2_fix_seed", + "dropout1_seed", + "dropout2_seed", + "add_residual", + "ring_id"}, + {"X@GRAD", + "Ln1Scale@GRAD", + "Ln1Bias@GRAD", + "Ln2Scale@GRAD", + "Ln2Bias@GRAD", + "Linear1Weight@GRAD", + "Linear1Bias@GRAD", + "Linear2Weight@GRAD", + "Linear2Bias@GRAD"}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_feedforward, + phi::FeedForwardFuseOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(fused_feedforward_grad, + phi::FeedForwardGradFuseOpArgumentMapping);