未验证 提交 3bac6264 编写于 作者: S Sonder 提交者: GitHub

Move fused_attention op to phi [迁移反向 GPU OpKernel] (#51909)

* add kernel functions

* update kernel functions

* update func parameters' name

* create codes for gpu device

* 调整文件位置

* fix include error

* remove dependent files to phi/

* restore fused_attention_op.cu

* fix dependence errors

* fix dependence errors

* fix include error

* fix all depandence errors[build success]

* remove useless include

* recover useless include

* use phi::ToNCCLDataType

* fix namespace

* update new register code

* fix error in fused_gemm_epilogue_utils

* fix error in FusedAttentionKernel parm

* finish fused_attention registe code[build success]

* add paddle::optional

* add sig file

* fix build error

* fix a include error

* 恢复正向代码

* update CMkaeList

* trans Compute function to phi [build success]

* add register code and fix include error [build success]

* fix parameter sequence

* add include file

* update #if before include

* update #if before include

* fix grammly error

* update codes for DropoutParam

* remove const cast

* trans some fluid api to phi api

* remove const cast

* trans some fluid api to phi api

* add #if

* update test code

* update test codes

* recover test codes

* fix namespace and remove fluid include

* recover random seed

* remove fluid quant_helper

* fix include error

* include utils in funcs

* change include file

* move grad codes back to fluid floder

* move grad codes back to fluid floder

* fix sig file error

* update include

* recover codes to develop

* update register codes

* fix build error

* recover fluid include

* remove some fluid include

* remove some fluid include

* Update fused_attention_op.cu

* remove fluid include

* add some fluid include

* Update fused_attention_op.cu

* Update fused_attention_op.cu

* Update fused_attention_op.cu

* Update fused_attention_op.cu

* remote useless include
上级 ab163063
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -17,458 +17,25 @@ limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
// for phi fused attention
// fluid include will be removed after fused attention grad kernel is merged
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
namespace paddle {
namespace operators {
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const int num_heads = ctx.Attr<int>("num_heads");
const bool transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
const float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
const bool has_attn_dropout = (attn_dropout_prob != 0.0f);
DropoutParam dropout_param2(ctx, 0);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 =
ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
int ring_id = ctx.Attr<int>("ring_id");
// get inputs.
auto *d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
auto *d_y_data = d_y->data<T>();
// fw input
auto *input_x = ctx.Input<phi::DenseTensor>("X");
auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
auto *ln_2_scale = ctx.Input<phi::DenseTensor>("Ln2Scale");
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_2_scale_data =
(ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
// fw parameters.
auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data =
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
// fw output
auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
auto *attn_dropout_mask_out =
ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut");
auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *dropout_mask_out_data =
has_dropout ? dropout_mask_out->data<uint8_t>() : nullptr;
// output's grad
auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto *d_qkv_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVOut"));
auto *d_qkv_bias_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBiasOut"));
auto *d_qktv_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKTVOut"));
auto *d_transpose_out_2 =
ctx.Output<phi::DenseTensor>(framework::GradVarName("TransposeOut2"));
auto *d_qk_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKOut"));
auto *d_softmax_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("SoftmaxOut"));
auto *d_attn_dropout_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("AttnDropoutOut"));
auto *d_src_mask_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("SrcMaskOut"));
auto *d_fmha_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("FMHAOut"));
auto *d_out_linear_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearOut"));
auto *d_bias_dropout_residual_out = ctx.Output<phi::DenseTensor>(
framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
// when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
// space can be reused.
auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
? nullptr
: dev_ctx.template Alloc<T>(
d_qkv_out, d_qkv_out->numel() * sizeof(T));
auto *d_qkv_bias_out_data =
(d_qkv_bias_out == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_qkv_bias_out,
d_qkv_bias_out->numel() * sizeof(T));
auto *d_qktv_out_data =
dev_ctx.template Alloc<T>(d_qktv_out, d_qktv_out->numel() * sizeof(T));
auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T));
auto *d_qk_out_data =
dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
d_softmax_out, d_softmax_out->numel() * sizeof(T));
auto *d_attn_dropout_out_data =
has_attn_dropout
? dev_ctx.template Alloc<T>(d_attn_dropout_out,
d_attn_dropout_out->numel() * sizeof(T))
: nullptr;
auto *d_src_mask_out_data =
(src_mask == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_src_mask_out,
d_src_mask_out->numel() * sizeof(T));
auto *d_fmha_out_data =
dev_ctx.template Alloc<T>(d_fmha_out, d_fmha_out->numel() * sizeof(T));
auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
d_out_linear_out, d_out_linear_out->numel() * sizeof(T));
// parameter grad
auto *d_qkv_weight =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVW"));
auto *d_qkv_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBias"));
auto *d_out_linear_weight =
ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearW"));
auto *d_out_linear_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearBias"));
auto *d_ln_2_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Bias"));
auto *d_qkv_weight_data =
(d_qkv_weight == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_qkv_weight,
d_qkv_weight->numel() * sizeof(T));
auto *d_qkv_bias_data =
(d_qkv_bias == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_qkv_bias,
d_qkv_bias->numel() * sizeof(T));
auto *d_out_linear_weight_data =
(d_out_linear_weight == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(
d_out_linear_weight,
d_out_linear_weight->numel() * sizeof(T));
auto *d_out_linear_bias_data =
(d_out_linear_bias == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(d_out_linear_bias,
d_out_linear_bias->numel() * sizeof(T));
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head;
int dim_head;
int nranks = 1;
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
bool add_residual = ctx.Attr<bool>("add_residual");
phi::DenseTensor d_residual;
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims);
d_residual_data = dev_ctx.template Alloc<T>(
&d_residual, d_residual.numel() * sizeof(T));
}
bool transA = false;
bool transB = transpose_qkv_wb ? false : true;
bool compute_qkv_bias = qkv_bias ? true : false;
auto layer_norm_compute = AttnLayerNorm<T>(
ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
transA,
transB,
bsz_seq,
output_size,
input_size,
compute_qkv_bias);
AttnDropoutParam attn_dropout_param(is_test_1,
dropout_implementation_1,
attn_dropout_prob,
is_upscale_in_train_1,
is_fix_seed_1,
seed_val_1,
seed_1);
auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
batch_size,
max_seq_len,
num_head,
dim_head,
attn_dropout_param);
output_size = hidden_size;
transA = false;
transB = false;
bool compute_bias = false;
// (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
transA,
transB,
bsz_seq,
input_size,
output_size,
compute_bias);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(),
bsz_seq,
dim_embed,
dropout_param2,
ln2epsilon);
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
ctx.cuda_device_context(),
d_y_data,
dropout_mask_out_data,
d_out_linear_out_data,
d_residual_data,
d_out_linear_bias_data);
} else {
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->data<T>();
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr
? nullptr
: dev_ctx.template Alloc<U>(d_ln_2_scale,
d_ln_2_scale->numel() * sizeof(U)));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr
? nullptr
: dev_ctx.template Alloc<U>(d_ln_2_bias,
d_ln_2_bias->numel() * sizeof(U)));
auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
d_bias_dropout_residual_out,
d_bias_dropout_residual_out->numel() * sizeof(T));
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(),
d_y_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
ln_2_scale_data,
ln_2_mean_data,
ln_2_var_data,
d_bias_dropout_residual_out_data,
d_ln_2_scale_data,
d_ln_2_bias_data,
d_out_linear_out_data,
d_out_linear_bias_data,
d_residual_data);
}
out_linear_compute.ComputeBackward(fmha_out,
out_linear_weight,
d_out_linear_out,
d_fmha_out,
d_out_linear_weight,
nullptr);
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
d_qkv_bias_out->Resize(
{batch_size, max_seq_len, 3, num_head, dim_head});
} else {
d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
}
if (qkv_bias != nullptr) {
fmha_ref_compute.ComputeBackward(*transpose_out_2,
has_attn_dropout ? src_mask : nullptr,
*softmax_out,
*attn_dropout_mask_out,
*attn_dropout_out,
*qk_out,
*src_mask_out,
*d_fmha_out,
d_qktv_out,
d_attn_dropout_out,
d_softmax_out,
d_src_mask_out,
d_qk_out,
d_transpose_out_2,
nullptr,
d_qkv_bias_out);
} else {
fmha_ref_compute.ComputeBackward(*transpose_out_2,
has_attn_dropout ? src_mask : nullptr,
*softmax_out,
*attn_dropout_mask_out,
*attn_dropout_out,
*qk_out,
*src_mask_out,
*d_fmha_out,
d_qktv_out,
d_attn_dropout_out,
d_softmax_out,
d_src_mask_out,
d_qk_out,
d_transpose_out_2,
nullptr,
d_qkv_out);
}
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
} else {
d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
}
if (pre_layer_norm) {
auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
auto *ln_out = ctx.Input<phi::DenseTensor>("LnOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *d_ln_out =
ctx.Output<phi::DenseTensor>(framework::GradVarName("LnOut"));
auto *d_ln_scale =
ctx.Output<phi::DenseTensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias =
ctx.Output<phi::DenseTensor>(framework::GradVarName("LnBias"));
auto *d_ln_out_data =
dev_ctx.template Alloc<T>(d_ln_out, d_ln_out->numel() * sizeof(T));
auto *d_ln_scale_data =
(d_ln_scale == nullptr
? nullptr
: dev_ctx.template Alloc<U>(d_ln_scale,
d_ln_scale->numel() * sizeof(U)));
auto *d_ln_bias_data =
(d_ln_bias == nullptr
? nullptr
: dev_ctx.template Alloc<U>(d_ln_bias,
d_ln_bias->numel() * sizeof(U)));
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(ln_out,
qkv_weight,
d_qkv_bias_out,
d_ln_out,
d_qkv_weight,
d_qkv_bias);
} else {
qkv_compute.ComputeBackward(
ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
}
// tensor model parallel
phi::fusion::AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
layer_norm_compute.ComputeBackward(x_data,
d_ln_out_data,
ln_scale_data,
ln_mean_data,
ln_var_data,
d_x_data,
d_ln_scale_data,
d_ln_bias_data);
} else {
if (qkv_bias != nullptr) {
qkv_compute.ComputeBackward(
input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias);
} else {
qkv_compute.ComputeBackward(
input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
}
// tensor model parallel
phi::fusion::AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
}
if (add_residual) {
// gradient accumulation
std::vector<const phi::DenseTensor *> ins = {&d_residual, d_x};
std::vector<phi::DenseTensor *> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(
ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor<T>());
}
}
};
} // namespace operators
} // namespace paddle
namespace phi {
namespace fusion {
......@@ -799,6 +366,443 @@ void FusedAttentionKernel(const Context &dev_ctx,
}
}
template <typename T, typename Context>
void FusedAttentionGradKernel(
const Context &dev_ctx,
const DenseTensor &out_grad,
const DenseTensor &x,
const DenseTensor &qkv_weight,
const paddle::optional<DenseTensor> &qkv_bias,
const paddle::optional<DenseTensor> &qkv_bias_out,
const paddle::optional<DenseTensor> &src_mask,
const paddle::optional<DenseTensor> &src_mask_out,
const DenseTensor &out_linear_weight,
const paddle::optional<DenseTensor> &out_linear_bias,
const paddle::optional<DenseTensor> &ln_scale,
const paddle::optional<DenseTensor> &ln_bias,
const paddle::optional<DenseTensor> &ln_scale_2,
const paddle::optional<DenseTensor> &ln_bias_2,
const paddle::optional<DenseTensor> &ln_out,
const paddle::optional<DenseTensor> &ln_mean,
const paddle::optional<DenseTensor> &ln_var,
const paddle::optional<DenseTensor> &ln_mean_2,
const paddle::optional<DenseTensor> &ln_var_2,
const paddle::optional<DenseTensor> &bias_dropout_residual_out,
const DenseTensor &qkv_out,
const DenseTensor &transpose_out_2,
const DenseTensor &qk_out,
const DenseTensor &qktv_out,
const DenseTensor &softmax_out,
const DenseTensor &attn_dropout_mask_out,
const DenseTensor &attn_dropout_out,
const DenseTensor &fmha_out,
const DenseTensor &out_linear_out,
const DenseTensor &dropout_mask_out,
int num_heads,
bool transpose_qkv_wb,
bool pre_layer_norm,
float epsilon,
float attn_dropout_rate,
bool is_test,
bool attn_dropout_fix_seed,
int attn_dropout_seed,
const std::string &attn_dropout_implementation,
float dropout_rate,
bool dropout_fix_seed,
int dropout_seed,
const std::string &dropout_implementation,
float ln_epsilon,
bool add_residual,
int ring_id,
DenseTensor *qkv_bias_grad,
DenseTensor *qkv_bias_out_grad,
DenseTensor *src_mask_out_grad,
DenseTensor *out_linear_bias_grad,
DenseTensor *ln_scale_grad,
DenseTensor *ln_bias_grad,
DenseTensor *ln_scale_2_grad,
DenseTensor *ln_bias_2_grad,
DenseTensor *x_grad,
DenseTensor *qkv_weight_grad,
DenseTensor *out_linear_weight_grad,
DenseTensor *ln_out_grad,
DenseTensor *bias_dropout_residual_out_grad,
DenseTensor *qkv_out_grad,
DenseTensor *qktv_out_grad,
DenseTensor *transpose_out_2_grad,
DenseTensor *qk_out_grad,
DenseTensor *softmax_out_grad,
DenseTensor *attn_dropout_out_grad,
DenseTensor *fmha_out_grad,
DenseTensor *out_linear_out_grad) {
using U = phi::fusion::LayerNormParamType<T>;
const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
const bool is_upscale_in_train =
(dropout_implementation == "upscale_in_train");
phi::fusion::DropoutParam dropout_param2(dropout_fix_seed,
0,
is_test,
is_upscale_in_train,
dropout_rate,
nullptr,
dropout_seed);
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
bool is_upscale_in_train_1 =
(attn_dropout_implementation == "upscale_in_train");
phi::DenseTensor *seed_1 = nullptr;
// get inputs.
auto *d_y = &out_grad;
auto *d_y_data = d_y->data<T>();
// fw input
auto *input_x = &x;
auto *ln_scale_p = ln_scale.get_ptr();
auto *ln_scale_2_p = ln_scale_2.get_ptr();
auto *x_data = input_x->data<T>();
auto *ln_scale_data =
(ln_scale_p == nullptr ? nullptr : ln_scale_p->data<U>());
auto *ln_2_scale_data =
(ln_scale_2_p == nullptr ? nullptr : ln_scale_2_p->data<U>());
// fw parameters.
auto *src_mask_p = src_mask.get_ptr();
auto *qkv_weight_p = &qkv_weight;
auto *qkv_bias_p = qkv_bias.get_ptr();
auto *out_linear_weight_p = &out_linear_weight;
auto *out_linear_bias_p = out_linear_bias.get_ptr();
auto *qkv_weight_data = qkv_weight_p->data<T>();
auto *qkv_bias_data =
(qkv_bias_p == nullptr) ? nullptr : qkv_bias_p->data<T>();
auto *out_linear_weight_data = out_linear_weight_p->data<T>();
auto *out_linear_bias_data =
(out_linear_bias_p == nullptr) ? nullptr : out_linear_bias_p->data<T>();
// fw output
auto *fmha_out_p = &fmha_out;
auto *transpose_out_2_p = &transpose_out_2;
auto *qk_out_p = &qk_out;
auto *softmax_out_p = &softmax_out;
auto *attn_dropout_mask_out_p = &attn_dropout_mask_out;
auto *attn_dropout_out_p = &attn_dropout_out;
auto *src_mask_out_p = src_mask_out.get_ptr();
auto *ln_mean_2_p = ln_mean_2.get_ptr();
auto *ln_var_2_p = ln_var_2.get_ptr();
auto *dropout_mask_out_p = &dropout_mask_out;
auto *bias_dropout_residual_out_p = bias_dropout_residual_out.get_ptr();
auto *fmha_out_data = fmha_out_p->data<T>();
auto *transpose_out_2_data = transpose_out_2_p->data<T>();
auto *softmax_out_data = softmax_out_p->data<T>();
auto *src_mask_out_data =
(src_mask_p == nullptr) ? nullptr : src_mask_out_p->data<T>();
auto *dropout_mask_out_data =
has_dropout ? dropout_mask_out_p->data<uint8_t>() : nullptr;
auto *d_x_data =
dev_ctx.template Alloc<T>(x_grad, x_grad->numel() * sizeof(T));
// when qkv_bias_p is not nullptr, qkv_out_grad is equals to
// qkv_bias_out_grad, the space can be reused.
auto *d_qkv_out_data =
(qkv_bias_out_grad != nullptr)
? nullptr
: dev_ctx.template Alloc<T>(qkv_out_grad,
qkv_out_grad->numel() * sizeof(T));
auto *d_qkv_bias_out_data =
(qkv_bias_out_grad == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(qkv_bias_out_grad,
qkv_bias_out_grad->numel() * sizeof(T));
auto *d_qktv_out_data = dev_ctx.template Alloc<T>(
qktv_out_grad, qktv_out_grad->numel() * sizeof(T));
auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
transpose_out_2_grad, transpose_out_2_grad->numel() * sizeof(T));
auto *d_qk_out_data =
dev_ctx.template Alloc<T>(qk_out_grad, qk_out_grad->numel() * sizeof(T));
auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
softmax_out_grad, softmax_out_grad->numel() * sizeof(T));
auto *d_attn_dropout_out_data =
has_attn_dropout ? dev_ctx.template Alloc<T>(
attn_dropout_out_grad,
attn_dropout_out_grad->numel() * sizeof(T))
: nullptr;
auto *d_src_mask_out_data =
(src_mask_p == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(src_mask_out_grad,
src_mask_out_grad->numel() * sizeof(T));
auto *d_fmha_out_data = dev_ctx.template Alloc<T>(
fmha_out_grad, fmha_out_grad->numel() * sizeof(T));
auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
out_linear_out_grad, out_linear_out_grad->numel() * sizeof(T));
// parameter grad
auto *d_qkv_weight_data =
(qkv_weight_grad == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(qkv_weight_grad,
qkv_weight_grad->numel() * sizeof(T));
auto *d_qkv_bias_data =
(qkv_bias_grad == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(qkv_bias_grad,
qkv_bias_grad->numel() * sizeof(T));
auto *d_out_linear_weight_data =
(out_linear_weight_grad == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(
out_linear_weight_grad,
out_linear_weight_grad->numel() * sizeof(T));
auto *d_out_linear_bias_data =
(out_linear_bias_grad == nullptr)
? nullptr
: dev_ctx.template Alloc<T>(
out_linear_bias_grad,
out_linear_bias_grad->numel() * sizeof(T));
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight_p->dims();
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head;
int dim_head;
int nranks = 1;
if (!transpose_qkv_wb) {
num_head = qkv_w_dims[1];
dim_head = qkv_w_dims[2];
} else {
nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
num_head = num_heads;
dim_head = dim_embed / (num_head * nranks);
}
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
phi::DenseTensor d_residual;
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims);
d_residual_data =
dev_ctx.template Alloc<T>(&d_residual, d_residual.numel() * sizeof(T));
}
bool transA = false;
bool transB = transpose_qkv_wb ? false : true;
bool compute_qkv_bias = qkv_bias_p ? true : false;
auto layer_norm_compute =
phi::fusion::AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
auto qkv_compute = phi::fusion::AttnMatMul<T>(dev_ctx,
transA,
transB,
bsz_seq,
output_size,
input_size,
compute_qkv_bias);
phi::fusion::AttnDropoutParam attn_dropout_param(is_test,
attn_dropout_implementation,
attn_dropout_rate,
is_upscale_in_train_1,
attn_dropout_fix_seed,
attn_dropout_seed,
seed_1);
auto fmha_ref_compute = phi::fusion::FMHARef<T>(
dev_ctx, batch_size, max_seq_len, num_head, dim_head, attn_dropout_param);
output_size = hidden_size;
transA = false;
transB = false;
bool compute_bias = false;
// (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
auto out_linear_compute = phi::fusion::AttnMatMul<T>(
dev_ctx, transA, transB, bsz_seq, input_size, output_size, compute_bias);
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
fused_dropout_layernorm_helper(
dev_ctx, bsz_seq, dim_embed, dropout_param2, ln_epsilon);
if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
dev_ctx,
d_y_data,
dropout_mask_out_data,
d_out_linear_out_data,
d_residual_data,
d_out_linear_bias_data);
} else {
auto *ln_mean_2_data = ln_mean_2_p->data<U>();
auto *ln_var_2_data = ln_var_2_p->data<U>();
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out_p->data<T>();
auto *d_ln_2_scale_data =
(ln_scale_2_grad == nullptr
? nullptr
: dev_ctx.template Alloc<U>(ln_scale_2_grad,
ln_scale_2_grad->numel() * sizeof(U)));
auto *d_ln_bias_2_data =
(ln_bias_2_grad == nullptr
? nullptr
: dev_ctx.template Alloc<U>(ln_bias_2_grad,
ln_bias_2_grad->numel() * sizeof(U)));
auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
bias_dropout_residual_out_grad,
bias_dropout_residual_out_grad->numel() * sizeof(T));
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
dev_ctx,
d_y_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
ln_2_scale_data,
ln_mean_2_data,
ln_var_2_data,
d_bias_dropout_residual_out_data,
d_ln_2_scale_data,
d_ln_bias_2_data,
d_out_linear_out_data,
d_out_linear_bias_data,
d_residual_data);
}
out_linear_compute.ComputeBackward(fmha_out_p,
out_linear_weight_p,
out_linear_out_grad,
fmha_out_grad,
out_linear_weight_grad,
nullptr);
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
qkv_bias_out_grad->Resize(
{batch_size, max_seq_len, 3, num_head, dim_head});
} else {
qkv_out_grad->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
}
}
if (qkv_bias_p != nullptr) {
fmha_ref_compute.ComputeBackward(*transpose_out_2_p,
has_attn_dropout ? src_mask_p : nullptr,
*softmax_out_p,
*attn_dropout_mask_out_p,
*attn_dropout_out_p,
*qk_out_p,
*src_mask_out_p,
*fmha_out_grad,
qktv_out_grad,
attn_dropout_out_grad,
softmax_out_grad,
src_mask_out_grad,
qk_out_grad,
transpose_out_2_grad,
nullptr,
qkv_bias_out_grad);
} else {
fmha_ref_compute.ComputeBackward(*transpose_out_2_p,
has_attn_dropout ? src_mask_p : nullptr,
*softmax_out_p,
*attn_dropout_mask_out_p,
*attn_dropout_out_p,
*qk_out_p,
*src_mask_out_p,
*fmha_out_grad,
qktv_out_grad,
attn_dropout_out_grad,
softmax_out_grad,
src_mask_out_grad,
qk_out_grad,
transpose_out_2_grad,
nullptr,
qkv_out_grad);
}
if (transpose_qkv_wb) {
if (compute_qkv_bias) {
qkv_bias_out_grad->Resize({batch_size, max_seq_len, 3 * hidden_size});
} else {
qkv_out_grad->Resize({batch_size, max_seq_len, 3 * hidden_size});
}
}
if (pre_layer_norm) {
auto *ln_mean_p = ln_mean.get_ptr();
auto *ln_var_p = ln_var.get_ptr();
auto *ln_out_p = ln_out.get_ptr();
auto *ln_mean_data = ln_mean_p->data<U>();
auto *ln_var_data = ln_var_p->data<U>();
auto *ln_out_data = ln_out_p->data<T>();
auto *d_ln_out_data = dev_ctx.template Alloc<T>(
ln_out_grad, ln_out_grad->numel() * sizeof(T));
auto *d_ln_scale_data =
(ln_scale_grad == nullptr
? nullptr
: dev_ctx.template Alloc<U>(ln_scale_grad,
ln_scale_grad->numel() * sizeof(U)));
auto *d_ln_bias_data =
(ln_bias_grad == nullptr
? nullptr
: dev_ctx.template Alloc<U>(ln_bias_grad,
ln_bias_grad->numel() * sizeof(U)));
if (qkv_bias_p != nullptr) {
qkv_compute.ComputeBackward(ln_out_p,
qkv_weight_p,
qkv_bias_out_grad,
ln_out_grad,
qkv_weight_grad,
qkv_bias_grad);
} else {
qkv_compute.ComputeBackward(ln_out_p,
qkv_weight_p,
qkv_out_grad,
ln_out_grad,
qkv_weight_grad,
qkv_bias_grad);
}
// tensor model parallel
phi::fusion::AllReduce<T>(*ln_out_grad, ring_id, dev_ctx);
layer_norm_compute.ComputeBackward(x_data,
d_ln_out_data,
ln_scale_data,
ln_mean_data,
ln_var_data,
d_x_data,
d_ln_scale_data,
d_ln_bias_data);
} else {
if (qkv_bias_p != nullptr) {
qkv_compute.ComputeBackward(input_x,
qkv_weight_p,
qkv_bias_out_grad,
x_grad,
qkv_weight_grad,
qkv_bias_grad);
} else {
qkv_compute.ComputeBackward(input_x,
qkv_weight_p,
qkv_out_grad,
x_grad,
qkv_weight_grad,
qkv_bias_grad);
}
// tensor model parallel
phi::fusion::AllReduce<T>(*x_grad, ring_id, dev_ctx);
}
if (add_residual) {
// gradient accumulation
std::vector<const phi::DenseTensor *> ins = {&d_residual, x_grad};
std::vector<phi::DenseTensor *> outs = {x_grad};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, phi::funcs::AddFunctor<T>());
}
}
} // namespace fusion
} // namespace phi
......@@ -809,24 +813,27 @@ PD_REGISTER_KERNEL(fused_attention,
phi::dtype::float16,
double,
float) {
phi::DataType data_type;
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::FLOAT32) {
data_type = phi::DataType::FLOAT32;
} else {
data_type = phi::DataType::FLOAT64;
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(15).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(0).SetDataType(data_type);
kernel->OutputAt(1).SetDataType(data_type);
kernel->OutputAt(3).SetDataType(data_type);
kernel->OutputAt(4).SetDataType(data_type);
kernel->OutputAt(15).SetDataType(data_type);
kernel->OutputAt(16).SetDataType(data_type);
}
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
ops::FusedAttentionGradKernel<float>,
ops::FusedAttentionGradKernel<double>,
ops::FusedAttentionGradKernel<plat::float16>);
PD_REGISTER_KERNEL(fused_attention_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedAttentionGradKernel,
phi::dtype::float16,
double,
float) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32);
}
}
......@@ -58,7 +58,83 @@ KernelSignature AttentionFuseOpArgumentMapping(
"CacheKVOut", "Y"});
}
KernelSignature AttentionGradFuseOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fused_attention_grad",
{"Y@GRAD",
"X",
"QKVW",
"QKVBias",
"QKVBiasOut",
"SrcMask",
"SrcMaskOut",
"OutLinearW",
"OutLinearBias",
"LnScale",
"LnBias",
"Ln2Scale",
"Ln2Bias",
"LnOut",
"LnMean",
"LnVariance",
"Ln2Mean",
"Ln2Variance",
"BiasDropoutResidualOut",
"QKVOut",
"TransposeOut2",
"QKOut",
"QKTVOut",
"SoftmaxOut",
"AttnDropoutMaskOut",
"AttnDropoutOut",
"FMHAOut",
"OutLinearOut",
"DropoutMaskOut"},
{"num_heads",
"transpose_qkv_wb",
"pre_layer_norm",
"epsilon",
"attn_dropout_rate",
"is_test",
"attn_dropout_fix_seed",
"attn_dropout_seed",
"attn_dropout_implementation",
"dropout_rate",
"dropout_fix_seed",
"dropout_seed",
"dropout_implementation",
"ln_epsilon",
"add_residual",
"ring_id"},
{
"QKVBias@GRAD",
"QKVBiasOut@GRAD",
"SrcMaskOut@GRAD",
"OutLinearBias@GRAD",
"LnScale@GRAD",
"LnBias@GRAD",
"Ln2Scale@GRAD",
"Ln2Bias@GRAD",
"X@GRAD",
"QKVW@GRAD",
"OutLinearW@GRAD",
"LnOut@GRAD",
"BiasDropoutResidualOut@GRAD",
"QKVOut@GRAD",
"QKTVOut@GRAD",
"TransposeOut2@GRAD",
"QKOut@GRAD",
"SoftmaxOut@GRAD",
"AttnDropoutOut@GRAD",
"FMHAOut@GRAD",
"OutLinearOut@GRAD",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_attention,
phi::AttentionFuseOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(fused_attention_grad,
phi::AttentionGradFuseOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册