From 3bac6264a124cef28bff8f13d18fa8b66a9d1d6a Mon Sep 17 00:00:00 2001 From: Sonder <55493212+AndSonder@users.noreply.github.com> Date: Fri, 14 Apr 2023 14:22:12 +0800 Subject: [PATCH] =?UTF-8?q?Move=20fused=5Fattention=20op=20to=20phi=20[?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E5=8F=8D=E5=90=91=20GPU=20OpKernel]=20(#5190?= =?UTF-8?q?9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add kernel functions * update kernel functions * update func parameters' name * create codes for gpu device * 调整文件位置 * fix include error * remove dependent files to phi/ * restore fused_attention_op.cu * fix dependence errors * fix dependence errors * fix include error * fix all depandence errors[build success] * remove useless include * recover useless include * use phi::ToNCCLDataType * fix namespace * update new register code * fix error in fused_gemm_epilogue_utils * fix error in FusedAttentionKernel parm * finish fused_attention registe code[build success] * add paddle::optional * add sig file * fix build error * fix a include error * 恢复正向代码 * update CMkaeList * trans Compute function to phi [build success] * add register code and fix include error [build success] * fix parameter sequence * add include file * update #if before include * update #if before include * fix grammly error * update codes for DropoutParam * remove const cast * trans some fluid api to phi api * remove const cast * trans some fluid api to phi api * add #if * update test code * update test codes * recover test codes * fix namespace and remove fluid include * recover random seed * remove fluid quant_helper * fix include error * include utils in funcs * change include file * move grad codes back to fluid floder * move grad codes back to fluid floder * fix sig file error * update include * recover codes to develop * update register codes * fix build error * recover fluid include * remove some fluid include * remove some fluid include * Update fused_attention_op.cu * remove fluid include * add some fluid include * Update fused_attention_op.cu * Update fused_attention_op.cu * Update fused_attention_op.cu * Update fused_attention_op.cu * remote useless include --- .../operators/fused/fused_attention_op.cu | 921 +++++++++--------- paddle/phi/ops/compat/fused_attention_sig.cc | 76 ++ 2 files changed, 540 insertions(+), 457 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 28ef8d32cd2..de62fe38653 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,458 +17,25 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/attn_gemm.h" -#include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_attention_utils.h" -#include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" + #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/transpose_function.cu.h" - -// for phi fused attention -// fluid include will be removed after fused attention grad kernel is merged -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" #include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" #include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" #include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" -namespace paddle { -namespace operators { - -template -class FusedAttentionGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - const int num_heads = ctx.Attr("num_heads"); - const bool transpose_qkv_wb = ctx.Attr("transpose_qkv_wb"); - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - const float epsilon = ctx.Attr("epsilon"); - const float ln2epsilon = ctx.Attr("ln_epsilon"); - - const float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); - const bool has_attn_dropout = (attn_dropout_prob != 0.0f); - DropoutParam dropout_param2(ctx, 0); - const bool has_dropout = (dropout_param2.dropout_prob != 0.0f); - - auto &dev_ctx = ctx.template device_context(); - bool is_test_1 = ctx.Attr("is_test"); - auto &dropout_implementation_1 = - ctx.Attr("attn_dropout_implementation"); - bool is_upscale_in_train_1 = - (dropout_implementation_1 == "upscale_in_train"); - auto *seed_1 = - ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; - bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); - int seed_val_1 = ctx.Attr("attn_dropout_seed"); - int ring_id = ctx.Attr("ring_id"); - - // get inputs. - auto *d_y = ctx.Input(framework::GradVarName("Y")); - auto *d_y_data = d_y->data(); - - // fw input - auto *input_x = ctx.Input("X"); - auto *ln_scale = ctx.Input("LnScale"); - auto *ln_2_scale = ctx.Input("Ln2Scale"); - auto *x_data = input_x->data(); - auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); - auto *ln_2_scale_data = - (ln_2_scale == nullptr ? nullptr : ln_2_scale->data()); - // fw parameters. - auto *src_mask = ctx.Input("SrcMask"); - auto *qkv_weight = ctx.Input("QKVW"); - auto *qkv_bias = ctx.Input("QKVBias"); - auto *out_linear_weight = ctx.Input("OutLinearW"); - auto *out_linear_bias = ctx.Input("OutLinearBias"); - auto *qkv_weight_data = qkv_weight->data(); - auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); - auto *out_linear_weight_data = out_linear_weight->data(); - auto *out_linear_bias_data = - (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); - - // fw output - auto *fmha_out = ctx.Input("FMHAOut"); - auto *transpose_out_2 = ctx.Input("TransposeOut2"); - auto *qk_out = ctx.Input("QKOut"); - auto *softmax_out = ctx.Input("SoftmaxOut"); - auto *attn_dropout_mask_out = - ctx.Input("AttnDropoutMaskOut"); - auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); - auto *src_mask_out = ctx.Input("SrcMaskOut"); - auto *ln_2_mean = ctx.Input("Ln2Mean"); - auto *ln_2_var = ctx.Input("Ln2Variance"); - auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); - auto *bias_dropout_residual_out = - ctx.Input("BiasDropoutResidualOut"); - auto *fmha_out_data = fmha_out->data(); - auto *transpose_out_2_data = transpose_out_2->data(); - auto *softmax_out_data = softmax_out->data(); - auto *src_mask_out_data = - (src_mask == nullptr) ? nullptr : src_mask_out->data(); - auto *dropout_mask_out_data = - has_dropout ? dropout_mask_out->data() : nullptr; - - // output's grad - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_qkv_out = - ctx.Output(framework::GradVarName("QKVOut")); - auto *d_qkv_bias_out = - ctx.Output(framework::GradVarName("QKVBiasOut")); - auto *d_qktv_out = - ctx.Output(framework::GradVarName("QKTVOut")); - auto *d_transpose_out_2 = - ctx.Output(framework::GradVarName("TransposeOut2")); - auto *d_qk_out = - ctx.Output(framework::GradVarName("QKOut")); - auto *d_softmax_out = - ctx.Output(framework::GradVarName("SoftmaxOut")); - auto *d_attn_dropout_out = - ctx.Output(framework::GradVarName("AttnDropoutOut")); - auto *d_src_mask_out = - ctx.Output(framework::GradVarName("SrcMaskOut")); - auto *d_fmha_out = - ctx.Output(framework::GradVarName("FMHAOut")); - auto *d_out_linear_out = - ctx.Output(framework::GradVarName("OutLinearOut")); - auto *d_bias_dropout_residual_out = ctx.Output( - framework::GradVarName("BiasDropoutResidualOut")); - auto *d_x_data = dev_ctx.template Alloc(d_x, d_x->numel() * sizeof(T)); - // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the - // space can be reused. - auto *d_qkv_out_data = (d_qkv_bias_out != nullptr) - ? nullptr - : dev_ctx.template Alloc( - d_qkv_out, d_qkv_out->numel() * sizeof(T)); - auto *d_qkv_bias_out_data = - (d_qkv_bias_out == nullptr) - ? nullptr - : dev_ctx.template Alloc(d_qkv_bias_out, - d_qkv_bias_out->numel() * sizeof(T)); - auto *d_qktv_out_data = - dev_ctx.template Alloc(d_qktv_out, d_qktv_out->numel() * sizeof(T)); - auto *d_transpose_out_2_data = dev_ctx.template Alloc( - d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T)); - auto *d_qk_out_data = - dev_ctx.template Alloc(d_qk_out, d_qk_out->numel() * sizeof(T)); - auto *d_softmax_out_data = dev_ctx.template Alloc( - d_softmax_out, d_softmax_out->numel() * sizeof(T)); - auto *d_attn_dropout_out_data = - has_attn_dropout - ? dev_ctx.template Alloc(d_attn_dropout_out, - d_attn_dropout_out->numel() * sizeof(T)) - : nullptr; - auto *d_src_mask_out_data = - (src_mask == nullptr) - ? nullptr - : dev_ctx.template Alloc(d_src_mask_out, - d_src_mask_out->numel() * sizeof(T)); - auto *d_fmha_out_data = - dev_ctx.template Alloc(d_fmha_out, d_fmha_out->numel() * sizeof(T)); - auto *d_out_linear_out_data = dev_ctx.template Alloc( - d_out_linear_out, d_out_linear_out->numel() * sizeof(T)); - - // parameter grad - auto *d_qkv_weight = - ctx.Output(framework::GradVarName("QKVW")); - auto *d_qkv_bias = - ctx.Output(framework::GradVarName("QKVBias")); - auto *d_out_linear_weight = - ctx.Output(framework::GradVarName("OutLinearW")); - auto *d_out_linear_bias = - ctx.Output(framework::GradVarName("OutLinearBias")); - auto *d_ln_2_scale = - ctx.Output(framework::GradVarName("Ln2Scale")); - auto *d_ln_2_bias = - ctx.Output(framework::GradVarName("Ln2Bias")); - - auto *d_qkv_weight_data = - (d_qkv_weight == nullptr) - ? nullptr - : dev_ctx.template Alloc(d_qkv_weight, - d_qkv_weight->numel() * sizeof(T)); - - auto *d_qkv_bias_data = - (d_qkv_bias == nullptr) - ? nullptr - : dev_ctx.template Alloc(d_qkv_bias, - d_qkv_bias->numel() * sizeof(T)); - auto *d_out_linear_weight_data = - (d_out_linear_weight == nullptr) - ? nullptr - : dev_ctx.template Alloc( - d_out_linear_weight, - d_out_linear_weight->numel() * sizeof(T)); - - auto *d_out_linear_bias_data = - (d_out_linear_bias == nullptr) - ? nullptr - : dev_ctx.template Alloc(d_out_linear_bias, - d_out_linear_bias->numel() * sizeof(T)); - - const auto input_x_dims = input_x->dims(); - const auto qkv_w_dims = qkv_weight->dims(); - - int batch_size = input_x_dims[0]; - int max_seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int num_head; - int dim_head; - int nranks = 1; - if (!transpose_qkv_wb) { - num_head = qkv_w_dims[1]; - dim_head = qkv_w_dims[2]; - } else { - nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1]; - num_head = num_heads; - dim_head = dim_embed / (num_head * nranks); - } - - int bsz_seq = batch_size * max_seq_len; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool add_residual = ctx.Attr("add_residual"); - phi::DenseTensor d_residual; - T *d_residual_data = nullptr; - if (add_residual) { - d_residual.Resize(input_x_dims); - d_residual_data = dev_ctx.template Alloc( - &d_residual, d_residual.numel() * sizeof(T)); - } - - bool transA = false; - bool transB = transpose_qkv_wb ? false : true; - bool compute_qkv_bias = qkv_bias ? true : false; - auto layer_norm_compute = AttnLayerNorm( - ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); - auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), - transA, - transB, - bsz_seq, - output_size, - input_size, - compute_qkv_bias); - AttnDropoutParam attn_dropout_param(is_test_1, - dropout_implementation_1, - attn_dropout_prob, - is_upscale_in_train_1, - is_fix_seed_1, - seed_val_1, - seed_1); - auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), - batch_size, - max_seq_len, - num_head, - dim_head, - attn_dropout_param); - output_size = hidden_size; - transA = false; - transB = false; - bool compute_bias = false; - // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed) - auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), - transA, - transB, - bsz_seq, - input_size, - output_size, - compute_bias); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - ctx.cuda_device_context(), - bsz_seq, - dim_embed, - dropout_param2, - ln2epsilon); - - if (pre_layer_norm) { - fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( - ctx.cuda_device_context(), - d_y_data, - dropout_mask_out_data, - d_out_linear_out_data, - d_residual_data, - d_out_linear_bias_data); - } else { - auto *ln_2_mean_data = ln_2_mean->data(); - auto *ln_2_var_data = ln_2_var->data(); - auto *bias_dropout_residual_out_data = - bias_dropout_residual_out->data(); - auto *d_ln_2_scale_data = - (d_ln_2_scale == nullptr - ? nullptr - : dev_ctx.template Alloc(d_ln_2_scale, - d_ln_2_scale->numel() * sizeof(U))); - auto *d_ln_2_bias_data = - (d_ln_2_bias == nullptr - ? nullptr - : dev_ctx.template Alloc(d_ln_2_bias, - d_ln_2_bias->numel() * sizeof(U))); - auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc( - d_bias_dropout_residual_out, - d_bias_dropout_residual_out->numel() * sizeof(T)); - - fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( - ctx.cuda_device_context(), - d_y_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - ln_2_scale_data, - ln_2_mean_data, - ln_2_var_data, - d_bias_dropout_residual_out_data, - d_ln_2_scale_data, - d_ln_2_bias_data, - d_out_linear_out_data, - d_out_linear_bias_data, - d_residual_data); - } - - out_linear_compute.ComputeBackward(fmha_out, - out_linear_weight, - d_out_linear_out, - d_fmha_out, - d_out_linear_weight, - nullptr); - - if (transpose_qkv_wb) { - if (compute_qkv_bias) { - d_qkv_bias_out->Resize( - {batch_size, max_seq_len, 3, num_head, dim_head}); - } else { - d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head}); - } - } - - if (qkv_bias != nullptr) { - fmha_ref_compute.ComputeBackward(*transpose_out_2, - has_attn_dropout ? src_mask : nullptr, - *softmax_out, - *attn_dropout_mask_out, - *attn_dropout_out, - *qk_out, - *src_mask_out, - *d_fmha_out, - d_qktv_out, - d_attn_dropout_out, - d_softmax_out, - d_src_mask_out, - d_qk_out, - d_transpose_out_2, - nullptr, - d_qkv_bias_out); - } else { - fmha_ref_compute.ComputeBackward(*transpose_out_2, - has_attn_dropout ? src_mask : nullptr, - *softmax_out, - *attn_dropout_mask_out, - *attn_dropout_out, - *qk_out, - *src_mask_out, - *d_fmha_out, - d_qktv_out, - d_attn_dropout_out, - d_softmax_out, - d_src_mask_out, - d_qk_out, - d_transpose_out_2, - nullptr, - d_qkv_out); - } - - if (transpose_qkv_wb) { - if (compute_qkv_bias) { - d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size}); - } else { - d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size}); - } - } - - if (pre_layer_norm) { - auto *ln_mean = ctx.Input("LnMean"); - auto *ln_var = ctx.Input("LnVariance"); - auto *ln_out = ctx.Input("LnOut"); - auto *ln_mean_data = ln_mean->data(); - auto *ln_var_data = ln_var->data(); - auto *ln_out_data = ln_out->data(); - - auto *d_ln_out = - ctx.Output(framework::GradVarName("LnOut")); - auto *d_ln_scale = - ctx.Output(framework::GradVarName("LnScale")); - auto *d_ln_bias = - ctx.Output(framework::GradVarName("LnBias")); - auto *d_ln_out_data = - dev_ctx.template Alloc(d_ln_out, d_ln_out->numel() * sizeof(T)); - auto *d_ln_scale_data = - (d_ln_scale == nullptr - ? nullptr - : dev_ctx.template Alloc(d_ln_scale, - d_ln_scale->numel() * sizeof(U))); - auto *d_ln_bias_data = - (d_ln_bias == nullptr - ? nullptr - : dev_ctx.template Alloc(d_ln_bias, - d_ln_bias->numel() * sizeof(U))); - if (qkv_bias != nullptr) { - qkv_compute.ComputeBackward(ln_out, - qkv_weight, - d_qkv_bias_out, - d_ln_out, - d_qkv_weight, - d_qkv_bias); - } else { - qkv_compute.ComputeBackward( - ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias); - } - // tensor model parallel - phi::fusion::AllReduce(*d_ln_out, ring_id, ctx.cuda_device_context()); - layer_norm_compute.ComputeBackward(x_data, - d_ln_out_data, - ln_scale_data, - ln_mean_data, - ln_var_data, - d_x_data, - d_ln_scale_data, - d_ln_bias_data); - } else { - if (qkv_bias != nullptr) { - qkv_compute.ComputeBackward( - input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias); - } else { - qkv_compute.ComputeBackward( - input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias); - } - // tensor model parallel - phi::fusion::AllReduce(*d_x, ring_id, ctx.cuda_device_context()); - } - - if (add_residual) { - // gradient accumulation - std::vector ins = {&d_residual, d_x}; - std::vector outs = {d_x}; - phi::funcs::ElementwiseKernel( - ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor()); - } - } -}; - -} // namespace operators -} // namespace paddle - namespace phi { namespace fusion { @@ -799,6 +366,443 @@ void FusedAttentionKernel(const Context &dev_ctx, } } +template +void FusedAttentionGradKernel( + const Context &dev_ctx, + const DenseTensor &out_grad, + const DenseTensor &x, + const DenseTensor &qkv_weight, + const paddle::optional &qkv_bias, + const paddle::optional &qkv_bias_out, + const paddle::optional &src_mask, + const paddle::optional &src_mask_out, + const DenseTensor &out_linear_weight, + const paddle::optional &out_linear_bias, + const paddle::optional &ln_scale, + const paddle::optional &ln_bias, + const paddle::optional &ln_scale_2, + const paddle::optional &ln_bias_2, + const paddle::optional &ln_out, + const paddle::optional &ln_mean, + const paddle::optional &ln_var, + const paddle::optional &ln_mean_2, + const paddle::optional &ln_var_2, + const paddle::optional &bias_dropout_residual_out, + const DenseTensor &qkv_out, + const DenseTensor &transpose_out_2, + const DenseTensor &qk_out, + const DenseTensor &qktv_out, + const DenseTensor &softmax_out, + const DenseTensor &attn_dropout_mask_out, + const DenseTensor &attn_dropout_out, + const DenseTensor &fmha_out, + const DenseTensor &out_linear_out, + const DenseTensor &dropout_mask_out, + int num_heads, + bool transpose_qkv_wb, + bool pre_layer_norm, + float epsilon, + float attn_dropout_rate, + bool is_test, + bool attn_dropout_fix_seed, + int attn_dropout_seed, + const std::string &attn_dropout_implementation, + float dropout_rate, + bool dropout_fix_seed, + int dropout_seed, + const std::string &dropout_implementation, + float ln_epsilon, + bool add_residual, + int ring_id, + DenseTensor *qkv_bias_grad, + DenseTensor *qkv_bias_out_grad, + DenseTensor *src_mask_out_grad, + DenseTensor *out_linear_bias_grad, + DenseTensor *ln_scale_grad, + DenseTensor *ln_bias_grad, + DenseTensor *ln_scale_2_grad, + DenseTensor *ln_bias_2_grad, + DenseTensor *x_grad, + DenseTensor *qkv_weight_grad, + DenseTensor *out_linear_weight_grad, + DenseTensor *ln_out_grad, + DenseTensor *bias_dropout_residual_out_grad, + DenseTensor *qkv_out_grad, + DenseTensor *qktv_out_grad, + DenseTensor *transpose_out_2_grad, + DenseTensor *qk_out_grad, + DenseTensor *softmax_out_grad, + DenseTensor *attn_dropout_out_grad, + DenseTensor *fmha_out_grad, + DenseTensor *out_linear_out_grad) { + using U = phi::fusion::LayerNormParamType; + + const bool has_attn_dropout = (attn_dropout_rate != 0.0f); + + const bool is_upscale_in_train = + (dropout_implementation == "upscale_in_train"); + phi::fusion::DropoutParam dropout_param2(dropout_fix_seed, + 0, + is_test, + is_upscale_in_train, + dropout_rate, + nullptr, + dropout_seed); + const bool has_dropout = (dropout_param2.dropout_prob != 0.0f); + + bool is_upscale_in_train_1 = + (attn_dropout_implementation == "upscale_in_train"); + phi::DenseTensor *seed_1 = nullptr; + + // get inputs. + auto *d_y = &out_grad; + auto *d_y_data = d_y->data(); + + // fw input + auto *input_x = &x; + auto *ln_scale_p = ln_scale.get_ptr(); + auto *ln_scale_2_p = ln_scale_2.get_ptr(); + auto *x_data = input_x->data(); + auto *ln_scale_data = + (ln_scale_p == nullptr ? nullptr : ln_scale_p->data()); + auto *ln_2_scale_data = + (ln_scale_2_p == nullptr ? nullptr : ln_scale_2_p->data()); + // fw parameters. + auto *src_mask_p = src_mask.get_ptr(); + auto *qkv_weight_p = &qkv_weight; + auto *qkv_bias_p = qkv_bias.get_ptr(); + auto *out_linear_weight_p = &out_linear_weight; + auto *out_linear_bias_p = out_linear_bias.get_ptr(); + auto *qkv_weight_data = qkv_weight_p->data(); + auto *qkv_bias_data = + (qkv_bias_p == nullptr) ? nullptr : qkv_bias_p->data(); + auto *out_linear_weight_data = out_linear_weight_p->data(); + auto *out_linear_bias_data = + (out_linear_bias_p == nullptr) ? nullptr : out_linear_bias_p->data(); + + // fw output + auto *fmha_out_p = &fmha_out; + auto *transpose_out_2_p = &transpose_out_2; + auto *qk_out_p = &qk_out; + auto *softmax_out_p = &softmax_out; + auto *attn_dropout_mask_out_p = &attn_dropout_mask_out; + auto *attn_dropout_out_p = &attn_dropout_out; + auto *src_mask_out_p = src_mask_out.get_ptr(); + auto *ln_mean_2_p = ln_mean_2.get_ptr(); + auto *ln_var_2_p = ln_var_2.get_ptr(); + auto *dropout_mask_out_p = &dropout_mask_out; + auto *bias_dropout_residual_out_p = bias_dropout_residual_out.get_ptr(); + auto *fmha_out_data = fmha_out_p->data(); + auto *transpose_out_2_data = transpose_out_2_p->data(); + auto *softmax_out_data = softmax_out_p->data(); + auto *src_mask_out_data = + (src_mask_p == nullptr) ? nullptr : src_mask_out_p->data(); + auto *dropout_mask_out_data = + has_dropout ? dropout_mask_out_p->data() : nullptr; + + auto *d_x_data = + dev_ctx.template Alloc(x_grad, x_grad->numel() * sizeof(T)); + // when qkv_bias_p is not nullptr, qkv_out_grad is equals to + // qkv_bias_out_grad, the space can be reused. + auto *d_qkv_out_data = + (qkv_bias_out_grad != nullptr) + ? nullptr + : dev_ctx.template Alloc(qkv_out_grad, + qkv_out_grad->numel() * sizeof(T)); + auto *d_qkv_bias_out_data = + (qkv_bias_out_grad == nullptr) + ? nullptr + : dev_ctx.template Alloc(qkv_bias_out_grad, + qkv_bias_out_grad->numel() * sizeof(T)); + auto *d_qktv_out_data = dev_ctx.template Alloc( + qktv_out_grad, qktv_out_grad->numel() * sizeof(T)); + auto *d_transpose_out_2_data = dev_ctx.template Alloc( + transpose_out_2_grad, transpose_out_2_grad->numel() * sizeof(T)); + auto *d_qk_out_data = + dev_ctx.template Alloc(qk_out_grad, qk_out_grad->numel() * sizeof(T)); + auto *d_softmax_out_data = dev_ctx.template Alloc( + softmax_out_grad, softmax_out_grad->numel() * sizeof(T)); + auto *d_attn_dropout_out_data = + has_attn_dropout ? dev_ctx.template Alloc( + attn_dropout_out_grad, + attn_dropout_out_grad->numel() * sizeof(T)) + : nullptr; + auto *d_src_mask_out_data = + (src_mask_p == nullptr) + ? nullptr + : dev_ctx.template Alloc(src_mask_out_grad, + src_mask_out_grad->numel() * sizeof(T)); + auto *d_fmha_out_data = dev_ctx.template Alloc( + fmha_out_grad, fmha_out_grad->numel() * sizeof(T)); + auto *d_out_linear_out_data = dev_ctx.template Alloc( + out_linear_out_grad, out_linear_out_grad->numel() * sizeof(T)); + + // parameter grad + auto *d_qkv_weight_data = + (qkv_weight_grad == nullptr) + ? nullptr + : dev_ctx.template Alloc(qkv_weight_grad, + qkv_weight_grad->numel() * sizeof(T)); + + auto *d_qkv_bias_data = + (qkv_bias_grad == nullptr) + ? nullptr + : dev_ctx.template Alloc(qkv_bias_grad, + qkv_bias_grad->numel() * sizeof(T)); + auto *d_out_linear_weight_data = + (out_linear_weight_grad == nullptr) + ? nullptr + : dev_ctx.template Alloc( + out_linear_weight_grad, + out_linear_weight_grad->numel() * sizeof(T)); + + auto *d_out_linear_bias_data = + (out_linear_bias_grad == nullptr) + ? nullptr + : dev_ctx.template Alloc( + out_linear_bias_grad, + out_linear_bias_grad->numel() * sizeof(T)); + + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight_p->dims(); + + int batch_size = input_x_dims[0]; + int max_seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int num_head; + int dim_head; + int nranks = 1; + if (!transpose_qkv_wb) { + num_head = qkv_w_dims[1]; + dim_head = qkv_w_dims[2]; + } else { + nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1]; + num_head = num_heads; + dim_head = dim_embed / (num_head * nranks); + } + + int bsz_seq = batch_size * max_seq_len; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + phi::DenseTensor d_residual; + T *d_residual_data = nullptr; + if (add_residual) { + d_residual.Resize(input_x_dims); + d_residual_data = + dev_ctx.template Alloc(&d_residual, d_residual.numel() * sizeof(T)); + } + + bool transA = false; + bool transB = transpose_qkv_wb ? false : true; + bool compute_qkv_bias = qkv_bias_p ? true : false; + auto layer_norm_compute = + phi::fusion::AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, + transA, + transB, + bsz_seq, + output_size, + input_size, + compute_qkv_bias); + phi::fusion::AttnDropoutParam attn_dropout_param(is_test, + attn_dropout_implementation, + attn_dropout_rate, + is_upscale_in_train_1, + attn_dropout_fix_seed, + attn_dropout_seed, + seed_1); + auto fmha_ref_compute = phi::fusion::FMHARef( + dev_ctx, batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); + output_size = hidden_size; + transA = false; + transB = false; + bool compute_bias = false; + // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed) + auto out_linear_compute = phi::fusion::AttnMatMul( + dev_ctx, transA, transB, bsz_seq, input_size, output_size, compute_bias); + phi::fusion::FusedDropoutLayerNormHelper + fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, dim_embed, dropout_param2, ln_epsilon); + + if (pre_layer_norm) { + fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( + dev_ctx, + d_y_data, + dropout_mask_out_data, + d_out_linear_out_data, + d_residual_data, + d_out_linear_bias_data); + } else { + auto *ln_mean_2_data = ln_mean_2_p->data(); + auto *ln_var_2_data = ln_var_2_p->data(); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out_p->data(); + auto *d_ln_2_scale_data = + (ln_scale_2_grad == nullptr + ? nullptr + : dev_ctx.template Alloc(ln_scale_2_grad, + ln_scale_2_grad->numel() * sizeof(U))); + auto *d_ln_bias_2_data = + (ln_bias_2_grad == nullptr + ? nullptr + : dev_ctx.template Alloc(ln_bias_2_grad, + ln_bias_2_grad->numel() * sizeof(U))); + auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc( + bias_dropout_residual_out_grad, + bias_dropout_residual_out_grad->numel() * sizeof(T)); + + fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( + dev_ctx, + d_y_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + ln_2_scale_data, + ln_mean_2_data, + ln_var_2_data, + d_bias_dropout_residual_out_data, + d_ln_2_scale_data, + d_ln_bias_2_data, + d_out_linear_out_data, + d_out_linear_bias_data, + d_residual_data); + } + + out_linear_compute.ComputeBackward(fmha_out_p, + out_linear_weight_p, + out_linear_out_grad, + fmha_out_grad, + out_linear_weight_grad, + nullptr); + + if (transpose_qkv_wb) { + if (compute_qkv_bias) { + qkv_bias_out_grad->Resize( + {batch_size, max_seq_len, 3, num_head, dim_head}); + } else { + qkv_out_grad->Resize({batch_size, max_seq_len, 3, num_head, dim_head}); + } + } + + if (qkv_bias_p != nullptr) { + fmha_ref_compute.ComputeBackward(*transpose_out_2_p, + has_attn_dropout ? src_mask_p : nullptr, + *softmax_out_p, + *attn_dropout_mask_out_p, + *attn_dropout_out_p, + *qk_out_p, + *src_mask_out_p, + *fmha_out_grad, + qktv_out_grad, + attn_dropout_out_grad, + softmax_out_grad, + src_mask_out_grad, + qk_out_grad, + transpose_out_2_grad, + nullptr, + qkv_bias_out_grad); + } else { + fmha_ref_compute.ComputeBackward(*transpose_out_2_p, + has_attn_dropout ? src_mask_p : nullptr, + *softmax_out_p, + *attn_dropout_mask_out_p, + *attn_dropout_out_p, + *qk_out_p, + *src_mask_out_p, + *fmha_out_grad, + qktv_out_grad, + attn_dropout_out_grad, + softmax_out_grad, + src_mask_out_grad, + qk_out_grad, + transpose_out_2_grad, + nullptr, + qkv_out_grad); + } + + if (transpose_qkv_wb) { + if (compute_qkv_bias) { + qkv_bias_out_grad->Resize({batch_size, max_seq_len, 3 * hidden_size}); + } else { + qkv_out_grad->Resize({batch_size, max_seq_len, 3 * hidden_size}); + } + } + + if (pre_layer_norm) { + auto *ln_mean_p = ln_mean.get_ptr(); + auto *ln_var_p = ln_var.get_ptr(); + auto *ln_out_p = ln_out.get_ptr(); + auto *ln_mean_data = ln_mean_p->data(); + auto *ln_var_data = ln_var_p->data(); + auto *ln_out_data = ln_out_p->data(); + + auto *d_ln_out_data = dev_ctx.template Alloc( + ln_out_grad, ln_out_grad->numel() * sizeof(T)); + auto *d_ln_scale_data = + (ln_scale_grad == nullptr + ? nullptr + : dev_ctx.template Alloc(ln_scale_grad, + ln_scale_grad->numel() * sizeof(U))); + auto *d_ln_bias_data = + (ln_bias_grad == nullptr + ? nullptr + : dev_ctx.template Alloc(ln_bias_grad, + ln_bias_grad->numel() * sizeof(U))); + if (qkv_bias_p != nullptr) { + qkv_compute.ComputeBackward(ln_out_p, + qkv_weight_p, + qkv_bias_out_grad, + ln_out_grad, + qkv_weight_grad, + qkv_bias_grad); + } else { + qkv_compute.ComputeBackward(ln_out_p, + qkv_weight_p, + qkv_out_grad, + ln_out_grad, + qkv_weight_grad, + qkv_bias_grad); + } + // tensor model parallel + phi::fusion::AllReduce(*ln_out_grad, ring_id, dev_ctx); + layer_norm_compute.ComputeBackward(x_data, + d_ln_out_data, + ln_scale_data, + ln_mean_data, + ln_var_data, + d_x_data, + d_ln_scale_data, + d_ln_bias_data); + } else { + if (qkv_bias_p != nullptr) { + qkv_compute.ComputeBackward(input_x, + qkv_weight_p, + qkv_bias_out_grad, + x_grad, + qkv_weight_grad, + qkv_bias_grad); + } else { + qkv_compute.ComputeBackward(input_x, + qkv_weight_p, + qkv_out_grad, + x_grad, + qkv_weight_grad, + qkv_bias_grad); + } + // tensor model parallel + phi::fusion::AllReduce(*x_grad, ring_id, dev_ctx); + } + + if (add_residual) { + // gradient accumulation + std::vector ins = {&d_residual, x_grad}; + std::vector outs = {x_grad}; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, phi::funcs::AddFunctor()); + } +} + } // namespace fusion } // namespace phi @@ -809,24 +813,27 @@ PD_REGISTER_KERNEL(fused_attention, phi::dtype::float16, double, float) { - phi::DataType data_type; - if (kernel_key.dtype() == phi::DataType::FLOAT16 || - kernel_key.dtype() == phi::DataType::FLOAT32) { - data_type = phi::DataType::FLOAT32; - } else { - data_type = phi::DataType::FLOAT64; + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(15).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(16).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(0).SetDataType(data_type); - kernel->OutputAt(1).SetDataType(data_type); - kernel->OutputAt(3).SetDataType(data_type); - kernel->OutputAt(4).SetDataType(data_type); - kernel->OutputAt(15).SetDataType(data_type); - kernel->OutputAt(16).SetDataType(data_type); } -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(fused_attention_grad, - ops::FusedAttentionGradKernel, - ops::FusedAttentionGradKernel, - ops::FusedAttentionGradKernel); +PD_REGISTER_KERNEL(fused_attention_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedAttentionGradKernel, + phi::dtype::float16, + double, + float) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(7).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/paddle/phi/ops/compat/fused_attention_sig.cc b/paddle/phi/ops/compat/fused_attention_sig.cc index 5fe1cb289e1..8e0c47b5317 100644 --- a/paddle/phi/ops/compat/fused_attention_sig.cc +++ b/paddle/phi/ops/compat/fused_attention_sig.cc @@ -58,7 +58,83 @@ KernelSignature AttentionFuseOpArgumentMapping( "CacheKVOut", "Y"}); } +KernelSignature AttentionGradFuseOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_attention_grad", + {"Y@GRAD", + "X", + "QKVW", + "QKVBias", + "QKVBiasOut", + "SrcMask", + "SrcMaskOut", + "OutLinearW", + "OutLinearBias", + "LnScale", + "LnBias", + "Ln2Scale", + "Ln2Bias", + "LnOut", + "LnMean", + "LnVariance", + "Ln2Mean", + "Ln2Variance", + "BiasDropoutResidualOut", + "QKVOut", + "TransposeOut2", + "QKOut", + "QKTVOut", + "SoftmaxOut", + "AttnDropoutMaskOut", + "AttnDropoutOut", + "FMHAOut", + "OutLinearOut", + "DropoutMaskOut"}, + {"num_heads", + "transpose_qkv_wb", + "pre_layer_norm", + "epsilon", + "attn_dropout_rate", + "is_test", + "attn_dropout_fix_seed", + "attn_dropout_seed", + "attn_dropout_implementation", + "dropout_rate", + "dropout_fix_seed", + "dropout_seed", + "dropout_implementation", + "ln_epsilon", + "add_residual", + "ring_id"}, + { + "QKVBias@GRAD", + "QKVBiasOut@GRAD", + "SrcMaskOut@GRAD", + "OutLinearBias@GRAD", + "LnScale@GRAD", + "LnBias@GRAD", + "Ln2Scale@GRAD", + "Ln2Bias@GRAD", + "X@GRAD", + "QKVW@GRAD", + "OutLinearW@GRAD", + "LnOut@GRAD", + "BiasDropoutResidualOut@GRAD", + "QKVOut@GRAD", + "QKTVOut@GRAD", + "TransposeOut2@GRAD", + "QKOut@GRAD", + "SoftmaxOut@GRAD", + "AttnDropoutOut@GRAD", + "FMHAOut@GRAD", + "OutLinearOut@GRAD", + }); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(fused_attention, phi::AttentionFuseOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(fused_attention_grad, + phi::AttentionGradFuseOpArgumentMapping); -- GitLab