diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index d2bc0124cbd3988bf0569920982df0188419753a..17b9381048cf3d13cbd4fc271df22f53668f7e6e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -33,8 +33,6 @@ std::set OpsWithFluidKernelNeedMoveToPhi = { "cudnn_lstm", "dequantize", "distributed_fused_lamb", - "fused_attention", - "fused_attention_grad", "fused_batch_norm_act", "fused_batch_norm_act_grad", "fusion_group", diff --git a/paddle/fluid/operators/fused/fused_attention_op_xpu.cc b/paddle/fluid/operators/fused/fused_attention_op_xpu.cc deleted file mode 100644 index bbfa48f1dca7824ffe98940280d143e13c5eca50..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fused/fused_attention_op_xpu.cc +++ /dev/null @@ -1,948 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/xpu_fused_common_function.h" -#include "paddle/fluid/operators/matmul_v2_op.h" -#include "paddle/fluid/operators/xpu_api_wrapper.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -template -class FusedAttentionOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using XPUTypeT = typename XPUTypeTrait::Type; - - // inputs tensor - auto *input_x = ctx.Input("X"); - - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - - // shape [3, num_head, dim_head, dim_embed] - auto *qkv_weight = ctx.Input("QKVW"); - // shape [3 , num_head, dim_head] - auto *qkv_bias = ctx.Input("QKVBias"); - - // shape [batch_size, 1, 1, seq_len] - auto *src_mask = ctx.Input("SrcMask"); - - // shape [dim_embed, dim_embed] - auto *out_linear_weight = ctx.Input("OutLinearW"); - // shape [dim_embed] - auto *out_linear_bias = ctx.Input("OutLinearBias"); - - const phi::DenseTensor *ln_scale = nullptr; - const phi::DenseTensor *ln_bias = nullptr; - float epsilon = 0.0f; - - if (pre_layer_norm) { - ln_scale = ctx.Input("LnScale"); - ln_bias = ctx.Input("LnBias"); - epsilon = ctx.Attr("epsilon"); - } else { - ln_scale = ctx.Input("Ln2Scale"); - ln_bias = ctx.Input("Ln2Bias"); - epsilon = ctx.Attr("ln_epsilon"); - } - - // outputs tensor - // qkv 的值,并已经做了transpos后的值 - // shape [3, batch_size, num_head, seq_len, dim_head] - auto *TransposeOut2 = ctx.Output("TransposeOut2"); - - // shape [batch_size, num_head, seq_len, seq_len] - auto *softmax_out = ctx.Output("SoftmaxOut"); - // shape [batch_size, num_head, seq_len, seq_len] - auto *attn_dropout_mask_out = - ctx.Output("AttnDropoutMaskOut"); - // shape [batch_size, num_head, seq_len, seq_len] - auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); - - // shape [[batch_size, seq_len, num_head, dim_head]] - auto *fmha_out = ctx.Output("FMHAOut"); - - // shape [batch_size, seq_len, dim_embed] - auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); - - // final output - // shape [batch_size, seq_len, dim_embed] - auto *out = ctx.Output("Y"); - - // 下面这个tensor是不需要返回, 但是新的动态图需要 - auto *QKOut = ctx.Output("QKOut"); - QKOut->mutable_data(ctx.GetPlace()); - auto *QKTVOut = ctx.Output("QKTVOut"); - QKTVOut->mutable_data(ctx.GetPlace()); - auto *OutLinearOut = ctx.Output("OutLinearOut"); - OutLinearOut->mutable_data(ctx.GetPlace()); - auto *QKVBiasOut = ctx.Output("QKVBiasOut"); - QKVBiasOut->mutable_data(ctx.GetPlace()); - auto *SrcMaskOut = ctx.Output("SrcMaskOut"); - SrcMaskOut->mutable_data(ctx.GetPlace()); - auto *qkv_out = ctx.Output("QKVOut"); - qkv_out->mutable_data(ctx.GetPlace()); - - phi::DenseTensor *bias_dropout_residual_out = nullptr; - phi::DenseTensor *ln_mean = nullptr; - phi::DenseTensor *ln_var = nullptr; - phi::DenseTensor *ln_out = nullptr; - - if (pre_layer_norm) { - ln_mean = ctx.Output("LnMean"); - ln_var = ctx.Output("LnVariance"); - ln_out = ctx.Output("LnOut"); - } else { - ln_mean = ctx.Output("Ln2Mean"); - ln_var = ctx.Output("Ln2Variance"); - bias_dropout_residual_out = - ctx.Output("BiasDropoutResidualOut"); - } - - // dropout info - float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); - - bool is_test_1 = ctx.Attr("is_test"); - - auto &dropout_implementation_1 = - ctx.Attr("attn_dropout_implementation"); - - bool is_upscale_in_train_1 = - (dropout_implementation_1 == "upscale_in_train"); - auto *seed_1 = - ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; - - bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); - - int seed_val_1 = ctx.Attr("attn_dropout_seed"); - - XPUDropoutParam attn_dropout_param; - attn_dropout_param.initXPUDropoutParam(attn_dropout_rate, - is_upscale_in_train_1, - is_test_1, - is_fix_seed_1, - seed_1, - seed_val_1); - - XPUDropoutParam dropout_param(ctx, 0); - - // 先计算纬度 - const auto input_x_dims = input_x->dims(); - const auto qkv_w_dims = qkv_weight->dims(); - - int batch_size = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int embed_dims = input_x_dims[2]; - int num_heads = qkv_w_dims[1]; - int head_dims = qkv_w_dims[2]; - - // 输入指针 - const XPUTypeT *input_x_ptr = - reinterpret_cast(input_x->data()); - - const XPUTypeT *qkv_weight_ptr = - reinterpret_cast(qkv_weight->data()); - const XPUTypeT *qkv_bias_ptr = - reinterpret_cast(qkv_bias->data()); - const XPUTypeT *src_mask_ptr = - (src_mask == nullptr) - ? (nullptr) - : (reinterpret_cast(src_mask->data())); - - const XPUTypeT *out_linear_weight_ptr = - reinterpret_cast(out_linear_weight->data()); - - const XPUTypeT *out_linear_bias_ptr = - reinterpret_cast(out_linear_bias->data()); - - const float *ln_scale_ptr = - (ln_scale == nullptr) ? (nullptr) : ln_scale->data(); - - const float *ln_bias_ptr = - (ln_bias == nullptr) ? (nullptr) : ln_bias->data(); - - // 输出指针 - XPUTypeT *qkv_transpose_out_ptr = reinterpret_cast( - TransposeOut2->mutable_data(ctx.GetPlace())); - - XPUTypeT *softmax_out_ptr = reinterpret_cast( - softmax_out->mutable_data(ctx.GetPlace())); - - XPUTypeT *attn_dropout_mask_out_ptr = reinterpret_cast( - attn_dropout_mask_out->mutable_data(ctx.GetPlace())); - - XPUTypeT *attn_dropout_out_ptr = reinterpret_cast( - attn_dropout_out->mutable_data(ctx.GetPlace())); - - XPUTypeT *fmha_out_ptr = - reinterpret_cast(fmha_out->mutable_data(ctx.GetPlace())); - - XPUTypeT *dropout_mask_out_ptr = reinterpret_cast( - dropout_mask_out->mutable_data(ctx.GetPlace())); - - XPUTypeT *out_ptr = - reinterpret_cast(out->mutable_data(ctx.GetPlace())); - - XPUTypeT *bias_dropout_residual_out_ptr = - (bias_dropout_residual_out == nullptr) - ? (nullptr) - : (reinterpret_cast( - bias_dropout_residual_out->mutable_data(ctx.GetPlace()))); - - float *ln_mean_ptr = (ln_mean == nullptr) - ? (nullptr) - : ln_mean->mutable_data(ctx.GetPlace()); - - float *ln_var_ptr = (ln_var == nullptr) - ? (nullptr) - : ln_var->mutable_data(ctx.GetPlace()); - - XPUTypeT *ln_out_ptr = (ln_out == nullptr) - ? (nullptr) - : (reinterpret_cast( - ln_out->mutable_data(ctx.GetPlace()))); - - auto &dev_ctx = ctx.template device_context(); - - xpu::Context *xpu_ctx = dev_ctx.x_context(); - - xpu::ctx_guard RAII_GUARD(xpu_ctx); - - int l3_total_size = xpu_ctx->_l3_mgr.get_size(); - - XPUTypeT *qkv_before_transpos_ptr = - NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims] - XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len] - XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims] - XPUTypeT *linear_out_ptr = - NULL; // x4, x5 [batch_size, seq_len, embed_dims] - - int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims; - int temp_size_2 = batch_size * num_heads * seq_len * seq_len; - int temp_size_3 = batch_size * num_heads * seq_len * head_dims; - int temp_size_4 = batch_size * seq_len * embed_dims; - - std::vector temp_vec = { - temp_size_1, temp_size_2, temp_size_3, temp_size_4}; - std::sort(temp_vec.begin(), temp_vec.end(), std::greater()); - XPUTypeT *max_gm_ptr = RAII_GUARD.alloc(temp_vec[0]); - PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr); - qkv_before_transpos_ptr = max_gm_ptr; - qk_ptr = max_gm_ptr; - qkv_ptr = max_gm_ptr; - linear_out_ptr = max_gm_ptr; - int sizeof_t = sizeof(XPUTypeT); - for (size_t i = 0; i < temp_vec.size(); ++i) { - if (l3_total_size >= temp_vec[i] * sizeof_t) { - XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3(temp_vec[i]); - qkv_before_transpos_ptr = - (temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; - qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; - qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; - linear_out_ptr = (temp_size_4 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; - break; - } - } - - int r = 0; - const XPUTypeT *x_cacl_ptr = input_x_ptr; - if (pre_layer_norm) { - r = xpu::layer_norm(xpu_ctx, - input_x_ptr, - ln_out_ptr, - batch_size * seq_len, - embed_dims, - epsilon, - ln_scale_ptr, - ln_bias_ptr, - ln_mean_ptr, - ln_var_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); - x_cacl_ptr = ln_out_ptr; - } - - // fc - phi::XpuFcInfo qkv_fc_info; - qkv_fc_info.InitFcInfo(0, - batch_size * seq_len, - 3 * num_heads * head_dims, - embed_dims, - false, - true, - nullptr, - nullptr, - nullptr); - - phi::MatMulXPUFunction(xpu_ctx, - x_cacl_ptr, - qkv_weight_ptr, - qkv_before_transpos_ptr, - qkv_fc_info, - 1.0f); - - // bias - r = xpu::broadcast_add(xpu_ctx, - qkv_before_transpos_ptr, - qkv_bias_ptr, - qkv_before_transpos_ptr, - {batch_size * seq_len, 3 * num_heads * head_dims}, - {3 * num_heads * head_dims}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - - // transpose - r = xpu::transpose(xpu_ctx, - qkv_before_transpos_ptr, - qkv_transpose_out_ptr, - {batch_size, seq_len, 3, num_heads, head_dims}, - {2, 0, 3, 1, 4}); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - - int qkv_every_size = batch_size * seq_len * num_heads * head_dims; - { - float alpha = 1.0 / sqrt(head_dims); - r = scale(xpu_ctx, - qkv_transpose_out_ptr, - qkv_transpose_out_ptr, - qkv_every_size, - false, - alpha, - 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); - } - - // begin fhma - // 1. qk 2. qk + mask 3. softmax 4.dropout 5. qkv 6. transpos - { - const XPUTypeT *q_ptr = qkv_transpose_out_ptr; - const XPUTypeT *k_ptr = q_ptr + qkv_every_size; - const XPUTypeT *v_ptr = k_ptr + qkv_every_size; - phi::XpuFcInfo qk_fc_info; - qk_fc_info.InitFcInfo(batch_size * num_heads, - seq_len, - seq_len, - head_dims, - false, - true, - nullptr, - nullptr, - nullptr); - phi::MatMulXPUFunction( - xpu_ctx, q_ptr, k_ptr, qk_ptr, qk_fc_info, 1.0f); - - if (src_mask_ptr) { - r = xpu::broadcast_add(xpu_ctx, - qk_ptr, - src_mask_ptr, - qk_ptr, - {batch_size, num_heads, seq_len, seq_len}, - {batch_size, 1, 1, seq_len}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - } - // do softmax - r = xpu::softmax(xpu_ctx, - qk_ptr, - softmax_out_ptr, - {batch_size, num_heads, seq_len, seq_len}, - 3); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); - - // do dropout - Dropout(xpu_ctx, - softmax_out_ptr, - attn_dropout_mask_out_ptr, - attn_dropout_out_ptr, - attn_dropout_param, - batch_size * num_heads * seq_len * seq_len); - - phi::XpuFcInfo qktv_fc_info; - qktv_fc_info.InitFcInfo(batch_size * num_heads, - seq_len, - head_dims, - seq_len, - false, - false, - nullptr, - nullptr, - nullptr); - phi::MatMulXPUFunction( - xpu_ctx, attn_dropout_out_ptr, v_ptr, qkv_ptr, qktv_fc_info, 1.0f); - r = xpu::transpose(xpu_ctx, - qkv_ptr, - fmha_out_ptr, - {batch_size, num_heads, seq_len, head_dims}, - {0, 2, 1, 3}); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - } - - // linear_out - phi::XpuFcInfo linear_fc_info; - linear_fc_info.InitFcInfo(0, - batch_size * seq_len, - embed_dims, - embed_dims, - false, - false, - nullptr, - nullptr, - nullptr); - phi::MatMulXPUFunction(xpu_ctx, - fmha_out_ptr, - out_linear_weight_ptr, - linear_out_ptr, - linear_fc_info, - 1.0f); - - // out_linear_bias_ptr - r = xpu::broadcast_add(xpu_ctx, - linear_out_ptr, - out_linear_bias_ptr, - linear_out_ptr, - {batch_size * seq_len, embed_dims}, - {embed_dims}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); - - Dropout(xpu_ctx, - linear_out_ptr, - dropout_mask_out_ptr, - linear_out_ptr, - dropout_param, - batch_size * seq_len * embed_dims); - - XPUTypeT *real_out_ptr = out_ptr; - if (pre_layer_norm == false) { - real_out_ptr = bias_dropout_residual_out_ptr; - } - - r = xpu::add(xpu_ctx, - linear_out_ptr, - input_x_ptr, - real_out_ptr, - batch_size * seq_len * embed_dims); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); - - if (pre_layer_norm == false) { - r = xpu::layer_norm(xpu_ctx, - real_out_ptr, - out_ptr, - batch_size * seq_len, - embed_dims, - epsilon, - ln_scale_ptr, - ln_bias_ptr, - ln_mean_ptr, - ln_var_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); - } - } -}; - -// template -template -class FusedAttentionGradXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using XPUTypeT = typename XPUTypeTrait::Type; - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - - // dropout info - float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); - bool is_test_1 = ctx.Attr("is_test"); - auto &dropout_implementation_1 = - ctx.Attr("attn_dropout_implementation"); - bool is_upscale_in_train_1 = - (dropout_implementation_1 == "upscale_in_train"); - auto *seed_1 = - ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; - bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); - int seed_val_1 = ctx.Attr("attn_dropout_seed"); - - XPUDropoutParam attn_dropout_param; - attn_dropout_param.initXPUDropoutParam(attn_dropout_prob, - is_upscale_in_train_1, - is_test_1, - is_fix_seed_1, - seed_1, - seed_val_1); - - XPUDropoutParam dropout_param(ctx, 0); - // get inputs. - auto *d_y = ctx.Input(framework::GradVarName("Y")); - const XPUTypeT *d_y_ptr = - reinterpret_cast(d_y->data()); - // 前向必要参数 - auto *input_x = ctx.Input("X"); - const XPUTypeT *input_x_ptr = - reinterpret_cast(input_x->data()); - auto *qkv_transpose_out = ctx.Input("TransposeOut2"); - const XPUTypeT *qkv_transpose_out_ptr = - reinterpret_cast(qkv_transpose_out->data()); - auto *qkv_weight = ctx.Input("QKVW"); - const XPUTypeT *qkv_weight_ptr = - reinterpret_cast(qkv_weight->data()); - - auto *softmax_out = ctx.Input("SoftmaxOut"); - const XPUTypeT *softmax_out_ptr = - reinterpret_cast(softmax_out->data()); - auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); - const XPUTypeT *attn_dropout_out_ptr = - reinterpret_cast(attn_dropout_out->data()); - - auto *attn_dropout_mask = ctx.Input("AttnDropoutMaskOut"); - const XPUTypeT *attn_dropout_mask_ptr = - reinterpret_cast(attn_dropout_mask->data()); - auto *fmha_out = ctx.Input("FMHAOut"); - const XPUTypeT *fmha_out_ptr = - reinterpret_cast(fmha_out->data()); - - auto *out_linear_weight = ctx.Input("OutLinearW"); - const XPUTypeT *out_linear_weight_ptr = - reinterpret_cast(out_linear_weight->data()); - - auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); - const XPUTypeT *dropout_mask_out_ptr = - reinterpret_cast(dropout_mask_out->data()); - // 需要计算的梯度 - auto *d_qkv_weight = - ctx.Output(framework::GradVarName("QKVW")); - XPUTypeT *d_qkv_weight_ptr = reinterpret_cast( - d_qkv_weight->mutable_data(ctx.GetPlace())); - - auto *d_qkv_bias = - ctx.Output(framework::GradVarName("QKVBias")); - XPUTypeT *d_qkv_bias_ptr = reinterpret_cast( - d_qkv_bias->mutable_data(ctx.GetPlace())); - auto *d_out_linear_weight = - ctx.Output(framework::GradVarName("OutLinearW")); - - XPUTypeT *d_out_linear_weight_ptr = reinterpret_cast( - d_out_linear_weight->mutable_data(ctx.GetPlace())); - - auto *d_out_linear_bias = - ctx.Output(framework::GradVarName("OutLinearBias")); - XPUTypeT *d_out_linear_bias_ptr = reinterpret_cast( - d_out_linear_bias->mutable_data(ctx.GetPlace())); - // 有可能需要 - auto *d_src_mask_out = - ctx.Output(framework::GradVarName("SrcMaskOut")); - XPUTypeT *d_src_mask_out_ptr = - (d_src_mask_out == nullptr) - ? (nullptr) - : (reinterpret_cast( - d_src_mask_out->mutable_data(ctx.GetPlace()))); - // 输出 dx - auto *d_x = ctx.Output(framework::GradVarName("X")); - XPUTypeT *d_x_ptr = - reinterpret_cast(d_x->mutable_data(ctx.GetPlace())); - - const phi::DenseTensor *ln_out = nullptr; - const phi::DenseTensor *bias_dropout_residual_out = nullptr; - const phi::DenseTensor *ln_scale = nullptr; - const phi::DenseTensor *ln_mean = nullptr; - const phi::DenseTensor *ln_var = nullptr; - phi::DenseTensor *d_ln_scale = nullptr; - phi::DenseTensor *d_ln_bias = nullptr; - - const XPUTypeT *ln_out_ptr = NULL; - const float *ln_scale_ptr = NULL; - const float *ln_mean_ptr = NULL; - const float *ln_var_ptr = NULL; - const XPUTypeT *bias_dropout_residual_out_ptr = NULL; - float *d_ln_scale_ptr = nullptr; - float *d_ln_bias_ptr = nullptr; - - float epsilon = 0.0f; - - if (pre_layer_norm) { - ln_out = ctx.Input("LnOut"); - ln_out_ptr = reinterpret_cast(ln_out->data()); - ln_scale = ctx.Input("LnScale"); - ln_mean = ctx.Input("LnMean"); - ln_var = ctx.Input("LnVariance"); - epsilon = ctx.Attr("epsilon"); - d_ln_scale = - ctx.Output(framework::GradVarName("LnScale")); - d_ln_bias = - ctx.Output(framework::GradVarName("LnBias")); - - } else { - ln_scale = ctx.Input("Ln2Scale"); - ln_mean = ctx.Input("Ln2Mean"); - ln_var = ctx.Input("Ln2Variance"); - epsilon = ctx.Attr("ln_epsilon"); - d_ln_scale = - ctx.Output(framework::GradVarName("Ln2Scale")); - d_ln_bias = - ctx.Output(framework::GradVarName("Ln2Bias")); - bias_dropout_residual_out = - ctx.Input("BiasDropoutResidualOut"); - bias_dropout_residual_out_ptr = reinterpret_cast( - bias_dropout_residual_out->data()); - } - - ln_scale_ptr = ln_scale->data(); - ln_mean_ptr = ln_mean->data(); - ln_var_ptr = ln_var->data(); - d_ln_scale_ptr = d_ln_scale->mutable_data(ctx.GetPlace()); - d_ln_bias_ptr = d_ln_bias->mutable_data(ctx.GetPlace()); - - const auto input_x_dims = input_x->dims(); - const auto qkv_w_dims = qkv_weight->dims(); - - int batch_size = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int embed_dims = input_x_dims[2]; - int num_heads = qkv_w_dims[1]; - int head_dims = qkv_w_dims[2]; - - auto &dev_ctx = ctx.template device_context(); - xpu::Context *xpu_ctx = dev_ctx.x_context(); - xpu::ctx_guard RAII_GUARD(xpu_ctx); - - int r = 0; - // int l3_total_size = xpu_ctx->_l3_mgr.get_size(); - XPUTypeT *d_ln_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] - XPUTypeT *d_dropout_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] - - XPUTypeT *d_fmha_out_ptr = - NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims] - XPUTypeT *d_fmha_out_transpos_tmp_ptr = - NULL; // d_fmha_out_transpos [batch_size, seq_len, num_heads, - // head_dims] - - XPUTypeT *d_qk_ptr = - NULL; // d_qk_ptr[batch_size, num_heads, seq_len, seq_len] - - XPUTypeT *d_combination_qkv_ptr = - NULL; // d_combination_qkv_ptr[3, batch_size, num_heads, seq_len, - // head_dims] - XPUTypeT *d_transpos_qkv_ptr = - NULL; // dx2 [batch_size, seq_len, 3, num_heads, head_dims] - - XPUTypeT *d_last_layernorm_grad_ptr = - NULL; // d_layer_out [batch_size, seq_len, embed_dims] - - const XPUTypeT *dy_input_ptr = d_y_ptr; - - d_ln_grad_ptr = - RAII_GUARD.alloc(batch_size * seq_len * embed_dims); - d_dropout_grad_ptr = - RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); - d_fmha_out_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * - num_heads * head_dims); - d_combination_qkv_ptr = - RAII_GUARD.alloc(batch_size * seq_len * embed_dims * 3); - d_transpos_qkv_ptr = RAII_GUARD.alloc_l3_or_gm( - batch_size * seq_len * embed_dims * 3); - d_fmha_out_transpos_tmp_ptr = - RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); - d_qk_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * - seq_len * num_heads); - d_last_layernorm_grad_ptr = - RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); - - if (pre_layer_norm == false) { - r = xpu::layer_norm_grad(xpu_ctx, - bias_dropout_residual_out_ptr, - d_y_ptr, - d_ln_grad_ptr, - batch_size * seq_len, - embed_dims, - epsilon, - ln_scale_ptr, - ln_mean_ptr, - ln_var_ptr, - d_ln_scale_ptr, - d_ln_bias_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); - dy_input_ptr = d_ln_grad_ptr; - } - // dropout_grad - DropoutGrad(xpu_ctx, - dy_input_ptr, - dropout_mask_out_ptr, - d_dropout_grad_ptr, - dropout_param, - batch_size * num_heads * seq_len * head_dims); - - // linear_out - phi::XpuFcInfo linear_fc_info; - linear_fc_info.InitFcInfo(0, - batch_size * seq_len, - embed_dims, - embed_dims, - false, - false, - nullptr, - nullptr, - nullptr); - const XPUTypeT *a_1 = reinterpret_cast(NULL); - const XPUTypeT *b_1 = reinterpret_cast(NULL); - const XPUTypeT *a_2 = reinterpret_cast(NULL); - const XPUTypeT *b_2 = reinterpret_cast(NULL); - - XPUTypeT *c_1 = d_fmha_out_ptr; - XPUTypeT *c_2 = d_out_linear_weight_ptr; - phi::XpuFcInfo info_dfmha; - phi::XpuFcInfo info_dlinear_w; - - std::tuple - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - linear_fc_info, - false, - false, - fmha_out_ptr, - out_linear_weight_ptr, - d_dropout_grad_ptr); - - std::tie(info_dfmha, info_dlinear_w, a_1, b_1, a_2, b_2) = fc_info; - phi::MatMulXPUFunction( - xpu_ctx, a_2, b_2, c_2, info_dlinear_w, 1.0f, true); - - phi::MatMulXPUFunction( - xpu_ctx, a_1, b_1, c_1, info_dfmha, 1.0f, true); - - // dlinear_bias - r = xpu::reduce_sum(xpu_ctx, - d_dropout_grad_ptr, - d_out_linear_bias_ptr, - {batch_size * seq_len, embed_dims}, - {0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); - { - int qkv_size = batch_size * seq_len * num_heads * head_dims; - const XPUTypeT *q_out_ptr = qkv_transpose_out_ptr; - const XPUTypeT *k_out_ptr = q_out_ptr + qkv_size; - const XPUTypeT *v_out_ptr = k_out_ptr + qkv_size; - XPUTypeT *d_q_out_ptr = d_combination_qkv_ptr; - XPUTypeT *d_k_out_ptr = d_q_out_ptr + qkv_size; - XPUTypeT *d_v_out_ptr = d_k_out_ptr + qkv_size; - r = xpu::transpose(xpu_ctx, - d_fmha_out_ptr, - d_fmha_out_transpos_tmp_ptr, - {batch_size, seq_len, num_heads, head_dims}, - {0, 2, 1, 3}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - - phi::XpuFcInfo qktv_fc_info; - qktv_fc_info.InitFcInfo(batch_size * num_heads, - seq_len, - head_dims, - seq_len, - false, - false, - nullptr, - nullptr, - nullptr); - - const XPUTypeT *a_1 = reinterpret_cast(NULL); - const XPUTypeT *b_1 = reinterpret_cast(NULL); - const XPUTypeT *a_2 = reinterpret_cast(NULL); - const XPUTypeT *b_2 = reinterpret_cast(NULL); - XPUTypeT *c_1 = d_qk_ptr; - XPUTypeT *c_2 = d_v_out_ptr; - phi::XpuFcInfo info_d_qk; - phi::XpuFcInfo info_d_v; - - std::tuple - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - qktv_fc_info, - false, - false, - attn_dropout_out_ptr, - v_out_ptr, - d_fmha_out_transpos_tmp_ptr); - - std::tie(info_d_qk, info_d_v, a_1, b_1, a_2, b_2) = fc_info; - phi::MatMulXPUFunction( - xpu_ctx, a_1, b_1, c_1, info_d_qk, 1.0f, true); - phi::MatMulXPUFunction( - xpu_ctx, a_2, b_2, c_2, info_d_v, 1.0f, true); - - DropoutGrad(xpu_ctx, - d_qk_ptr, - attn_dropout_mask_ptr, - d_qk_ptr, - attn_dropout_param, - batch_size * seq_len * seq_len * num_heads); - - r = xpu::softmax_grad(xpu_ctx, - softmax_out_ptr, - d_qk_ptr, - d_qk_ptr, - {batch_size, num_heads, seq_len, seq_len}, - 3); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); - - if (d_src_mask_out_ptr) { - r = xpu::copy(xpu_ctx, - d_qk_ptr, - d_src_mask_out_ptr, - batch_size * seq_len * seq_len * num_heads); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - } - phi::XpuFcInfo qk_fc_info; - qk_fc_info.InitFcInfo(batch_size * num_heads, - seq_len, - seq_len, - head_dims, - false, - true, - nullptr, - nullptr, - nullptr); - - a_1 = reinterpret_cast(NULL); - b_1 = reinterpret_cast(NULL); - a_2 = reinterpret_cast(NULL); - b_2 = reinterpret_cast(NULL); - c_1 = d_q_out_ptr; - c_2 = d_k_out_ptr; - phi::XpuFcInfo info_d_q; - phi::XpuFcInfo info_d_k; - - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - qk_fc_info, - false, - true, - q_out_ptr, - k_out_ptr, - d_qk_ptr); - - std::tie(info_d_q, info_d_k, a_1, b_1, a_2, b_2) = fc_info; - - phi::MatMulXPUFunction( - xpu_ctx, a_1, b_1, c_1, info_d_q, 1.0f / sqrt(head_dims), true); - - phi::MatMulXPUFunction( - xpu_ctx, a_2, b_2, c_2, info_d_k, 1.0f, true); - } - - // - r = xpu::transpose(xpu_ctx, - d_combination_qkv_ptr, - d_transpos_qkv_ptr, - {3, batch_size, num_heads, seq_len, head_dims}, - {1, 3, 0, 2, 4}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - // dx and d_qkv_w - phi::XpuFcInfo qkv_fc_info; - qkv_fc_info.InitFcInfo(0, - batch_size * seq_len, - 3 * num_heads * head_dims, - embed_dims, - false, - true, - nullptr, - nullptr, - nullptr); - - a_1 = reinterpret_cast(NULL); - b_1 = reinterpret_cast(NULL); - a_2 = reinterpret_cast(NULL); - b_2 = reinterpret_cast(NULL); - c_1 = (pre_layer_norm == true) ? d_last_layernorm_grad_ptr : d_x_ptr; - c_2 = d_qkv_weight_ptr; - phi::XpuFcInfo info_d_x; - phi::XpuFcInfo info_d_qkv_w; - - const XPUTypeT *use_calc_input_x_ptr = - (pre_layer_norm == true) ? ln_out_ptr : input_x_ptr; - - fc_info = phi::MatmulGradFcInfo(xpu_ctx, - &RAII_GUARD, - qkv_fc_info, - false, - true, - use_calc_input_x_ptr, - qkv_weight_ptr, - d_transpos_qkv_ptr); - - std::tie(info_d_x, info_d_qkv_w, a_1, b_1, a_2, b_2) = fc_info; - phi::MatMulXPUFunction( - xpu_ctx, a_1, b_1, c_1, info_d_x, 1.0f, true); - phi::MatMulXPUFunction( - xpu_ctx, a_2, b_2, c_2, info_d_qkv_w, 1.0f, true); - - // d_qkv_bias - r = xpu::reduce_sum(xpu_ctx, - d_transpos_qkv_ptr, - d_qkv_bias_ptr, - {batch_size * seq_len, 3 * embed_dims}, - {0}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); - - if (pre_layer_norm) { - r = xpu::layer_norm_grad(xpu_ctx, - input_x_ptr, - c_1, - d_x_ptr, - batch_size * seq_len, - embed_dims, - epsilon, - ln_scale_ptr, - ln_mean_ptr, - ln_var_ptr, - d_ln_scale_ptr, - d_ln_bias_ptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); - } - - // add rediaus dy - r = xpu::add(xpu_ctx, - dy_input_ptr, - d_x_ptr, - d_x_ptr, - batch_size * seq_len * embed_dims); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_XPU_KERNEL( - fused_attention, - ops::FusedAttentionOpKernel, - ops::FusedAttentionOpKernel); - -REGISTER_OP_XPU_KERNEL( - fused_attention_grad, - ops::FusedAttentionGradXPUKernel, - ops::FusedAttentionGradXPUKernel); - -#endif diff --git a/paddle/phi/kernels/fused_attention_grad_kernel.h b/paddle/phi/kernels/fused_attention_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d6b5f77c724eda31b5bc5564a3a24336b6c0ec23 --- /dev/null +++ b/paddle/phi/kernels/fused_attention_grad_kernel.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void FusedAttentionGradKernel( + const Context &dev_ctx, + const DenseTensor &out_grad, + const DenseTensor &x, + const DenseTensor &qkv_weight, + const paddle::optional &qkv_bias, + const paddle::optional &qkv_bias_out, + const paddle::optional &src_mask, + const paddle::optional &src_mask_out, + const DenseTensor &out_linear_weight, + const paddle::optional &out_linear_bias, + const paddle::optional &ln_scale, + const paddle::optional &ln_bias, + const paddle::optional &ln_scale_2, + const paddle::optional &ln_bias_2, + const paddle::optional &ln_out, + const paddle::optional &ln_mean, + const paddle::optional &ln_var, + const paddle::optional &ln_mean_2, + const paddle::optional &ln_var_2, + const paddle::optional &bias_dropout_residual_out, + const DenseTensor &qkv_out, + const DenseTensor &transpose_out_2, + const DenseTensor &qk_out, + const DenseTensor &qktv_out, + const DenseTensor &softmax_out, + const DenseTensor &attn_dropout_mask_out, + const DenseTensor &attn_dropout_out, + const DenseTensor &fmha_out, + const DenseTensor &out_linear_out, + const DenseTensor &dropout_mask_out, + int num_heads, + bool transpose_qkv_wb, + bool pre_layer_norm, + float epsilon, + float attn_dropout_rate, + bool is_test, + bool attn_dropout_fix_seed, + int attn_dropout_seed, + const std::string &attn_dropout_implementation, + float dropout_rate, + bool dropout_fix_seed, + int dropout_seed, + const std::string &dropout_implementation, + float ln_epsilon, + bool add_residual, + int ring_id, + DenseTensor *qkv_bias_grad, + DenseTensor *qkv_bias_out_grad, + DenseTensor *src_mask_out_grad, + DenseTensor *out_linear_bias_grad, + DenseTensor *ln_scale_grad, + DenseTensor *ln_bias_grad, + DenseTensor *ln_scale_2_grad, + DenseTensor *ln_bias_2_grad, + DenseTensor *x_grad, + DenseTensor *qkv_weight_grad, + DenseTensor *out_linear_weight_grad, + DenseTensor *ln_out_grad, + DenseTensor *bias_dropout_residual_out_grad, + DenseTensor *qkv_out_grad, + DenseTensor *qktv_out_grad, + DenseTensor *transpose_out_2_grad, + DenseTensor *qk_out_grad, + DenseTensor *softmax_out_grad, + DenseTensor *attn_dropout_out_grad, + DenseTensor *fmha_out_grad, + DenseTensor *out_linear_out_grad); +} // namespace phi diff --git a/paddle/phi/kernels/fused_attention_kernel.h b/paddle/phi/kernels/fused_attention_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..529c49514066d4dade09d93bc023d14dd1a1c93a --- /dev/null +++ b/paddle/phi/kernels/fused_attention_kernel.h @@ -0,0 +1,156 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +/** + * @brief Fused Attention Kernel. + * @param ctx device context + * @param x The input tensor. + * @param ln_scale (optional) Scale is a 1-dimensional tensor of size + * H. Here, H represents the last dimension of its + * input tensor. + * @param ln_bias (optional) Bias is a 1-dimensional tensor of size + * H. Here, H represents the last dimension of its + * input tensor. + * @param qkv_weight The qkv weight tensor. + * @param qkv_bias The qkv bias tensor. + * @param cache_kv (optional) The cache KV for generation inference. + * @param src_mask (optional) The attention mask tensor in fmha. + * @param out_linear_w The out_linear weight tensor. + * @param out_linear_bias (optional) The out_linear bias tensor. + * @param ln_scale_2 (optional) Scale is a 1-dimensional tensor of + * size H. Here, H represents the last dimension of its input tensor. + * @param ln_bias_2 (optional) Bias is a 1-dimensional tensor of size + * H. Here, H represents the last dimension of its + * input tensor. + * @param num_heads The number head for multi_head_attention. + * @param transpose_qkv_wb The qkv_w shape is (h, 3h), do transpose to it. + * @param pre_layer_norm if true, the attention op uses pre_layer_norm + * architecure, else, uses post_layer_norm + * architecuture. [default false]. + * @param epsilon Constant for numerical stability [default 1e-5]. + * @param attn_dropout_rate Probability of setting units to zero. + * @param is_test (bool, default false) Set to true for inference + * only, false " for training. Some layers may run + * faster when this is true. + * @param attn_dropout_fix_seed A flag indicating whether to use a fixed seed to + * generate " random mask. NOTE: DO NOT set this flag + * to true in training. Setting this flag to true is + * only useful in unittest or for debug that always the same output units will + * be dropped." + * @param attn_dropout_seed Dropout random seed. + * @param attn_dropout_implementation ["downgrade_in_infer"|"upscale_in_train"] + * There are two kinds of ways to implement dropout + * (the mask below is a tensor have the same shape + * with input the value of mask is 0 or 1, the ratio of 0 is + * dropout_rate) + * 1. downgrade_in_infer(default), downgrade the + * outcome at inference time train: out = input * + * mask inference: out = input * (1.0 - dropout_rate) + * 2. upscale_in_train, upscale the outcome at + * training time, do nothing in inference train: + * out = input * mask / ( 1.0 - dropout_rate ) inference: out = input dropout op + * can be removed from the program. the program will be efficient + * @param dropout_rate Probability of setting units to zero. + * @param dropout_fix_seed A flag indicating whether to use a fixed seed to + * generate " random mask. NOTE: DO NOT set this flag + * to true in training. Setting this flag to true is + * only useful in unittest or for debug that always the same output units will + * be dropped. + * @param dropout_seed Dropout random seed. + * @param dropout_implementation dropout_implementation + * ["downgrade_in_infer"|"upscale_in_train"] The + * meaning is the same as + * 'attn_dropout_implementation' + * @param ln_epsilon Constant for numerical stability [default 1e-5]. + * @param add_residual Whether to add residual. + * @param ring_id ring id for tensor model parallel. distributed + * training and inference + * @param ln_mean Mean of the current mini batch. + * @param ln_var Variance of the current mini batch. + * @param ln_out The output tensor after layer_norm. + * @param qkv_out Result after qkv. + * @param qkv_bias_out Result after qkv and bias op. + * @param transpose_out_2 Result in fmha. + * @param qk_out Result in fmha. + * @param qktv_out Result in fmha. + * @param soft_max_out Result in fmha. + * @param attn_dropout_mask_out Result in fmha. + * @param attn_dropout_out Result in fmha. + * @param src_mask_out Result in fmha. + * @param fmha_out Result in fmha. + * @param out_linear_out Result after out_linear. + * @param dropout_mask_out The random sampled dropout mask. + * @param ln_mean_2 Mean of the current mini batch. + * @param ln_var_2 Variance of the current mini batch. + * @param bias_dropout_residual_out Result of residual + dropout(src + bias). + * @param cache_kv_out The update cache KV. + * @param y Result after attention. + */ +template +void FusedAttentionKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &ln_scale, + const paddle::optional &ln_bias, + const DenseTensor &qkv_weight, + const paddle::optional &qkv_bias, + const paddle::optional &cache_kv, + const paddle::optional &src_mask, + const DenseTensor &out_linear_weight, + const paddle::optional &out_linear_bias, + const paddle::optional &ln_scale_2, + const paddle::optional &ln_bias_2, + int num_heads, + bool transpose_qkv_wb, + bool pre_layer_norm, + float epsilon, + float attn_dropout_rate, + bool is_test, + bool attn_dropout_fix_seed, + int attn_dropout_seed, + const std::string &attn_dropout_implementation, + float dropout_rate, + bool dropout_fix_seed, + int dropout_seed, + const std::string &dropout_implementation, + float ln_epsilon, + bool add_residual, + int ring_id, + DenseTensor *ln_mean, + DenseTensor *ln_var, + DenseTensor *ln_out, + DenseTensor *qkv_out, + DenseTensor *qkv_bias_out, + DenseTensor *transpose_out_2, + DenseTensor *qk_out, + DenseTensor *qktv_out, + DenseTensor *softmax_out, + DenseTensor *attn_dropout_mask_out, + DenseTensor *attn_dropout_out, + DenseTensor *src_mask_out, + DenseTensor *fmha_out, + DenseTensor *out_linear_out, + DenseTensor *dropout_mask_out, + DenseTensor *ln_mean_2, + DenseTensor *ln_var_2, + DenseTensor *bias_dropout_residual_out, + DenseTensor *cache_kv_out, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc b/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4432f82d9b265a4c2129ab1d86bd4732300d60d --- /dev/null +++ b/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc @@ -0,0 +1,542 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fused_attention_grad_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" +#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h" + +namespace phi { + +template +void FusedAttentionGradKernel( + const Context &dev_ctx, + const DenseTensor &out_grad, + const DenseTensor &x, + const DenseTensor &qkv_weight, + const paddle::optional &qkv_bias, + const paddle::optional &qkv_bias_out, + const paddle::optional &src_mask, + const paddle::optional &src_mask_out, + const DenseTensor &out_linear_weight, + const paddle::optional &out_linear_bias, + const paddle::optional &ln_scale, + const paddle::optional &ln_bias, + const paddle::optional &ln_scale_2, + const paddle::optional &ln_bias_2, + const paddle::optional &ln_out, + const paddle::optional &ln_mean, + const paddle::optional &ln_var, + const paddle::optional &ln_mean_2, + const paddle::optional &ln_var_2, + const paddle::optional &bias_dropout_residual_out, + const DenseTensor &qkv_out, + const DenseTensor &transpose_out_2, + const DenseTensor &qk_out, + const DenseTensor &qktv_out, + const DenseTensor &softmax_out, + const DenseTensor &attn_dropout_mask, + const DenseTensor &attn_dropout_out, + const DenseTensor &fmha_out, + const DenseTensor &out_linear_out, + const DenseTensor &dropout_mask_out, + int num_heads, + bool transpose_qkv_wb, + bool pre_layer_norm, + float epsilon, + float attn_dropout_rate, + bool is_test, + bool attn_dropout_fix_seed, + int attn_dropout_seed, + const std::string &attn_dropout_implementation, + float dropout_rate, + bool dropout_fix_seed, + int dropout_seed, + const std::string &dropout_implementation, + float ln_epsilon, + bool add_residual, + int ring_id, + DenseTensor *qkv_bias_grad, + DenseTensor *qkv_bias_out_grad, + DenseTensor *src_mask_out_grad, + DenseTensor *out_linear_bias_grad, + DenseTensor *ln_scale_grad, + DenseTensor *ln_bias_grad, + DenseTensor *ln_scale_2_grad, + DenseTensor *ln_bias_2_grad, + DenseTensor *x_grad, + DenseTensor *qkv_weight_grad, + DenseTensor *out_linear_weight_grad, + DenseTensor *ln_out_grad, + DenseTensor *bias_dropout_residual_out_grad, + DenseTensor *qkv_out_grad, + DenseTensor *qktv_out_grad, + DenseTensor *transpose_out_2_grad, + DenseTensor *qk_out_grad, + DenseTensor *softmax_out_grad, + DenseTensor *attn_dropout_out_grad, + DenseTensor *fmha_out_grad, + DenseTensor *out_linear_out_grad) { + using XPUTypeT = typename XPUTypeTrait::Type; + + bool is_upscale_in_train_1 = + (attn_dropout_implementation == "upscale_in_train"); + const phi::DenseTensor *seed_1 = nullptr; + + phi::XPUDropoutParam attn_dropout_param; + attn_dropout_param.initXPUDropoutParam(attn_dropout_rate, + is_upscale_in_train_1, + is_test, + attn_dropout_fix_seed, + seed_1, + attn_dropout_seed); + + phi::XPUDropoutParam dropout_param; + dropout_param.initXPUDropoutParam(dropout_rate, + is_upscale_in_train_1, + is_test, + dropout_fix_seed, + seed_1, + dropout_seed); + // get inputs. + const XPUTypeT *d_y_ptr = + reinterpret_cast(out_grad.data()); + // 前向必要参数 + const XPUTypeT *input_x_ptr = reinterpret_cast(x.data()); + const XPUTypeT *qkv_transpose_out_ptr = + reinterpret_cast(transpose_out_2.data()); + const XPUTypeT *qkv_weight_ptr = + reinterpret_cast(qkv_weight.data()); + + const XPUTypeT *softmax_out_ptr = + reinterpret_cast(softmax_out.data()); + const XPUTypeT *attn_dropout_out_ptr = + reinterpret_cast(attn_dropout_out.data()); + + const XPUTypeT *attn_dropout_mask_ptr = + reinterpret_cast(attn_dropout_mask.data()); + const XPUTypeT *fmha_out_ptr = + reinterpret_cast(fmha_out.data()); + + const XPUTypeT *out_linear_weight_ptr = + reinterpret_cast(out_linear_weight.data()); + + const XPUTypeT *dropout_mask_out_ptr = + reinterpret_cast(dropout_mask_out.data()); + // 需要计算的梯度 + auto *d_qkv_weight = qkv_weight_grad; + XPUTypeT *d_qkv_weight_ptr = + reinterpret_cast(dev_ctx.template Alloc(d_qkv_weight)); + + auto *d_qkv_bias = qkv_bias_grad; + XPUTypeT *d_qkv_bias_ptr = + reinterpret_cast(dev_ctx.template Alloc(d_qkv_bias)); + auto *d_out_linear_weight = out_linear_weight_grad; + + XPUTypeT *d_out_linear_weight_ptr = reinterpret_cast( + dev_ctx.template Alloc(d_out_linear_weight)); + + auto *d_out_linear_bias = out_linear_bias_grad; + XPUTypeT *d_out_linear_bias_ptr = reinterpret_cast( + dev_ctx.template Alloc(d_out_linear_bias)); + // 有可能需要 + auto *d_src_mask_out = src_mask_out_grad; + XPUTypeT *d_src_mask_out_ptr = + (d_src_mask_out == nullptr) + ? (nullptr) + : (reinterpret_cast( + dev_ctx.template Alloc(d_src_mask_out))); + // 输出 dx + auto *d_x = x_grad; + XPUTypeT *d_x_ptr = + reinterpret_cast(dev_ctx.template Alloc(d_x)); + + const phi::DenseTensor *ln_out_p = ln_out.get_ptr(); + const phi::DenseTensor *bias_dropout_residual_out_p = + bias_dropout_residual_out.get_ptr(); + + const phi::DenseTensor *ln_scale_p = nullptr; + const phi::DenseTensor *ln_mean_p = nullptr; + const phi::DenseTensor *ln_var_p = nullptr; + phi::DenseTensor *d_ln_scale = nullptr; + phi::DenseTensor *d_ln_bias = nullptr; + + const XPUTypeT *ln_out_ptr = NULL; + const float *ln_scale_ptr = NULL; + const float *ln_mean_ptr = NULL; + const float *ln_var_ptr = NULL; + const XPUTypeT *bias_dropout_residual_out_ptr = NULL; + float *d_ln_scale_ptr = nullptr; + float *d_ln_bias_ptr = nullptr; + + if (pre_layer_norm) { + ln_out_ptr = reinterpret_cast(ln_out_p->data()); + ln_scale_p = ln_scale.get_ptr(); + ln_mean_p = ln_mean.get_ptr(); + ln_var_p = ln_var.get_ptr(); + d_ln_scale = ln_scale_grad; + d_ln_bias = ln_bias_grad; + } else { + ln_scale_p = ln_scale_2.get_ptr(); + ln_mean_p = ln_mean_2.get_ptr(); + ln_var_p = ln_var_2.get_ptr(); + epsilon = ln_epsilon; + d_ln_scale = ln_scale_2_grad; + d_ln_bias = ln_bias_2_grad; + bias_dropout_residual_out_ptr = reinterpret_cast( + bias_dropout_residual_out_p->data()); + } + + ln_scale_ptr = ln_scale_p->data(); + ln_mean_ptr = ln_mean_p->data(); + ln_var_ptr = ln_var_p->data(); + d_ln_scale_ptr = dev_ctx.template Alloc(d_ln_scale); + d_ln_bias_ptr = dev_ctx.template Alloc(d_ln_bias); + + const auto input_x_dims = x.dims(); + const auto qkv_w_dims = qkv_weight.dims(); + + int batch_size = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int embed_dims = input_x_dims[2]; + num_heads = qkv_w_dims[1]; + int head_dims = qkv_w_dims[2]; + + xpu::Context *xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + int r = 0; + // int l3_total_size = xpu_ctx->_l3_mgr.get_size(); + XPUTypeT *d_ln_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] + XPUTypeT *d_dropout_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] + + XPUTypeT *d_fmha_out_ptr = + NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims] + XPUTypeT *d_fmha_out_transpos_tmp_ptr = + NULL; // d_fmha_out_transpos [batch_size, seq_len, num_heads, + // head_dims] + + XPUTypeT *d_qk_ptr = + NULL; // d_qk_ptr[batch_size, num_heads, seq_len, seq_len] + + XPUTypeT *d_combination_qkv_ptr = + NULL; // d_combination_qkv_ptr[3, batch_size, num_heads, seq_len, + // head_dims] + XPUTypeT *d_transpos_qkv_ptr = + NULL; // dx2 [batch_size, seq_len, 3, num_heads, head_dims] + + XPUTypeT *d_last_layernorm_grad_ptr = + NULL; // d_layer_out [batch_size, seq_len, embed_dims] + + const XPUTypeT *dy_input_ptr = d_y_ptr; + + d_ln_grad_ptr = RAII_GUARD.alloc(batch_size * seq_len * embed_dims); + d_dropout_grad_ptr = + RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); + d_fmha_out_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * + num_heads * head_dims); + d_combination_qkv_ptr = + RAII_GUARD.alloc(batch_size * seq_len * embed_dims * 3); + d_transpos_qkv_ptr = RAII_GUARD.alloc_l3_or_gm( + batch_size * seq_len * embed_dims * 3); + d_fmha_out_transpos_tmp_ptr = + RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); + d_qk_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * + seq_len * num_heads); + d_last_layernorm_grad_ptr = + RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); + + if (pre_layer_norm == false) { + r = xpu::layer_norm_grad(xpu_ctx, + bias_dropout_residual_out_ptr, + d_y_ptr, + d_ln_grad_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_var_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + dy_input_ptr = d_ln_grad_ptr; + } + // dropout_grad + DropoutGrad(xpu_ctx, + dy_input_ptr, + dropout_mask_out_ptr, + d_dropout_grad_ptr, + dropout_param, + batch_size * num_heads * seq_len * head_dims); + + // linear_out + phi::XpuFcInfo linear_fc_info; + linear_fc_info.InitFcInfo(0, + batch_size * seq_len, + embed_dims, + embed_dims, + false, + false, + nullptr, + nullptr, + nullptr); + const XPUTypeT *a_1 = reinterpret_cast(NULL); + const XPUTypeT *b_1 = reinterpret_cast(NULL); + const XPUTypeT *a_2 = reinterpret_cast(NULL); + const XPUTypeT *b_2 = reinterpret_cast(NULL); + + XPUTypeT *c_1 = d_fmha_out_ptr; + XPUTypeT *c_2 = d_out_linear_weight_ptr; + phi::XpuFcInfo info_dfmha; + phi::XpuFcInfo info_dlinear_w; + + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + linear_fc_info, + false, + false, + fmha_out_ptr, + out_linear_weight_ptr, + d_dropout_grad_ptr); + + std::tie(info_dfmha, info_dlinear_w, a_1, b_1, a_2, b_2) = fc_info; + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_dlinear_w, 1.0f, true); + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_dfmha, 1.0f, true); + + // dlinear_bias + r = xpu::reduce_sum(xpu_ctx, + d_dropout_grad_ptr, + d_out_linear_bias_ptr, + {batch_size * seq_len, embed_dims}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + { + int qkv_size = batch_size * seq_len * num_heads * head_dims; + const XPUTypeT *q_out_ptr = qkv_transpose_out_ptr; + const XPUTypeT *k_out_ptr = q_out_ptr + qkv_size; + const XPUTypeT *v_out_ptr = k_out_ptr + qkv_size; + XPUTypeT *d_q_out_ptr = d_combination_qkv_ptr; + XPUTypeT *d_k_out_ptr = d_q_out_ptr + qkv_size; + XPUTypeT *d_v_out_ptr = d_k_out_ptr + qkv_size; + r = xpu::transpose(xpu_ctx, + d_fmha_out_ptr, + d_fmha_out_transpos_tmp_ptr, + {batch_size, seq_len, num_heads, head_dims}, + {0, 2, 1, 3}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + phi::XpuFcInfo qktv_fc_info; + qktv_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + head_dims, + seq_len, + false, + false, + nullptr, + nullptr, + nullptr); + + const XPUTypeT *a_1 = reinterpret_cast(NULL); + const XPUTypeT *b_1 = reinterpret_cast(NULL); + const XPUTypeT *a_2 = reinterpret_cast(NULL); + const XPUTypeT *b_2 = reinterpret_cast(NULL); + XPUTypeT *c_1 = d_qk_ptr; + XPUTypeT *c_2 = d_v_out_ptr; + phi::XpuFcInfo info_d_qk; + phi::XpuFcInfo info_d_v; + + std::tuple + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + qktv_fc_info, + false, + false, + attn_dropout_out_ptr, + v_out_ptr, + d_fmha_out_transpos_tmp_ptr); + + std::tie(info_d_qk, info_d_v, a_1, b_1, a_2, b_2) = fc_info; + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_qk, 1.0f, true); + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_d_v, 1.0f, true); + + DropoutGrad(xpu_ctx, + d_qk_ptr, + attn_dropout_mask_ptr, + d_qk_ptr, + attn_dropout_param, + batch_size * seq_len * seq_len * num_heads); + + r = xpu::softmax_grad(xpu_ctx, + softmax_out_ptr, + d_qk_ptr, + d_qk_ptr, + {batch_size, num_heads, seq_len, seq_len}, + 3); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); + + if (d_src_mask_out_ptr) { + r = xpu::copy(xpu_ctx, + d_qk_ptr, + d_src_mask_out_ptr, + batch_size * seq_len * seq_len * num_heads); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + } + phi::XpuFcInfo qk_fc_info; + qk_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + seq_len, + head_dims, + false, + true, + nullptr, + nullptr, + nullptr); + + a_1 = reinterpret_cast(NULL); + b_1 = reinterpret_cast(NULL); + a_2 = reinterpret_cast(NULL); + b_2 = reinterpret_cast(NULL); + c_1 = d_q_out_ptr; + c_2 = d_k_out_ptr; + phi::XpuFcInfo info_d_q; + phi::XpuFcInfo info_d_k; + + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + qk_fc_info, + false, + true, + q_out_ptr, + k_out_ptr, + d_qk_ptr); + + std::tie(info_d_q, info_d_k, a_1, b_1, a_2, b_2) = fc_info; + + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_q, 1.0f / sqrt(head_dims), true); + + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_d_k, 1.0f, true); + } + + // + r = xpu::transpose(xpu_ctx, + d_combination_qkv_ptr, + d_transpos_qkv_ptr, + {3, batch_size, num_heads, seq_len, head_dims}, + {1, 3, 0, 2, 4}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + // dx and d_qkv_w + phi::XpuFcInfo qkv_fc_info; + qkv_fc_info.InitFcInfo(0, + batch_size * seq_len, + 3 * num_heads * head_dims, + embed_dims, + false, + true, + nullptr, + nullptr, + nullptr); + + a_1 = reinterpret_cast(NULL); + b_1 = reinterpret_cast(NULL); + a_2 = reinterpret_cast(NULL); + b_2 = reinterpret_cast(NULL); + c_1 = (pre_layer_norm == true) ? d_last_layernorm_grad_ptr : d_x_ptr; + c_2 = d_qkv_weight_ptr; + phi::XpuFcInfo info_d_x; + phi::XpuFcInfo info_d_qkv_w; + + const XPUTypeT *use_calc_input_x_ptr = + (pre_layer_norm == true) ? ln_out_ptr : input_x_ptr; + + fc_info = phi::MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + qkv_fc_info, + false, + true, + use_calc_input_x_ptr, + qkv_weight_ptr, + d_transpos_qkv_ptr); + + std::tie(info_d_x, info_d_qkv_w, a_1, b_1, a_2, b_2) = fc_info; + phi::MatMulXPUFunction( + xpu_ctx, a_1, b_1, c_1, info_d_x, 1.0f, true); + phi::MatMulXPUFunction( + xpu_ctx, a_2, b_2, c_2, info_d_qkv_w, 1.0f, true); + + // d_qkv_bias + r = xpu::reduce_sum(xpu_ctx, + d_transpos_qkv_ptr, + d_qkv_bias_ptr, + {batch_size * seq_len, 3 * embed_dims}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + + if (pre_layer_norm) { + r = xpu::layer_norm_grad(xpu_ctx, + input_x_ptr, + c_1, + d_x_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_mean_ptr, + ln_var_ptr, + d_ln_scale_ptr, + d_ln_bias_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad"); + } + + // add rediaus dy + r = xpu::add(xpu_ctx, + dy_input_ptr, + d_x_ptr, + d_x_ptr, + batch_size * seq_len * embed_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); +} +} // namespace phi + +PD_REGISTER_KERNEL(fused_attention_grad, + XPU, + ALL_LAYOUT, + phi::FusedAttentionGradKernel, + float, + phi::dtype::float16) { + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32); +} diff --git a/paddle/phi/kernels/xpu/fused_attention_kernel.cc b/paddle/phi/kernels/xpu/fused_attention_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..e91c109b375025a2e6461121e5cbb19fb9864b01 --- /dev/null +++ b/paddle/phi/kernels/xpu/fused_attention_kernel.cc @@ -0,0 +1,434 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fused_attention_kernel.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" +#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h" + +namespace phi { + +template +void FusedAttentionKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &ln_scale, + const paddle::optional &ln_bias, + const DenseTensor &qkv_weight, + const paddle::optional &qkv_bias, + const paddle::optional &cache_kv, + const paddle::optional &src_mask, + const DenseTensor &out_linear_weight, + const paddle::optional &out_linear_bias, + const paddle::optional &ln_scale_2, + const paddle::optional &ln_bias_2, + int num_heads, + bool transpose_qkv_wb, + bool pre_layer_norm, + float epsilon, + float attn_dropout_rate, + bool is_test, + bool attn_dropout_fix_seed, + int attn_dropout_seed, + const std::string &attn_dropout_implementation, + float dropout_rate, + bool dropout_fix_seed, + int dropout_seed, + const std::string &dropout_implementation, + float ln_epsilon, + bool add_residual, + int ring_id, + DenseTensor *ln_mean, + DenseTensor *ln_var, + DenseTensor *ln_out, + DenseTensor *qkv_out, + DenseTensor *qkv_bias_out, + DenseTensor *transpose_out_2, + DenseTensor *qk_out, + DenseTensor *qktv_out, + DenseTensor *softmax_out, + DenseTensor *attn_dropout_mask_out, + DenseTensor *attn_dropout_out, + DenseTensor *src_mask_out, + DenseTensor *fmha_out, + DenseTensor *out_linear_out, + DenseTensor *dropout_mask_out, + DenseTensor *ln_mean_2, + DenseTensor *ln_var_2, + DenseTensor *bias_dropout_residual_out, + DenseTensor *cache_kv_out, + DenseTensor *out) { + using XPUTypeT = typename XPUTypeTrait::Type; + + // shape [batch_size, 1, 1, seq_len] + const phi::DenseTensor *src_mask_p = src_mask.get_ptr(); + + const phi::DenseTensor *ln_scale_p = nullptr; + const phi::DenseTensor *ln_bias_p = nullptr; + + if (pre_layer_norm) { + ln_scale_p = ln_scale.get_ptr(); + ln_bias_p = ln_bias.get_ptr(); + } else { + ln_scale_p = ln_scale_2.get_ptr(); + ln_bias_p = ln_bias_2.get_ptr(); + epsilon = ln_epsilon; + } + + dev_ctx.template Alloc(qk_out); + dev_ctx.template Alloc(qktv_out); + dev_ctx.template Alloc(out_linear_out); + dev_ctx.template Alloc(qkv_bias_out); + dev_ctx.template Alloc(src_mask_out); + dev_ctx.template Alloc(qkv_out); + + bool is_upscale_in_train_1 = + (attn_dropout_implementation == "upscale_in_train"); + const phi::DenseTensor *seed_1 = nullptr; + + phi::XPUDropoutParam attn_dropout_param; + attn_dropout_param.initXPUDropoutParam(attn_dropout_rate, + is_upscale_in_train_1, + is_test, + attn_dropout_fix_seed, + seed_1, + attn_dropout_seed); + + phi::XPUDropoutParam dropout_param; + dropout_param.initXPUDropoutParam(dropout_rate, + is_upscale_in_train_1, + is_test, + dropout_fix_seed, + seed_1, + dropout_seed); + + // 先计算纬度 + const auto input_x_dims = x.dims(); + const auto qkv_w_dims = qkv_weight.dims(); + + int batch_size = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int embed_dims = input_x_dims[2]; + num_heads = qkv_w_dims[1]; + int head_dims = qkv_w_dims[2]; + + // 输入指针 + const XPUTypeT *input_x_ptr = reinterpret_cast(x.data()); + + const XPUTypeT *qkv_weight_ptr = + reinterpret_cast(qkv_weight.data()); + const DenseTensor *qkv_bias_p = qkv_bias.get_ptr(); + const XPUTypeT *qkv_bias_ptr = + reinterpret_cast(qkv_bias_p->data()); + const XPUTypeT *src_mask_ptr = + (src_mask_p == nullptr) + ? (nullptr) + : (reinterpret_cast(src_mask_p->data())); + + const XPUTypeT *out_linear_weight_ptr = + reinterpret_cast(out_linear_weight.data()); + + const DenseTensor *out_linear_bias_p = out_linear_bias.get_ptr(); + const XPUTypeT *out_linear_bias_ptr = + reinterpret_cast(out_linear_bias_p->data()); + + const float *ln_scale_ptr = + (ln_scale_p == nullptr) ? (nullptr) : ln_scale_p->data(); + + const float *ln_bias_ptr = + (ln_bias_p == nullptr) ? (nullptr) : ln_bias_p->data(); + + // 输出指针 + XPUTypeT *qkv_transpose_out_ptr = + reinterpret_cast(dev_ctx.template Alloc(transpose_out_2)); + + XPUTypeT *softmax_out_ptr = + reinterpret_cast(dev_ctx.template Alloc(softmax_out)); + + XPUTypeT *attn_dropout_mask_out_ptr = reinterpret_cast( + dev_ctx.template Alloc(attn_dropout_mask_out)); + + XPUTypeT *attn_dropout_out_ptr = + reinterpret_cast(dev_ctx.template Alloc(attn_dropout_out)); + + XPUTypeT *fmha_out_ptr = + reinterpret_cast(dev_ctx.template Alloc(fmha_out)); + + XPUTypeT *dropout_mask_out_ptr = + reinterpret_cast(dev_ctx.template Alloc(dropout_mask_out)); + + XPUTypeT *out_ptr = + reinterpret_cast(dev_ctx.template Alloc(out)); + + XPUTypeT *bias_dropout_residual_out_ptr = + (bias_dropout_residual_out == nullptr) + ? (nullptr) + : (reinterpret_cast( + dev_ctx.template Alloc(bias_dropout_residual_out))); + + float *ln_mean_ptr = + (ln_mean == nullptr) + ? (nullptr) + : reinterpret_cast(dev_ctx.template Alloc(ln_mean)); + + float *ln_var_ptr = + (ln_var == nullptr) + ? (nullptr) + : reinterpret_cast(dev_ctx.template Alloc(ln_var)); + + XPUTypeT *ln_out_ptr = + (ln_out == nullptr) + ? (nullptr) + : (reinterpret_cast(dev_ctx.template Alloc(ln_out))); + + xpu::Context *xpu_ctx = dev_ctx.x_context(); + + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + int l3_total_size = xpu_ctx->_l3_mgr.get_size(); + + XPUTypeT *qkv_before_transpos_ptr = + NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims] + XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len] + XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims] + XPUTypeT *linear_out_ptr = NULL; // x4, x5 [batch_size, seq_len, embed_dims] + + int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims; + int temp_size_2 = batch_size * num_heads * seq_len * seq_len; + int temp_size_3 = batch_size * num_heads * seq_len * head_dims; + int temp_size_4 = batch_size * seq_len * embed_dims; + + std::vector temp_vec = { + temp_size_1, temp_size_2, temp_size_3, temp_size_4}; + std::sort(temp_vec.begin(), temp_vec.end(), std::greater()); + XPUTypeT *max_gm_ptr = RAII_GUARD.alloc(temp_vec[0]); + PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr); + qkv_before_transpos_ptr = max_gm_ptr; + qk_ptr = max_gm_ptr; + qkv_ptr = max_gm_ptr; + linear_out_ptr = max_gm_ptr; + int sizeof_t = sizeof(XPUTypeT); + for (size_t i = 0; i < temp_vec.size(); ++i) { + if (l3_total_size >= temp_vec[i] * sizeof_t) { + XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3(temp_vec[i]); + qkv_before_transpos_ptr = + (temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + linear_out_ptr = (temp_size_4 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; + break; + } + } + + int r = 0; + const XPUTypeT *x_cacl_ptr = input_x_ptr; + if (pre_layer_norm) { + r = xpu::layer_norm(xpu_ctx, + input_x_ptr, + ln_out_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_var_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + x_cacl_ptr = ln_out_ptr; + } + + // fc + phi::XpuFcInfo qkv_fc_info; + qkv_fc_info.InitFcInfo(0, + batch_size * seq_len, + 3 * num_heads * head_dims, + embed_dims, + false, + true, + nullptr, + nullptr, + nullptr); + + phi::MatMulXPUFunction(xpu_ctx, + x_cacl_ptr, + qkv_weight_ptr, + qkv_before_transpos_ptr, + qkv_fc_info, + 1.0f); + + // bias + r = xpu::broadcast_add(xpu_ctx, + qkv_before_transpos_ptr, + qkv_bias_ptr, + qkv_before_transpos_ptr, + {batch_size * seq_len, 3 * num_heads * head_dims}, + {3 * num_heads * head_dims}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + // transpose + r = xpu::transpose(xpu_ctx, + qkv_before_transpos_ptr, + qkv_transpose_out_ptr, + {batch_size, seq_len, 3, num_heads, head_dims}, + {2, 0, 3, 1, 4}); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + int qkv_every_size = batch_size * seq_len * num_heads * head_dims; + { + float alpha = 1.0 / sqrt(head_dims); + r = scale(xpu_ctx, + qkv_transpose_out_ptr, + qkv_transpose_out_ptr, + qkv_every_size, + false, + alpha, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } + + // begin fhma + // 1. qk 2. qk + mask 3. softmax 4.dropout 5. qkv 6. transpos + { + const XPUTypeT *q_ptr = qkv_transpose_out_ptr; + const XPUTypeT *k_ptr = q_ptr + qkv_every_size; + const XPUTypeT *v_ptr = k_ptr + qkv_every_size; + phi::XpuFcInfo qk_fc_info; + qk_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + seq_len, + head_dims, + false, + true, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction( + xpu_ctx, q_ptr, k_ptr, qk_ptr, qk_fc_info, 1.0f); + + if (src_mask_ptr) { + r = xpu::broadcast_add(xpu_ctx, + qk_ptr, + src_mask_ptr, + qk_ptr, + {batch_size, num_heads, seq_len, seq_len}, + {batch_size, 1, 1, seq_len}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + } + // do softmax + r = xpu::softmax(xpu_ctx, + qk_ptr, + softmax_out_ptr, + {batch_size, num_heads, seq_len, seq_len}, + 3); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); + + // do dropout + phi::Dropout(xpu_ctx, + softmax_out_ptr, + attn_dropout_mask_out_ptr, + attn_dropout_out_ptr, + attn_dropout_param, + batch_size * num_heads * seq_len * seq_len); + + phi::XpuFcInfo qktv_fc_info; + qktv_fc_info.InitFcInfo(batch_size * num_heads, + seq_len, + head_dims, + seq_len, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction( + xpu_ctx, attn_dropout_out_ptr, v_ptr, qkv_ptr, qktv_fc_info, 1.0f); + r = xpu::transpose(xpu_ctx, + qkv_ptr, + fmha_out_ptr, + {batch_size, num_heads, seq_len, head_dims}, + {0, 2, 1, 3}); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + } + + // linear_out + phi::XpuFcInfo linear_fc_info; + linear_fc_info.InitFcInfo(0, + batch_size * seq_len, + embed_dims, + embed_dims, + false, + false, + nullptr, + nullptr, + nullptr); + phi::MatMulXPUFunction(xpu_ctx, + fmha_out_ptr, + out_linear_weight_ptr, + linear_out_ptr, + linear_fc_info, + 1.0f); + + // out_linear_bias_ptr + r = xpu::broadcast_add(xpu_ctx, + linear_out_ptr, + out_linear_bias_ptr, + linear_out_ptr, + {batch_size * seq_len, embed_dims}, + {embed_dims}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + + Dropout(xpu_ctx, + linear_out_ptr, + dropout_mask_out_ptr, + linear_out_ptr, + dropout_param, + batch_size * seq_len * embed_dims); + + XPUTypeT *real_out_ptr = out_ptr; + if (pre_layer_norm == false) { + real_out_ptr = bias_dropout_residual_out_ptr; + } + + r = xpu::add(xpu_ctx, + linear_out_ptr, + input_x_ptr, + real_out_ptr, + batch_size * seq_len * embed_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + + if (pre_layer_norm == false) { + r = xpu::layer_norm(xpu_ctx, + real_out_ptr, + out_ptr, + batch_size * seq_len, + embed_dims, + epsilon, + ln_scale_ptr, + ln_bias_ptr, + ln_mean_ptr, + ln_var_ptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(fused_attention, + XPU, + ALL_LAYOUT, + phi::FusedAttentionKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/xpu_fused_common_function.h b/paddle/phi/kernels/xpu/xpu_fused_common_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1aac7ff1392a351b558b96c7e52619f38adcd175 --- /dev/null +++ b/paddle/phi/kernels/xpu/xpu_fused_common_function.h @@ -0,0 +1,153 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +namespace phi { +struct XPUDropoutParam { + float dropout_prob; + bool is_upscale_in_train; + bool is_test; + bool fix_seed; + const phi::DenseTensor *tensor_seed; + int seed_val; + + XPUDropoutParam() { + fix_seed = false; + is_test = false; + is_upscale_in_train = false; + dropout_prob = 0.5; + tensor_seed = nullptr; + seed_val = 0; + } + + void initXPUDropoutParam(float dropout_prob_, + bool is_upscale_in_train_, + bool is_test_, + bool fix_seed_, + const phi::DenseTensor *tensor_seed, + int seed_val_) { + dropout_prob = dropout_prob_; + is_upscale_in_train = is_upscale_in_train_; + is_test = is_test_; + fix_seed = fix_seed_; + if (tensor_seed) { + seed_val = *(tensor_seed->data()); + } else { + seed_val = fix_seed ? seed_val_ : 0; + } + } +}; + +/****************** + * check is l3 + ******************/ + +static bool is_in_l3(const void *addr) { + int64_t addr_int = (int64_t)addr; + int addr_int_high = addr_int >> 32; + return (addr_int_high == 0); +} + +/************************* + * dropout + *************************/ + +template +void Dropout(xpu::Context *xpu_ctx, + const T *x, + T *mask, + T *y, + const XPUDropoutParam ¶m, + int len) { + using XPUType = typename XPUTypeTrait::Type; + int r = XPU_SUCCESS; + if (param.dropout_prob == 0.0f) { + r = xpu::copy(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + return; + } + if (!param.is_test) { + if (param.dropout_prob == 1.0f) { + r = xpu::constant( + xpu_ctx, reinterpret_cast(y), len, XPUType(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + r = xpu::constant( + xpu_ctx, reinterpret_cast(mask), len, XPUType(0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + } else { + r = xpu::dropout(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(mask), + param.seed_val, + len, + param.is_upscale_in_train, + param.dropout_prob); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); + } + } else { + float scale = (param.is_upscale_in_train) + ? (1.0) + : (static_cast(1.0f - param.dropout_prob)); + r = xpu::scale(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + len, + false, + scale, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } +} + +template +void DropoutGrad(xpu::Context *xpu_ctx, + const T *dy, + const T *mask, + T *dx, + const XPUDropoutParam ¶m, + int len) { + using XPUType = typename XPUTypeTrait::Type; + if (param.dropout_prob == 0.0f) { + int r = xpu::copy(xpu_ctx, + reinterpret_cast(dy), + reinterpret_cast(dx), + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + return; + } + if (!param.is_upscale_in_train) { + int r = xpu::mul(xpu_ctx, + reinterpret_cast(dy), + reinterpret_cast(mask), + reinterpret_cast(dx), + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); + } else { + int r = xpu::dropout_grad(xpu_ctx, + reinterpret_cast(mask), + reinterpret_cast(dy), + reinterpret_cast(dx), + param.dropout_prob, + len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad"); + } +} +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 406b8d5f55bc615bf97194b4a9c1321e00c4b250..3d2ca74e43ad30ef3f66c6fab05a85a58ea2b226 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1160,6 +1160,8 @@ set(STATIC_BUILD_TESTS test_eigh_op test_fake_quantize_op test_fetch_lod_tensor_array + test_fused_attention_op + test_fused_attention_op_api test_imperative_optimizer test_lamb_op test_layer_norm_op @@ -1186,6 +1188,11 @@ set(STATIC_BUILD_TESTS test_while_op test_one_hot_v2_op) +if(NOT WITH_GPU) + list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op) + list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api) +endif() + foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) py_test_modules( ${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS