未验证 提交 7b56bd25 编写于 作者: S Sonder 提交者: GitHub

Move fused_attention op to phi [迁移XPU OpKernel] [ test=kunlun ] (#53011)

* trans fused attention to phi

* add optional parm

* trans fused_attention_grad to phi

* add fused attention grad register info

* fix include

* test=kunlun

* add fused attention to static build list

* add remove

* update remove
上级 543efcc5
...@@ -33,8 +33,6 @@ std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = { ...@@ -33,8 +33,6 @@ std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = {
"cudnn_lstm", "cudnn_lstm",
"dequantize", "dequantize",
"distributed_fused_lamb", "distributed_fused_lamb",
"fused_attention",
"fused_attention_grad",
"fused_batch_norm_act", "fused_batch_norm_act",
"fused_batch_norm_act_grad", "fused_batch_norm_act_grad",
"fusion_group", "fusion_group",
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/xpu_fused_common_function.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using XPUTypeT = typename XPUTypeTrait<T>::Type;
// inputs tensor
auto *input_x = ctx.Input<phi::DenseTensor>("X");
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
// shape [3, num_head, dim_head, dim_embed]
auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
// shape [3 , num_head, dim_head]
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
// shape [batch_size, 1, 1, seq_len]
auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
// shape [dim_embed, dim_embed]
auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
// shape [dim_embed]
auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
const phi::DenseTensor *ln_scale = nullptr;
const phi::DenseTensor *ln_bias = nullptr;
float epsilon = 0.0f;
if (pre_layer_norm) {
ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
ln_bias = ctx.Input<phi::DenseTensor>("LnBias");
epsilon = ctx.Attr<float>("epsilon");
} else {
ln_scale = ctx.Input<phi::DenseTensor>("Ln2Scale");
ln_bias = ctx.Input<phi::DenseTensor>("Ln2Bias");
epsilon = ctx.Attr<float>("ln_epsilon");
}
// outputs tensor
// qkv 的值,并已经做了transpos后的值
// shape [3, batch_size, num_head, seq_len, dim_head]
auto *TransposeOut2 = ctx.Output<phi::DenseTensor>("TransposeOut2");
// shape [batch_size, num_head, seq_len, seq_len]
auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
// shape [batch_size, num_head, seq_len, seq_len]
auto *attn_dropout_mask_out =
ctx.Output<phi::DenseTensor>("AttnDropoutMaskOut");
// shape [batch_size, num_head, seq_len, seq_len]
auto *attn_dropout_out = ctx.Output<phi::DenseTensor>("AttnDropoutOut");
// shape [[batch_size, seq_len, num_head, dim_head]]
auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");
// shape [batch_size, seq_len, dim_embed]
auto *dropout_mask_out = ctx.Output<phi::DenseTensor>("DropoutMaskOut");
// final output
// shape [batch_size, seq_len, dim_embed]
auto *out = ctx.Output<phi::DenseTensor>("Y");
// 下面这个tensor是不需要返回, 但是新的动态图需要
auto *QKOut = ctx.Output<phi::DenseTensor>("QKOut");
QKOut->mutable_data<T>(ctx.GetPlace());
auto *QKTVOut = ctx.Output<phi::DenseTensor>("QKTVOut");
QKTVOut->mutable_data<T>(ctx.GetPlace());
auto *OutLinearOut = ctx.Output<phi::DenseTensor>("OutLinearOut");
OutLinearOut->mutable_data<T>(ctx.GetPlace());
auto *QKVBiasOut = ctx.Output<phi::DenseTensor>("QKVBiasOut");
QKVBiasOut->mutable_data<T>(ctx.GetPlace());
auto *SrcMaskOut = ctx.Output<phi::DenseTensor>("SrcMaskOut");
SrcMaskOut->mutable_data<T>(ctx.GetPlace());
auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
qkv_out->mutable_data<T>(ctx.GetPlace());
phi::DenseTensor *bias_dropout_residual_out = nullptr;
phi::DenseTensor *ln_mean = nullptr;
phi::DenseTensor *ln_var = nullptr;
phi::DenseTensor *ln_out = nullptr;
if (pre_layer_norm) {
ln_mean = ctx.Output<phi::DenseTensor>("LnMean");
ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
ln_out = ctx.Output<phi::DenseTensor>("LnOut");
} else {
ln_mean = ctx.Output<phi::DenseTensor>("Ln2Mean");
ln_var = ctx.Output<phi::DenseTensor>("Ln2Variance");
bias_dropout_residual_out =
ctx.Output<phi::DenseTensor>("BiasDropoutResidualOut");
}
// dropout info
float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 =
ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
XPUDropoutParam attn_dropout_param;
attn_dropout_param.initXPUDropoutParam(attn_dropout_rate,
is_upscale_in_train_1,
is_test_1,
is_fix_seed_1,
seed_1,
seed_val_1);
XPUDropoutParam dropout_param(ctx, 0);
// 先计算纬度
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
int batch_size = input_x_dims[0];
int seq_len = input_x_dims[1];
int embed_dims = input_x_dims[2];
int num_heads = qkv_w_dims[1];
int head_dims = qkv_w_dims[2];
// 输入指针
const XPUTypeT *input_x_ptr =
reinterpret_cast<const XPUTypeT *>(input_x->data<T>());
const XPUTypeT *qkv_weight_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_weight->data<T>());
const XPUTypeT *qkv_bias_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_bias->data<T>());
const XPUTypeT *src_mask_ptr =
(src_mask == nullptr)
? (nullptr)
: (reinterpret_cast<const XPUTypeT *>(src_mask->data<T>()));
const XPUTypeT *out_linear_weight_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_weight->data<T>());
const XPUTypeT *out_linear_bias_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_bias->data<T>());
const float *ln_scale_ptr =
(ln_scale == nullptr) ? (nullptr) : ln_scale->data<float>();
const float *ln_bias_ptr =
(ln_bias == nullptr) ? (nullptr) : ln_bias->data<float>();
// 输出指针
XPUTypeT *qkv_transpose_out_ptr = reinterpret_cast<XPUTypeT *>(
TransposeOut2->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *softmax_out_ptr = reinterpret_cast<XPUTypeT *>(
softmax_out->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *attn_dropout_mask_out_ptr = reinterpret_cast<XPUTypeT *>(
attn_dropout_mask_out->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *attn_dropout_out_ptr = reinterpret_cast<XPUTypeT *>(
attn_dropout_out->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *fmha_out_ptr =
reinterpret_cast<XPUTypeT *>(fmha_out->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *dropout_mask_out_ptr = reinterpret_cast<XPUTypeT *>(
dropout_mask_out->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *out_ptr =
reinterpret_cast<XPUTypeT *>(out->mutable_data<T>(ctx.GetPlace()));
XPUTypeT *bias_dropout_residual_out_ptr =
(bias_dropout_residual_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace())));
float *ln_mean_ptr = (ln_mean == nullptr)
? (nullptr)
: ln_mean->mutable_data<float>(ctx.GetPlace());
float *ln_var_ptr = (ln_var == nullptr)
? (nullptr)
: ln_var->mutable_data<float>(ctx.GetPlace());
XPUTypeT *ln_out_ptr = (ln_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(
ln_out->mutable_data<T>(ctx.GetPlace())));
auto &dev_ctx = ctx.template device_context<DeviceContext>();
xpu::Context *xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
int l3_total_size = xpu_ctx->_l3_mgr.get_size();
XPUTypeT *qkv_before_transpos_ptr =
NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims]
XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len]
XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims]
XPUTypeT *linear_out_ptr =
NULL; // x4, x5 [batch_size, seq_len, embed_dims]
int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims;
int temp_size_2 = batch_size * num_heads * seq_len * seq_len;
int temp_size_3 = batch_size * num_heads * seq_len * head_dims;
int temp_size_4 = batch_size * seq_len * embed_dims;
std::vector<int> temp_vec = {
temp_size_1, temp_size_2, temp_size_3, temp_size_4};
std::sort(temp_vec.begin(), temp_vec.end(), std::greater<int>());
XPUTypeT *max_gm_ptr = RAII_GUARD.alloc<XPUTypeT>(temp_vec[0]);
PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr);
qkv_before_transpos_ptr = max_gm_ptr;
qk_ptr = max_gm_ptr;
qkv_ptr = max_gm_ptr;
linear_out_ptr = max_gm_ptr;
int sizeof_t = sizeof(XPUTypeT);
for (size_t i = 0; i < temp_vec.size(); ++i) {
if (l3_total_size >= temp_vec[i] * sizeof_t) {
XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3<XPUTypeT>(temp_vec[i]);
qkv_before_transpos_ptr =
(temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
linear_out_ptr = (temp_size_4 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
break;
}
}
int r = 0;
const XPUTypeT *x_cacl_ptr = input_x_ptr;
if (pre_layer_norm) {
r = xpu::layer_norm(xpu_ctx,
input_x_ptr,
ln_out_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_bias_ptr,
ln_mean_ptr,
ln_var_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
x_cacl_ptr = ln_out_ptr;
}
// fc
phi::XpuFcInfo qkv_fc_info;
qkv_fc_info.InitFcInfo(0,
batch_size * seq_len,
3 * num_heads * head_dims,
embed_dims,
false,
true,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx,
x_cacl_ptr,
qkv_weight_ptr,
qkv_before_transpos_ptr,
qkv_fc_info,
1.0f);
// bias
r = xpu::broadcast_add(xpu_ctx,
qkv_before_transpos_ptr,
qkv_bias_ptr,
qkv_before_transpos_ptr,
{batch_size * seq_len, 3 * num_heads * head_dims},
{3 * num_heads * head_dims});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
// transpose
r = xpu::transpose(xpu_ctx,
qkv_before_transpos_ptr,
qkv_transpose_out_ptr,
{batch_size, seq_len, 3, num_heads, head_dims},
{2, 0, 3, 1, 4});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
int qkv_every_size = batch_size * seq_len * num_heads * head_dims;
{
float alpha = 1.0 / sqrt(head_dims);
r = scale(xpu_ctx,
qkv_transpose_out_ptr,
qkv_transpose_out_ptr,
qkv_every_size,
false,
alpha,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
// begin fhma
// 1. qk 2. qk + mask 3. softmax 4.dropout 5. qkv 6. transpos
{
const XPUTypeT *q_ptr = qkv_transpose_out_ptr;
const XPUTypeT *k_ptr = q_ptr + qkv_every_size;
const XPUTypeT *v_ptr = k_ptr + qkv_every_size;
phi::XpuFcInfo qk_fc_info;
qk_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
seq_len,
head_dims,
false,
true,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, q_ptr, k_ptr, qk_ptr, qk_fc_info, 1.0f);
if (src_mask_ptr) {
r = xpu::broadcast_add(xpu_ctx,
qk_ptr,
src_mask_ptr,
qk_ptr,
{batch_size, num_heads, seq_len, seq_len},
{batch_size, 1, 1, seq_len});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
// do softmax
r = xpu::softmax(xpu_ctx,
qk_ptr,
softmax_out_ptr,
{batch_size, num_heads, seq_len, seq_len},
3);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
// do dropout
Dropout<XPUTypeT>(xpu_ctx,
softmax_out_ptr,
attn_dropout_mask_out_ptr,
attn_dropout_out_ptr,
attn_dropout_param,
batch_size * num_heads * seq_len * seq_len);
phi::XpuFcInfo qktv_fc_info;
qktv_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
head_dims,
seq_len,
false,
false,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, attn_dropout_out_ptr, v_ptr, qkv_ptr, qktv_fc_info, 1.0f);
r = xpu::transpose(xpu_ctx,
qkv_ptr,
fmha_out_ptr,
{batch_size, num_heads, seq_len, head_dims},
{0, 2, 1, 3});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
// linear_out
phi::XpuFcInfo linear_fc_info;
linear_fc_info.InitFcInfo(0,
batch_size * seq_len,
embed_dims,
embed_dims,
false,
false,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx,
fmha_out_ptr,
out_linear_weight_ptr,
linear_out_ptr,
linear_fc_info,
1.0f);
// out_linear_bias_ptr
r = xpu::broadcast_add(xpu_ctx,
linear_out_ptr,
out_linear_bias_ptr,
linear_out_ptr,
{batch_size * seq_len, embed_dims},
{embed_dims});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
Dropout(xpu_ctx,
linear_out_ptr,
dropout_mask_out_ptr,
linear_out_ptr,
dropout_param,
batch_size * seq_len * embed_dims);
XPUTypeT *real_out_ptr = out_ptr;
if (pre_layer_norm == false) {
real_out_ptr = bias_dropout_residual_out_ptr;
}
r = xpu::add(xpu_ctx,
linear_out_ptr,
input_x_ptr,
real_out_ptr,
batch_size * seq_len * embed_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
if (pre_layer_norm == false) {
r = xpu::layer_norm(xpu_ctx,
real_out_ptr,
out_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_bias_ptr,
ln_mean_ptr,
ln_var_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
}
}
};
// template <typename T>
template <typename DeviceContext, typename T>
class FusedAttentionGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using XPUTypeT = typename XPUTypeTrait<T>::Type;
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
// dropout info
float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 =
ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
XPUDropoutParam attn_dropout_param;
attn_dropout_param.initXPUDropoutParam(attn_dropout_prob,
is_upscale_in_train_1,
is_test_1,
is_fix_seed_1,
seed_1,
seed_val_1);
XPUDropoutParam dropout_param(ctx, 0);
// get inputs.
auto *d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
const XPUTypeT *d_y_ptr =
reinterpret_cast<const XPUTypeT *>(d_y->data<T>());
// 前向必要参数
auto *input_x = ctx.Input<phi::DenseTensor>("X");
const XPUTypeT *input_x_ptr =
reinterpret_cast<const XPUTypeT *>(input_x->data<T>());
auto *qkv_transpose_out = ctx.Input<phi::DenseTensor>("TransposeOut2");
const XPUTypeT *qkv_transpose_out_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_transpose_out->data<T>());
auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
const XPUTypeT *qkv_weight_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_weight->data<T>());
auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
const XPUTypeT *softmax_out_ptr =
reinterpret_cast<const XPUTypeT *>(softmax_out->data<T>());
auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
const XPUTypeT *attn_dropout_out_ptr =
reinterpret_cast<const XPUTypeT *>(attn_dropout_out->data<T>());
auto *attn_dropout_mask = ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
const XPUTypeT *attn_dropout_mask_ptr =
reinterpret_cast<const XPUTypeT *>(attn_dropout_mask->data<T>());
auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
const XPUTypeT *fmha_out_ptr =
reinterpret_cast<const XPUTypeT *>(fmha_out->data<T>());
auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
const XPUTypeT *out_linear_weight_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_weight->data<T>());
auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
const XPUTypeT *dropout_mask_out_ptr =
reinterpret_cast<const XPUTypeT *>(dropout_mask_out->data<T>());
// 需要计算的梯度
auto *d_qkv_weight =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVW"));
XPUTypeT *d_qkv_weight_ptr = reinterpret_cast<XPUTypeT *>(
d_qkv_weight->mutable_data<T>(ctx.GetPlace()));
auto *d_qkv_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBias"));
XPUTypeT *d_qkv_bias_ptr = reinterpret_cast<XPUTypeT *>(
d_qkv_bias->mutable_data<T>(ctx.GetPlace()));
auto *d_out_linear_weight =
ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearW"));
XPUTypeT *d_out_linear_weight_ptr = reinterpret_cast<XPUTypeT *>(
d_out_linear_weight->mutable_data<T>(ctx.GetPlace()));
auto *d_out_linear_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearBias"));
XPUTypeT *d_out_linear_bias_ptr = reinterpret_cast<XPUTypeT *>(
d_out_linear_bias->mutable_data<T>(ctx.GetPlace()));
// 有可能需要
auto *d_src_mask_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("SrcMaskOut"));
XPUTypeT *d_src_mask_out_ptr =
(d_src_mask_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(
d_src_mask_out->mutable_data<T>(ctx.GetPlace())));
// 输出 dx
auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
XPUTypeT *d_x_ptr =
reinterpret_cast<XPUTypeT *>(d_x->mutable_data<T>(ctx.GetPlace()));
const phi::DenseTensor *ln_out = nullptr;
const phi::DenseTensor *bias_dropout_residual_out = nullptr;
const phi::DenseTensor *ln_scale = nullptr;
const phi::DenseTensor *ln_mean = nullptr;
const phi::DenseTensor *ln_var = nullptr;
phi::DenseTensor *d_ln_scale = nullptr;
phi::DenseTensor *d_ln_bias = nullptr;
const XPUTypeT *ln_out_ptr = NULL;
const float *ln_scale_ptr = NULL;
const float *ln_mean_ptr = NULL;
const float *ln_var_ptr = NULL;
const XPUTypeT *bias_dropout_residual_out_ptr = NULL;
float *d_ln_scale_ptr = nullptr;
float *d_ln_bias_ptr = nullptr;
float epsilon = 0.0f;
if (pre_layer_norm) {
ln_out = ctx.Input<phi::DenseTensor>("LnOut");
ln_out_ptr = reinterpret_cast<const XPUTypeT *>(ln_out->data<T>());
ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
epsilon = ctx.Attr<float>("epsilon");
d_ln_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("LnScale"));
d_ln_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("LnBias"));
} else {
ln_scale = ctx.Input<phi::DenseTensor>("Ln2Scale");
ln_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
ln_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
epsilon = ctx.Attr<float>("ln_epsilon");
d_ln_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Scale"));
d_ln_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Bias"));
bias_dropout_residual_out =
ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut");
bias_dropout_residual_out_ptr = reinterpret_cast<const XPUTypeT *>(
bias_dropout_residual_out->data<T>());
}
ln_scale_ptr = ln_scale->data<float>();
ln_mean_ptr = ln_mean->data<float>();
ln_var_ptr = ln_var->data<float>();
d_ln_scale_ptr = d_ln_scale->mutable_data<float>(ctx.GetPlace());
d_ln_bias_ptr = d_ln_bias->mutable_data<float>(ctx.GetPlace());
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
int batch_size = input_x_dims[0];
int seq_len = input_x_dims[1];
int embed_dims = input_x_dims[2];
int num_heads = qkv_w_dims[1];
int head_dims = qkv_w_dims[2];
auto &dev_ctx = ctx.template device_context<DeviceContext>();
xpu::Context *xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
int r = 0;
// int l3_total_size = xpu_ctx->_l3_mgr.get_size();
XPUTypeT *d_ln_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden]
XPUTypeT *d_dropout_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden]
XPUTypeT *d_fmha_out_ptr =
NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims]
XPUTypeT *d_fmha_out_transpos_tmp_ptr =
NULL; // d_fmha_out_transpos [batch_size, seq_len, num_heads,
// head_dims]
XPUTypeT *d_qk_ptr =
NULL; // d_qk_ptr[batch_size, num_heads, seq_len, seq_len]
XPUTypeT *d_combination_qkv_ptr =
NULL; // d_combination_qkv_ptr[3, batch_size, num_heads, seq_len,
// head_dims]
XPUTypeT *d_transpos_qkv_ptr =
NULL; // dx2 [batch_size, seq_len, 3, num_heads, head_dims]
XPUTypeT *d_last_layernorm_grad_ptr =
NULL; // d_layer_out [batch_size, seq_len, embed_dims]
const XPUTypeT *dy_input_ptr = d_y_ptr;
d_ln_grad_ptr =
RAII_GUARD.alloc<XPUTypeT>(batch_size * seq_len * embed_dims);
d_dropout_grad_ptr =
RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len * embed_dims);
d_fmha_out_ptr = RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len *
num_heads * head_dims);
d_combination_qkv_ptr =
RAII_GUARD.alloc<XPUTypeT>(batch_size * seq_len * embed_dims * 3);
d_transpos_qkv_ptr = RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(
batch_size * seq_len * embed_dims * 3);
d_fmha_out_transpos_tmp_ptr =
RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len * embed_dims);
d_qk_ptr = RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len *
seq_len * num_heads);
d_last_layernorm_grad_ptr =
RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len * embed_dims);
if (pre_layer_norm == false) {
r = xpu::layer_norm_grad(xpu_ctx,
bias_dropout_residual_out_ptr,
d_y_ptr,
d_ln_grad_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_mean_ptr,
ln_var_ptr,
d_ln_scale_ptr,
d_ln_bias_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
dy_input_ptr = d_ln_grad_ptr;
}
// dropout_grad
DropoutGrad<XPUTypeT>(xpu_ctx,
dy_input_ptr,
dropout_mask_out_ptr,
d_dropout_grad_ptr,
dropout_param,
batch_size * num_heads * seq_len * head_dims);
// linear_out
phi::XpuFcInfo linear_fc_info;
linear_fc_info.InitFcInfo(0,
batch_size * seq_len,
embed_dims,
embed_dims,
false,
false,
nullptr,
nullptr,
nullptr);
const XPUTypeT *a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
XPUTypeT *c_1 = d_fmha_out_ptr;
XPUTypeT *c_2 = d_out_linear_weight_ptr;
phi::XpuFcInfo info_dfmha;
phi::XpuFcInfo info_dlinear_w;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
linear_fc_info,
false,
false,
fmha_out_ptr,
out_linear_weight_ptr,
d_dropout_grad_ptr);
std::tie(info_dfmha, info_dlinear_w, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_dlinear_w, 1.0f, true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_dfmha, 1.0f, true);
// dlinear_bias
r = xpu::reduce_sum(xpu_ctx,
d_dropout_grad_ptr,
d_out_linear_bias_ptr,
{batch_size * seq_len, embed_dims},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
{
int qkv_size = batch_size * seq_len * num_heads * head_dims;
const XPUTypeT *q_out_ptr = qkv_transpose_out_ptr;
const XPUTypeT *k_out_ptr = q_out_ptr + qkv_size;
const XPUTypeT *v_out_ptr = k_out_ptr + qkv_size;
XPUTypeT *d_q_out_ptr = d_combination_qkv_ptr;
XPUTypeT *d_k_out_ptr = d_q_out_ptr + qkv_size;
XPUTypeT *d_v_out_ptr = d_k_out_ptr + qkv_size;
r = xpu::transpose<XPUTypeT>(xpu_ctx,
d_fmha_out_ptr,
d_fmha_out_transpos_tmp_ptr,
{batch_size, seq_len, num_heads, head_dims},
{0, 2, 1, 3});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
phi::XpuFcInfo qktv_fc_info;
qktv_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
head_dims,
seq_len,
false,
false,
nullptr,
nullptr,
nullptr);
const XPUTypeT *a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
XPUTypeT *c_1 = d_qk_ptr;
XPUTypeT *c_2 = d_v_out_ptr;
phi::XpuFcInfo info_d_qk;
phi::XpuFcInfo info_d_v;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
qktv_fc_info,
false,
false,
attn_dropout_out_ptr,
v_out_ptr,
d_fmha_out_transpos_tmp_ptr);
std::tie(info_d_qk, info_d_v, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_d_qk, 1.0f, true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_d_v, 1.0f, true);
DropoutGrad<XPUTypeT>(xpu_ctx,
d_qk_ptr,
attn_dropout_mask_ptr,
d_qk_ptr,
attn_dropout_param,
batch_size * seq_len * seq_len * num_heads);
r = xpu::softmax_grad<XPUTypeT>(xpu_ctx,
softmax_out_ptr,
d_qk_ptr,
d_qk_ptr,
{batch_size, num_heads, seq_len, seq_len},
3);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad");
if (d_src_mask_out_ptr) {
r = xpu::copy<XPUTypeT>(xpu_ctx,
d_qk_ptr,
d_src_mask_out_ptr,
batch_size * seq_len * seq_len * num_heads);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
phi::XpuFcInfo qk_fc_info;
qk_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
seq_len,
head_dims,
false,
true,
nullptr,
nullptr,
nullptr);
a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
c_1 = d_q_out_ptr;
c_2 = d_k_out_ptr;
phi::XpuFcInfo info_d_q;
phi::XpuFcInfo info_d_k;
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
qk_fc_info,
false,
true,
q_out_ptr,
k_out_ptr,
d_qk_ptr);
std::tie(info_d_q, info_d_k, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_d_q, 1.0f / sqrt(head_dims), true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_d_k, 1.0f, true);
}
//
r = xpu::transpose<XPUTypeT>(xpu_ctx,
d_combination_qkv_ptr,
d_transpos_qkv_ptr,
{3, batch_size, num_heads, seq_len, head_dims},
{1, 3, 0, 2, 4});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
// dx and d_qkv_w
phi::XpuFcInfo qkv_fc_info;
qkv_fc_info.InitFcInfo(0,
batch_size * seq_len,
3 * num_heads * head_dims,
embed_dims,
false,
true,
nullptr,
nullptr,
nullptr);
a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
c_1 = (pre_layer_norm == true) ? d_last_layernorm_grad_ptr : d_x_ptr;
c_2 = d_qkv_weight_ptr;
phi::XpuFcInfo info_d_x;
phi::XpuFcInfo info_d_qkv_w;
const XPUTypeT *use_calc_input_x_ptr =
(pre_layer_norm == true) ? ln_out_ptr : input_x_ptr;
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
qkv_fc_info,
false,
true,
use_calc_input_x_ptr,
qkv_weight_ptr,
d_transpos_qkv_ptr);
std::tie(info_d_x, info_d_qkv_w, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_d_x, 1.0f, true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_d_qkv_w, 1.0f, true);
// d_qkv_bias
r = xpu::reduce_sum(xpu_ctx,
d_transpos_qkv_ptr,
d_qkv_bias_ptr,
{batch_size * seq_len, 3 * embed_dims},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
if (pre_layer_norm) {
r = xpu::layer_norm_grad(xpu_ctx,
input_x_ptr,
c_1,
d_x_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_mean_ptr,
ln_var_ptr,
d_ln_scale_ptr,
d_ln_bias_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
}
// add rediaus dy
r = xpu::add(xpu_ctx,
dy_input_ptr,
d_x_ptr,
d_x_ptr,
batch_size * seq_len * embed_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
fused_attention,
ops::FusedAttentionOpKernel<phi::XPUContext, float>,
ops::FusedAttentionOpKernel<phi::XPUContext, paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
fused_attention_grad,
ops::FusedAttentionGradXPUKernel<phi::XPUContext, float>,
ops::FusedAttentionGradXPUKernel<phi::XPUContext,
paddle::platform::float16>);
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FusedAttentionGradKernel(
const Context &dev_ctx,
const DenseTensor &out_grad,
const DenseTensor &x,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &qkv_bias_out,
const paddle::optional<DenseTensor> &src_mask,
const paddle::optional<DenseTensor> &src_mask_out,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
const paddle::optional<DenseTensor> &ln_out,
const paddle::optional<DenseTensor> &ln_mean,
const paddle::optional<DenseTensor> &ln_var,
const paddle::optional<DenseTensor> &ln_mean_2,
const paddle::optional<DenseTensor> &ln_var_2,
const paddle::optional<DenseTensor> &bias_dropout_residual_out,
const DenseTensor &qkv_out,
const DenseTensor &transpose_out_2,
const DenseTensor &qk_out,
const DenseTensor &qktv_out,
const DenseTensor &softmax_out,
const DenseTensor &attn_dropout_mask_out,
const DenseTensor &attn_dropout_out,
const DenseTensor &fmha_out,
const DenseTensor &out_linear_out,
const DenseTensor &dropout_mask_out,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *qkv_bias_grad,
DenseTensor *qkv_bias_out_grad,
DenseTensor *src_mask_out_grad,
DenseTensor *out_linear_bias_grad,
DenseTensor *ln_scale_grad,
DenseTensor *ln_bias_grad,
DenseTensor *ln_scale_2_grad,
DenseTensor *ln_bias_2_grad,
DenseTensor *x_grad,
DenseTensor *qkv_weight_grad,
DenseTensor *out_linear_weight_grad,
DenseTensor *ln_out_grad,
DenseTensor *bias_dropout_residual_out_grad,
DenseTensor *qkv_out_grad,
DenseTensor *qktv_out_grad,
DenseTensor *transpose_out_2_grad,
DenseTensor *qk_out_grad,
DenseTensor *softmax_out_grad,
DenseTensor *attn_dropout_out_grad,
DenseTensor *fmha_out_grad,
DenseTensor *out_linear_out_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
/**
* @brief Fused Attention Kernel.
* @param ctx device context
* @param x The input tensor.
* @param ln_scale (optional) Scale is a 1-dimensional tensor of size
* H. Here, H represents the last dimension of its
* input tensor.
* @param ln_bias (optional) Bias is a 1-dimensional tensor of size
* H. Here, H represents the last dimension of its
* input tensor.
* @param qkv_weight The qkv weight tensor.
* @param qkv_bias The qkv bias tensor.
* @param cache_kv (optional) The cache KV for generation inference.
* @param src_mask (optional) The attention mask tensor in fmha.
* @param out_linear_w The out_linear weight tensor.
* @param out_linear_bias (optional) The out_linear bias tensor.
* @param ln_scale_2 (optional) Scale is a 1-dimensional tensor of
* size H. Here, H represents the last dimension of its input tensor.
* @param ln_bias_2 (optional) Bias is a 1-dimensional tensor of size
* H. Here, H represents the last dimension of its
* input tensor.
* @param num_heads The number head for multi_head_attention.
* @param transpose_qkv_wb The qkv_w shape is (h, 3h), do transpose to it.
* @param pre_layer_norm if true, the attention op uses pre_layer_norm
* architecure, else, uses post_layer_norm
* architecuture. [default false].
* @param epsilon Constant for numerical stability [default 1e-5].
* @param attn_dropout_rate Probability of setting units to zero.
* @param is_test (bool, default false) Set to true for inference
* only, false " for training. Some layers may run
* faster when this is true.
* @param attn_dropout_fix_seed A flag indicating whether to use a fixed seed to
* generate " random mask. NOTE: DO NOT set this flag
* to true in training. Setting this flag to true is
* only useful in unittest or for debug that always the same output units will
* be dropped."
* @param attn_dropout_seed Dropout random seed.
* @param attn_dropout_implementation ["downgrade_in_infer"|"upscale_in_train"]
* There are two kinds of ways to implement dropout
* (the mask below is a tensor have the same shape
* with input the value of mask is 0 or 1, the ratio of 0 is
* dropout_rate)
* 1. downgrade_in_infer(default), downgrade the
* outcome at inference time train: out = input *
* mask inference: out = input * (1.0 - dropout_rate)
* 2. upscale_in_train, upscale the outcome at
* training time, do nothing in inference train:
* out = input * mask / ( 1.0 - dropout_rate ) inference: out = input dropout op
* can be removed from the program. the program will be efficient
* @param dropout_rate Probability of setting units to zero.
* @param dropout_fix_seed A flag indicating whether to use a fixed seed to
* generate " random mask. NOTE: DO NOT set this flag
* to true in training. Setting this flag to true is
* only useful in unittest or for debug that always the same output units will
* be dropped.
* @param dropout_seed Dropout random seed.
* @param dropout_implementation dropout_implementation
* ["downgrade_in_infer"|"upscale_in_train"] The
* meaning is the same as
* 'attn_dropout_implementation'
* @param ln_epsilon Constant for numerical stability [default 1e-5].
* @param add_residual Whether to add residual.
* @param ring_id ring id for tensor model parallel. distributed
* training and inference
* @param ln_mean Mean of the current mini batch.
* @param ln_var Variance of the current mini batch.
* @param ln_out The output tensor after layer_norm.
* @param qkv_out Result after qkv.
* @param qkv_bias_out Result after qkv and bias op.
* @param transpose_out_2 Result in fmha.
* @param qk_out Result in fmha.
* @param qktv_out Result in fmha.
* @param soft_max_out Result in fmha.
* @param attn_dropout_mask_out Result in fmha.
* @param attn_dropout_out Result in fmha.
* @param src_mask_out Result in fmha.
* @param fmha_out Result in fmha.
* @param out_linear_out Result after out_linear.
* @param dropout_mask_out The random sampled dropout mask.
* @param ln_mean_2 Mean of the current mini batch.
* @param ln_var_2 Variance of the current mini batch.
* @param bias_dropout_residual_out Result of residual + dropout(src + bias).
* @param cache_kv_out The update cache KV.
* @param y Result after attention.
*/
template <typename T, typename Context>
void FusedAttentionKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &cache_kv,
const paddle::optional<DenseTensor> &src_mask,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *ln_mean,
DenseTensor *ln_var,
DenseTensor *ln_out,
DenseTensor *qkv_out,
DenseTensor *qkv_bias_out,
DenseTensor *transpose_out_2,
DenseTensor *qk_out,
DenseTensor *qktv_out,
DenseTensor *softmax_out,
DenseTensor *attn_dropout_mask_out,
DenseTensor *attn_dropout_out,
DenseTensor *src_mask_out,
DenseTensor *fmha_out,
DenseTensor *out_linear_out,
DenseTensor *dropout_mask_out,
DenseTensor *ln_mean_2,
DenseTensor *ln_var_2,
DenseTensor *bias_dropout_residual_out,
DenseTensor *cache_kv_out,
DenseTensor *out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_attention_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h"
namespace phi {
template <typename T, typename Context>
void FusedAttentionGradKernel(
const Context &dev_ctx,
const DenseTensor &out_grad,
const DenseTensor &x,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &qkv_bias_out,
const paddle::optional<DenseTensor> &src_mask,
const paddle::optional<DenseTensor> &src_mask_out,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
const paddle::optional<DenseTensor> &ln_out,
const paddle::optional<DenseTensor> &ln_mean,
const paddle::optional<DenseTensor> &ln_var,
const paddle::optional<DenseTensor> &ln_mean_2,
const paddle::optional<DenseTensor> &ln_var_2,
const paddle::optional<DenseTensor> &bias_dropout_residual_out,
const DenseTensor &qkv_out,
const DenseTensor &transpose_out_2,
const DenseTensor &qk_out,
const DenseTensor &qktv_out,
const DenseTensor &softmax_out,
const DenseTensor &attn_dropout_mask,
const DenseTensor &attn_dropout_out,
const DenseTensor &fmha_out,
const DenseTensor &out_linear_out,
const DenseTensor &dropout_mask_out,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *qkv_bias_grad,
DenseTensor *qkv_bias_out_grad,
DenseTensor *src_mask_out_grad,
DenseTensor *out_linear_bias_grad,
DenseTensor *ln_scale_grad,
DenseTensor *ln_bias_grad,
DenseTensor *ln_scale_2_grad,
DenseTensor *ln_bias_2_grad,
DenseTensor *x_grad,
DenseTensor *qkv_weight_grad,
DenseTensor *out_linear_weight_grad,
DenseTensor *ln_out_grad,
DenseTensor *bias_dropout_residual_out_grad,
DenseTensor *qkv_out_grad,
DenseTensor *qktv_out_grad,
DenseTensor *transpose_out_2_grad,
DenseTensor *qk_out_grad,
DenseTensor *softmax_out_grad,
DenseTensor *attn_dropout_out_grad,
DenseTensor *fmha_out_grad,
DenseTensor *out_linear_out_grad) {
using XPUTypeT = typename XPUTypeTrait<T>::Type;
bool is_upscale_in_train_1 =
(attn_dropout_implementation == "upscale_in_train");
const phi::DenseTensor *seed_1 = nullptr;
phi::XPUDropoutParam attn_dropout_param;
attn_dropout_param.initXPUDropoutParam(attn_dropout_rate,
is_upscale_in_train_1,
is_test,
attn_dropout_fix_seed,
seed_1,
attn_dropout_seed);
phi::XPUDropoutParam dropout_param;
dropout_param.initXPUDropoutParam(dropout_rate,
is_upscale_in_train_1,
is_test,
dropout_fix_seed,
seed_1,
dropout_seed);
// get inputs.
const XPUTypeT *d_y_ptr =
reinterpret_cast<const XPUTypeT *>(out_grad.data<T>());
// 前向必要参数
const XPUTypeT *input_x_ptr = reinterpret_cast<const XPUTypeT *>(x.data<T>());
const XPUTypeT *qkv_transpose_out_ptr =
reinterpret_cast<const XPUTypeT *>(transpose_out_2.data<T>());
const XPUTypeT *qkv_weight_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_weight.data<T>());
const XPUTypeT *softmax_out_ptr =
reinterpret_cast<const XPUTypeT *>(softmax_out.data<T>());
const XPUTypeT *attn_dropout_out_ptr =
reinterpret_cast<const XPUTypeT *>(attn_dropout_out.data<T>());
const XPUTypeT *attn_dropout_mask_ptr =
reinterpret_cast<const XPUTypeT *>(attn_dropout_mask.data<T>());
const XPUTypeT *fmha_out_ptr =
reinterpret_cast<const XPUTypeT *>(fmha_out.data<T>());
const XPUTypeT *out_linear_weight_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_weight.data<T>());
const XPUTypeT *dropout_mask_out_ptr =
reinterpret_cast<const XPUTypeT *>(dropout_mask_out.data<T>());
// 需要计算的梯度
auto *d_qkv_weight = qkv_weight_grad;
XPUTypeT *d_qkv_weight_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(d_qkv_weight));
auto *d_qkv_bias = qkv_bias_grad;
XPUTypeT *d_qkv_bias_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(d_qkv_bias));
auto *d_out_linear_weight = out_linear_weight_grad;
XPUTypeT *d_out_linear_weight_ptr = reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(d_out_linear_weight));
auto *d_out_linear_bias = out_linear_bias_grad;
XPUTypeT *d_out_linear_bias_ptr = reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(d_out_linear_bias));
// 有可能需要
auto *d_src_mask_out = src_mask_out_grad;
XPUTypeT *d_src_mask_out_ptr =
(d_src_mask_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(d_src_mask_out)));
// 输出 dx
auto *d_x = x_grad;
XPUTypeT *d_x_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(d_x));
const phi::DenseTensor *ln_out_p = ln_out.get_ptr();
const phi::DenseTensor *bias_dropout_residual_out_p =
bias_dropout_residual_out.get_ptr();
const phi::DenseTensor *ln_scale_p = nullptr;
const phi::DenseTensor *ln_mean_p = nullptr;
const phi::DenseTensor *ln_var_p = nullptr;
phi::DenseTensor *d_ln_scale = nullptr;
phi::DenseTensor *d_ln_bias = nullptr;
const XPUTypeT *ln_out_ptr = NULL;
const float *ln_scale_ptr = NULL;
const float *ln_mean_ptr = NULL;
const float *ln_var_ptr = NULL;
const XPUTypeT *bias_dropout_residual_out_ptr = NULL;
float *d_ln_scale_ptr = nullptr;
float *d_ln_bias_ptr = nullptr;
if (pre_layer_norm) {
ln_out_ptr = reinterpret_cast<const XPUTypeT *>(ln_out_p->data<T>());
ln_scale_p = ln_scale.get_ptr();
ln_mean_p = ln_mean.get_ptr();
ln_var_p = ln_var.get_ptr();
d_ln_scale = ln_scale_grad;
d_ln_bias = ln_bias_grad;
} else {
ln_scale_p = ln_scale_2.get_ptr();
ln_mean_p = ln_mean_2.get_ptr();
ln_var_p = ln_var_2.get_ptr();
epsilon = ln_epsilon;
d_ln_scale = ln_scale_2_grad;
d_ln_bias = ln_bias_2_grad;
bias_dropout_residual_out_ptr = reinterpret_cast<const XPUTypeT *>(
bias_dropout_residual_out_p->data<T>());
}
ln_scale_ptr = ln_scale_p->data<float>();
ln_mean_ptr = ln_mean_p->data<float>();
ln_var_ptr = ln_var_p->data<float>();
d_ln_scale_ptr = dev_ctx.template Alloc<float>(d_ln_scale);
d_ln_bias_ptr = dev_ctx.template Alloc<float>(d_ln_bias);
const auto input_x_dims = x.dims();
const auto qkv_w_dims = qkv_weight.dims();
int batch_size = input_x_dims[0];
int seq_len = input_x_dims[1];
int embed_dims = input_x_dims[2];
num_heads = qkv_w_dims[1];
int head_dims = qkv_w_dims[2];
xpu::Context *xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
int r = 0;
// int l3_total_size = xpu_ctx->_l3_mgr.get_size();
XPUTypeT *d_ln_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden]
XPUTypeT *d_dropout_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden]
XPUTypeT *d_fmha_out_ptr =
NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims]
XPUTypeT *d_fmha_out_transpos_tmp_ptr =
NULL; // d_fmha_out_transpos [batch_size, seq_len, num_heads,
// head_dims]
XPUTypeT *d_qk_ptr =
NULL; // d_qk_ptr[batch_size, num_heads, seq_len, seq_len]
XPUTypeT *d_combination_qkv_ptr =
NULL; // d_combination_qkv_ptr[3, batch_size, num_heads, seq_len,
// head_dims]
XPUTypeT *d_transpos_qkv_ptr =
NULL; // dx2 [batch_size, seq_len, 3, num_heads, head_dims]
XPUTypeT *d_last_layernorm_grad_ptr =
NULL; // d_layer_out [batch_size, seq_len, embed_dims]
const XPUTypeT *dy_input_ptr = d_y_ptr;
d_ln_grad_ptr = RAII_GUARD.alloc<XPUTypeT>(batch_size * seq_len * embed_dims);
d_dropout_grad_ptr =
RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len * embed_dims);
d_fmha_out_ptr = RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len *
num_heads * head_dims);
d_combination_qkv_ptr =
RAII_GUARD.alloc<XPUTypeT>(batch_size * seq_len * embed_dims * 3);
d_transpos_qkv_ptr = RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(
batch_size * seq_len * embed_dims * 3);
d_fmha_out_transpos_tmp_ptr =
RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len * embed_dims);
d_qk_ptr = RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len *
seq_len * num_heads);
d_last_layernorm_grad_ptr =
RAII_GUARD.alloc_l3_or_gm<XPUTypeT>(batch_size * seq_len * embed_dims);
if (pre_layer_norm == false) {
r = xpu::layer_norm_grad(xpu_ctx,
bias_dropout_residual_out_ptr,
d_y_ptr,
d_ln_grad_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_mean_ptr,
ln_var_ptr,
d_ln_scale_ptr,
d_ln_bias_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
dy_input_ptr = d_ln_grad_ptr;
}
// dropout_grad
DropoutGrad<XPUTypeT>(xpu_ctx,
dy_input_ptr,
dropout_mask_out_ptr,
d_dropout_grad_ptr,
dropout_param,
batch_size * num_heads * seq_len * head_dims);
// linear_out
phi::XpuFcInfo linear_fc_info;
linear_fc_info.InitFcInfo(0,
batch_size * seq_len,
embed_dims,
embed_dims,
false,
false,
nullptr,
nullptr,
nullptr);
const XPUTypeT *a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
XPUTypeT *c_1 = d_fmha_out_ptr;
XPUTypeT *c_2 = d_out_linear_weight_ptr;
phi::XpuFcInfo info_dfmha;
phi::XpuFcInfo info_dlinear_w;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
linear_fc_info,
false,
false,
fmha_out_ptr,
out_linear_weight_ptr,
d_dropout_grad_ptr);
std::tie(info_dfmha, info_dlinear_w, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_dlinear_w, 1.0f, true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_dfmha, 1.0f, true);
// dlinear_bias
r = xpu::reduce_sum(xpu_ctx,
d_dropout_grad_ptr,
d_out_linear_bias_ptr,
{batch_size * seq_len, embed_dims},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
{
int qkv_size = batch_size * seq_len * num_heads * head_dims;
const XPUTypeT *q_out_ptr = qkv_transpose_out_ptr;
const XPUTypeT *k_out_ptr = q_out_ptr + qkv_size;
const XPUTypeT *v_out_ptr = k_out_ptr + qkv_size;
XPUTypeT *d_q_out_ptr = d_combination_qkv_ptr;
XPUTypeT *d_k_out_ptr = d_q_out_ptr + qkv_size;
XPUTypeT *d_v_out_ptr = d_k_out_ptr + qkv_size;
r = xpu::transpose<XPUTypeT>(xpu_ctx,
d_fmha_out_ptr,
d_fmha_out_transpos_tmp_ptr,
{batch_size, seq_len, num_heads, head_dims},
{0, 2, 1, 3});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
phi::XpuFcInfo qktv_fc_info;
qktv_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
head_dims,
seq_len,
false,
false,
nullptr,
nullptr,
nullptr);
const XPUTypeT *a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
const XPUTypeT *b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
XPUTypeT *c_1 = d_qk_ptr;
XPUTypeT *c_2 = d_v_out_ptr;
phi::XpuFcInfo info_d_qk;
phi::XpuFcInfo info_d_v;
std::tuple<phi::XpuFcInfo,
phi::XpuFcInfo,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *,
const XPUTypeT *>
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
qktv_fc_info,
false,
false,
attn_dropout_out_ptr,
v_out_ptr,
d_fmha_out_transpos_tmp_ptr);
std::tie(info_d_qk, info_d_v, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_d_qk, 1.0f, true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_d_v, 1.0f, true);
DropoutGrad<XPUTypeT>(xpu_ctx,
d_qk_ptr,
attn_dropout_mask_ptr,
d_qk_ptr,
attn_dropout_param,
batch_size * seq_len * seq_len * num_heads);
r = xpu::softmax_grad<XPUTypeT>(xpu_ctx,
softmax_out_ptr,
d_qk_ptr,
d_qk_ptr,
{batch_size, num_heads, seq_len, seq_len},
3);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad");
if (d_src_mask_out_ptr) {
r = xpu::copy<XPUTypeT>(xpu_ctx,
d_qk_ptr,
d_src_mask_out_ptr,
batch_size * seq_len * seq_len * num_heads);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
phi::XpuFcInfo qk_fc_info;
qk_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
seq_len,
head_dims,
false,
true,
nullptr,
nullptr,
nullptr);
a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
c_1 = d_q_out_ptr;
c_2 = d_k_out_ptr;
phi::XpuFcInfo info_d_q;
phi::XpuFcInfo info_d_k;
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
qk_fc_info,
false,
true,
q_out_ptr,
k_out_ptr,
d_qk_ptr);
std::tie(info_d_q, info_d_k, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_d_q, 1.0f / sqrt(head_dims), true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_d_k, 1.0f, true);
}
//
r = xpu::transpose<XPUTypeT>(xpu_ctx,
d_combination_qkv_ptr,
d_transpos_qkv_ptr,
{3, batch_size, num_heads, seq_len, head_dims},
{1, 3, 0, 2, 4});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
// dx and d_qkv_w
phi::XpuFcInfo qkv_fc_info;
qkv_fc_info.InitFcInfo(0,
batch_size * seq_len,
3 * num_heads * head_dims,
embed_dims,
false,
true,
nullptr,
nullptr,
nullptr);
a_1 = reinterpret_cast<const XPUTypeT *>(NULL);
b_1 = reinterpret_cast<const XPUTypeT *>(NULL);
a_2 = reinterpret_cast<const XPUTypeT *>(NULL);
b_2 = reinterpret_cast<const XPUTypeT *>(NULL);
c_1 = (pre_layer_norm == true) ? d_last_layernorm_grad_ptr : d_x_ptr;
c_2 = d_qkv_weight_ptr;
phi::XpuFcInfo info_d_x;
phi::XpuFcInfo info_d_qkv_w;
const XPUTypeT *use_calc_input_x_ptr =
(pre_layer_norm == true) ? ln_out_ptr : input_x_ptr;
fc_info = phi::MatmulGradFcInfo(xpu_ctx,
&RAII_GUARD,
qkv_fc_info,
false,
true,
use_calc_input_x_ptr,
qkv_weight_ptr,
d_transpos_qkv_ptr);
std::tie(info_d_x, info_d_qkv_w, a_1, b_1, a_2, b_2) = fc_info;
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_1, b_1, c_1, info_d_x, 1.0f, true);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, a_2, b_2, c_2, info_d_qkv_w, 1.0f, true);
// d_qkv_bias
r = xpu::reduce_sum(xpu_ctx,
d_transpos_qkv_ptr,
d_qkv_bias_ptr,
{batch_size * seq_len, 3 * embed_dims},
{0});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
if (pre_layer_norm) {
r = xpu::layer_norm_grad(xpu_ctx,
input_x_ptr,
c_1,
d_x_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_mean_ptr,
ln_var_ptr,
d_ln_scale_ptr,
d_ln_bias_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_grad");
}
// add rediaus dy
r = xpu::add(xpu_ctx,
dy_input_ptr,
d_x_ptr,
d_x_ptr,
batch_size * seq_len * embed_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
}
} // namespace phi
PD_REGISTER_KERNEL(fused_attention_grad,
XPU,
ALL_LAYOUT,
phi::FusedAttentionGradKernel,
float,
phi::dtype::float16) {
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_attention_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
#include "paddle/phi/kernels/xpu/xpu_fused_common_function.h"
namespace phi {
template <typename T, typename Context>
void FusedAttentionKernel(const Context &dev_ctx,
const DenseTensor &x,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &cache_kv,
const paddle::optional<DenseTensor> &src_mask,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *ln_mean,
DenseTensor *ln_var,
DenseTensor *ln_out,
DenseTensor *qkv_out,
DenseTensor *qkv_bias_out,
DenseTensor *transpose_out_2,
DenseTensor *qk_out,
DenseTensor *qktv_out,
DenseTensor *softmax_out,
DenseTensor *attn_dropout_mask_out,
DenseTensor *attn_dropout_out,
DenseTensor *src_mask_out,
DenseTensor *fmha_out,
DenseTensor *out_linear_out,
DenseTensor *dropout_mask_out,
DenseTensor *ln_mean_2,
DenseTensor *ln_var_2,
DenseTensor *bias_dropout_residual_out,
DenseTensor *cache_kv_out,
DenseTensor *out) {
using XPUTypeT = typename XPUTypeTrait<T>::Type;
// shape [batch_size, 1, 1, seq_len]
const phi::DenseTensor *src_mask_p = src_mask.get_ptr();
const phi::DenseTensor *ln_scale_p = nullptr;
const phi::DenseTensor *ln_bias_p = nullptr;
if (pre_layer_norm) {
ln_scale_p = ln_scale.get_ptr();
ln_bias_p = ln_bias.get_ptr();
} else {
ln_scale_p = ln_scale_2.get_ptr();
ln_bias_p = ln_bias_2.get_ptr();
epsilon = ln_epsilon;
}
dev_ctx.template Alloc<T>(qk_out);
dev_ctx.template Alloc<T>(qktv_out);
dev_ctx.template Alloc<T>(out_linear_out);
dev_ctx.template Alloc<T>(qkv_bias_out);
dev_ctx.template Alloc<T>(src_mask_out);
dev_ctx.template Alloc<T>(qkv_out);
bool is_upscale_in_train_1 =
(attn_dropout_implementation == "upscale_in_train");
const phi::DenseTensor *seed_1 = nullptr;
phi::XPUDropoutParam attn_dropout_param;
attn_dropout_param.initXPUDropoutParam(attn_dropout_rate,
is_upscale_in_train_1,
is_test,
attn_dropout_fix_seed,
seed_1,
attn_dropout_seed);
phi::XPUDropoutParam dropout_param;
dropout_param.initXPUDropoutParam(dropout_rate,
is_upscale_in_train_1,
is_test,
dropout_fix_seed,
seed_1,
dropout_seed);
// 先计算纬度
const auto input_x_dims = x.dims();
const auto qkv_w_dims = qkv_weight.dims();
int batch_size = input_x_dims[0];
int seq_len = input_x_dims[1];
int embed_dims = input_x_dims[2];
num_heads = qkv_w_dims[1];
int head_dims = qkv_w_dims[2];
// 输入指针
const XPUTypeT *input_x_ptr = reinterpret_cast<const XPUTypeT *>(x.data<T>());
const XPUTypeT *qkv_weight_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_weight.data<T>());
const DenseTensor *qkv_bias_p = qkv_bias.get_ptr();
const XPUTypeT *qkv_bias_ptr =
reinterpret_cast<const XPUTypeT *>(qkv_bias_p->data<T>());
const XPUTypeT *src_mask_ptr =
(src_mask_p == nullptr)
? (nullptr)
: (reinterpret_cast<const XPUTypeT *>(src_mask_p->data<T>()));
const XPUTypeT *out_linear_weight_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_weight.data<T>());
const DenseTensor *out_linear_bias_p = out_linear_bias.get_ptr();
const XPUTypeT *out_linear_bias_ptr =
reinterpret_cast<const XPUTypeT *>(out_linear_bias_p->data<T>());
const float *ln_scale_ptr =
(ln_scale_p == nullptr) ? (nullptr) : ln_scale_p->data<float>();
const float *ln_bias_ptr =
(ln_bias_p == nullptr) ? (nullptr) : ln_bias_p->data<float>();
// 输出指针
XPUTypeT *qkv_transpose_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(transpose_out_2));
XPUTypeT *softmax_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(softmax_out));
XPUTypeT *attn_dropout_mask_out_ptr = reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(attn_dropout_mask_out));
XPUTypeT *attn_dropout_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(attn_dropout_out));
XPUTypeT *fmha_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(fmha_out));
XPUTypeT *dropout_mask_out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(dropout_mask_out));
XPUTypeT *out_ptr =
reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(out));
XPUTypeT *bias_dropout_residual_out_ptr =
(bias_dropout_residual_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(
dev_ctx.template Alloc<T>(bias_dropout_residual_out)));
float *ln_mean_ptr =
(ln_mean == nullptr)
? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_mean));
float *ln_var_ptr =
(ln_var == nullptr)
? (nullptr)
: reinterpret_cast<float *>(dev_ctx.template Alloc<T>(ln_var));
XPUTypeT *ln_out_ptr =
(ln_out == nullptr)
? (nullptr)
: (reinterpret_cast<XPUTypeT *>(dev_ctx.template Alloc<T>(ln_out)));
xpu::Context *xpu_ctx = dev_ctx.x_context();
xpu::ctx_guard RAII_GUARD(xpu_ctx);
int l3_total_size = xpu_ctx->_l3_mgr.get_size();
XPUTypeT *qkv_before_transpos_ptr =
NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims]
XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len]
XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims]
XPUTypeT *linear_out_ptr = NULL; // x4, x5 [batch_size, seq_len, embed_dims]
int temp_size_1 = batch_size * seq_len * 3 * num_heads * head_dims;
int temp_size_2 = batch_size * num_heads * seq_len * seq_len;
int temp_size_3 = batch_size * num_heads * seq_len * head_dims;
int temp_size_4 = batch_size * seq_len * embed_dims;
std::vector<int> temp_vec = {
temp_size_1, temp_size_2, temp_size_3, temp_size_4};
std::sort(temp_vec.begin(), temp_vec.end(), std::greater<int>());
XPUTypeT *max_gm_ptr = RAII_GUARD.alloc<XPUTypeT>(temp_vec[0]);
PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr);
qkv_before_transpos_ptr = max_gm_ptr;
qk_ptr = max_gm_ptr;
qkv_ptr = max_gm_ptr;
linear_out_ptr = max_gm_ptr;
int sizeof_t = sizeof(XPUTypeT);
for (size_t i = 0; i < temp_vec.size(); ++i) {
if (l3_total_size >= temp_vec[i] * sizeof_t) {
XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3<XPUTypeT>(temp_vec[i]);
qkv_before_transpos_ptr =
(temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
linear_out_ptr = (temp_size_4 <= temp_vec[i]) ? l3_ptr : max_gm_ptr;
break;
}
}
int r = 0;
const XPUTypeT *x_cacl_ptr = input_x_ptr;
if (pre_layer_norm) {
r = xpu::layer_norm(xpu_ctx,
input_x_ptr,
ln_out_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_bias_ptr,
ln_mean_ptr,
ln_var_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
x_cacl_ptr = ln_out_ptr;
}
// fc
phi::XpuFcInfo qkv_fc_info;
qkv_fc_info.InitFcInfo(0,
batch_size * seq_len,
3 * num_heads * head_dims,
embed_dims,
false,
true,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx,
x_cacl_ptr,
qkv_weight_ptr,
qkv_before_transpos_ptr,
qkv_fc_info,
1.0f);
// bias
r = xpu::broadcast_add(xpu_ctx,
qkv_before_transpos_ptr,
qkv_bias_ptr,
qkv_before_transpos_ptr,
{batch_size * seq_len, 3 * num_heads * head_dims},
{3 * num_heads * head_dims});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
// transpose
r = xpu::transpose(xpu_ctx,
qkv_before_transpos_ptr,
qkv_transpose_out_ptr,
{batch_size, seq_len, 3, num_heads, head_dims},
{2, 0, 3, 1, 4});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
int qkv_every_size = batch_size * seq_len * num_heads * head_dims;
{
float alpha = 1.0 / sqrt(head_dims);
r = scale(xpu_ctx,
qkv_transpose_out_ptr,
qkv_transpose_out_ptr,
qkv_every_size,
false,
alpha,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
// begin fhma
// 1. qk 2. qk + mask 3. softmax 4.dropout 5. qkv 6. transpos
{
const XPUTypeT *q_ptr = qkv_transpose_out_ptr;
const XPUTypeT *k_ptr = q_ptr + qkv_every_size;
const XPUTypeT *v_ptr = k_ptr + qkv_every_size;
phi::XpuFcInfo qk_fc_info;
qk_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
seq_len,
head_dims,
false,
true,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, q_ptr, k_ptr, qk_ptr, qk_fc_info, 1.0f);
if (src_mask_ptr) {
r = xpu::broadcast_add(xpu_ctx,
qk_ptr,
src_mask_ptr,
qk_ptr,
{batch_size, num_heads, seq_len, seq_len},
{batch_size, 1, 1, seq_len});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
// do softmax
r = xpu::softmax(xpu_ctx,
qk_ptr,
softmax_out_ptr,
{batch_size, num_heads, seq_len, seq_len},
3);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
// do dropout
phi::Dropout<XPUTypeT>(xpu_ctx,
softmax_out_ptr,
attn_dropout_mask_out_ptr,
attn_dropout_out_ptr,
attn_dropout_param,
batch_size * num_heads * seq_len * seq_len);
phi::XpuFcInfo qktv_fc_info;
qktv_fc_info.InitFcInfo(batch_size * num_heads,
seq_len,
head_dims,
seq_len,
false,
false,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(
xpu_ctx, attn_dropout_out_ptr, v_ptr, qkv_ptr, qktv_fc_info, 1.0f);
r = xpu::transpose(xpu_ctx,
qkv_ptr,
fmha_out_ptr,
{batch_size, num_heads, seq_len, head_dims},
{0, 2, 1, 3});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
// linear_out
phi::XpuFcInfo linear_fc_info;
linear_fc_info.InitFcInfo(0,
batch_size * seq_len,
embed_dims,
embed_dims,
false,
false,
nullptr,
nullptr,
nullptr);
phi::MatMulXPUFunction<XPUTypeT>(xpu_ctx,
fmha_out_ptr,
out_linear_weight_ptr,
linear_out_ptr,
linear_fc_info,
1.0f);
// out_linear_bias_ptr
r = xpu::broadcast_add(xpu_ctx,
linear_out_ptr,
out_linear_bias_ptr,
linear_out_ptr,
{batch_size * seq_len, embed_dims},
{embed_dims});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
Dropout(xpu_ctx,
linear_out_ptr,
dropout_mask_out_ptr,
linear_out_ptr,
dropout_param,
batch_size * seq_len * embed_dims);
XPUTypeT *real_out_ptr = out_ptr;
if (pre_layer_norm == false) {
real_out_ptr = bias_dropout_residual_out_ptr;
}
r = xpu::add(xpu_ctx,
linear_out_ptr,
input_x_ptr,
real_out_ptr,
batch_size * seq_len * embed_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
if (pre_layer_norm == false) {
r = xpu::layer_norm(xpu_ctx,
real_out_ptr,
out_ptr,
batch_size * seq_len,
embed_dims,
epsilon,
ln_scale_ptr,
ln_bias_ptr,
ln_mean_ptr,
ln_var_ptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
}
}
} // namespace phi
PD_REGISTER_KERNEL(fused_attention,
XPU,
ALL_LAYOUT,
phi::FusedAttentionKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
struct XPUDropoutParam {
float dropout_prob;
bool is_upscale_in_train;
bool is_test;
bool fix_seed;
const phi::DenseTensor *tensor_seed;
int seed_val;
XPUDropoutParam() {
fix_seed = false;
is_test = false;
is_upscale_in_train = false;
dropout_prob = 0.5;
tensor_seed = nullptr;
seed_val = 0;
}
void initXPUDropoutParam(float dropout_prob_,
bool is_upscale_in_train_,
bool is_test_,
bool fix_seed_,
const phi::DenseTensor *tensor_seed,
int seed_val_) {
dropout_prob = dropout_prob_;
is_upscale_in_train = is_upscale_in_train_;
is_test = is_test_;
fix_seed = fix_seed_;
if (tensor_seed) {
seed_val = *(tensor_seed->data<int>());
} else {
seed_val = fix_seed ? seed_val_ : 0;
}
}
};
/******************
* check is l3
******************/
static bool is_in_l3(const void *addr) {
int64_t addr_int = (int64_t)addr;
int addr_int_high = addr_int >> 32;
return (addr_int_high == 0);
}
/*************************
* dropout
*************************/
template <typename T>
void Dropout(xpu::Context *xpu_ctx,
const T *x,
T *mask,
T *y,
const XPUDropoutParam &param,
int len) {
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
if (param.dropout_prob == 0.0f) {
r = xpu::copy(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
if (!param.is_test) {
if (param.dropout_prob == 1.0f) {
r = xpu::constant(
xpu_ctx, reinterpret_cast<XPUType *>(y), len, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(
xpu_ctx, reinterpret_cast<XPUType *>(mask), len, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
} else {
r = xpu::dropout(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
reinterpret_cast<XPUType *>(mask),
param.seed_val,
len,
param.is_upscale_in_train,
param.dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
}
} else {
float scale = (param.is_upscale_in_train)
? (1.0)
: (static_cast<float>(1.0f - param.dropout_prob));
r = xpu::scale(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
len,
false,
scale,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
}
template <typename T>
void DropoutGrad(xpu::Context *xpu_ctx,
const T *dy,
const T *mask,
T *dx,
const XPUDropoutParam &param,
int len) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (param.dropout_prob == 0.0f) {
int r = xpu::copy(xpu_ctx,
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<XPUType *>(dx),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
if (!param.is_upscale_in_train) {
int r = xpu::mul(xpu_ctx,
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<const XPUType *>(mask),
reinterpret_cast<XPUType *>(dx),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else {
int r = xpu::dropout_grad(xpu_ctx,
reinterpret_cast<const XPUType *>(mask),
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<XPUType *>(dx),
param.dropout_prob,
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad");
}
}
} // namespace phi
...@@ -1160,6 +1160,8 @@ set(STATIC_BUILD_TESTS ...@@ -1160,6 +1160,8 @@ set(STATIC_BUILD_TESTS
test_eigh_op test_eigh_op
test_fake_quantize_op test_fake_quantize_op
test_fetch_lod_tensor_array test_fetch_lod_tensor_array
test_fused_attention_op
test_fused_attention_op_api
test_imperative_optimizer test_imperative_optimizer
test_lamb_op test_lamb_op
test_layer_norm_op test_layer_norm_op
...@@ -1186,6 +1188,11 @@ set(STATIC_BUILD_TESTS ...@@ -1186,6 +1188,11 @@ set(STATIC_BUILD_TESTS
test_while_op test_while_op
test_one_hot_v2_op) test_one_hot_v2_op)
if(NOT WITH_GPU)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api)
endif()
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
py_test_modules( py_test_modules(
${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS ${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册