/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h" #include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template class FusedAttentionOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; auto *input_x = ctx.Input("X"); const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); auto *ln_scale = ctx.Input("LnScale"); auto *ln_bias = ctx.Input("LnBias"); auto *ln_mean = ctx.Output("LnMean"); auto *ln_var = ctx.Output("LnVariance"); auto *ln_out = ctx.Output("LnOut"); // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto *qkv_weight = ctx.Input("QKVW"); auto *qkv_bias = ctx.Input("QKVBias"); auto *qkv_out = ctx.Output("QKVOut"); auto *qkv_bias_out = ctx.Output("QKVBiasOut"); auto *src_mask = ctx.Input("SrcMask"); auto *transpose_out_2 = ctx.Output("TransposeOut2"); auto *qk_out = ctx.Output("QKOut"); auto *qktv_out = ctx.Output("QKTVOut"); auto *softmax_out = ctx.Output("SoftmaxOut"); auto *attn_dropout_mask_out = ctx.Output("AttnDropoutMaskOut"); auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); auto *src_mask_out = ctx.Output("SrcMaskOut"); auto *fmha_out = ctx.Output("FMHAOut"); auto *out_linear_weight = ctx.Input("OutLinearW"); auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *out_linear_out = ctx.Output("OutLinearOut"); auto *ln_scale_2 = ctx.Input("Ln2Scale"); auto *ln_bias_2 = ctx.Input("Ln2Bias"); auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Output("BiasDropoutResidualOut"); auto *ln_mean_2 = ctx.Output("Ln2Mean"); auto *ln_var_2 = ctx.Output("Ln2Variance"); const float ln_epsilon = ctx.Attr("ln_epsilon"); float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); bool is_test_1 = ctx.Attr("attn_dropout_is_test"); auto &dropout_implementation_1 = ctx.Attr("attn_dropout_implementation"); bool is_upscale_in_train_1 = (dropout_implementation_1 == "upscale_in_train"); auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); int seed_val_1 = ctx.Attr("attn_dropout_seed"); // final output. auto *out = ctx.Output("Y"); // get data ptr for qkv part. const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); auto *x_data = input_x->data(); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); auto *qkv_bias_out_data = (qkv_bias == nullptr) ? nullptr : qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. auto *transpose_out_2_data = transpose_out_2->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->mutable_data(ctx.GetPlace()); auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); auto *attn_dropout_mask_out_data = attn_dropout_mask_out->mutable_data(ctx.GetPlace()); auto *attn_dropout_out_data = attn_dropout_out->mutable_data(ctx.GetPlace()); auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); // get data ptr for out_linear. auto *out_linear_weight_data = out_linear_weight->data(); auto *out_linear_bias_data = (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm auto *dropout_mask_out_data = dropout_mask_out->mutable_data(ctx.GetPlace()); auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; int max_seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int num_head = qkv_w_dims[1]; int dim_head = qkv_w_dims[2]; int bsz_seq = batch_size * max_seq_len; int hidden_size = num_head * dim_head; int output_size = 3 * hidden_size; int input_size = dim_embed; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); bool compute_bias = true; if (qkv_bias == nullptr) { compute_bias = false; } // (transA, transB, compute_bias) = (false, true, true) auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, bsz_seq, output_size, input_size, compute_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_rate, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); output_size = hidden_size; // (transA, transB, compute_bias) = (false, false, false) auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, output_size, input_size, false); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln_epsilon); if (pre_layer_norm) { auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, ln_out_data, ln_mean_data, ln_var_data); qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out); } else { qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out); } if (qkv_bias == nullptr) { fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); } else { fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); } // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim] out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr); if (pre_layer_norm) { // output = (residual + dropout(input + bias)) fused_dropout_layernorm_helper.ResidualDropoutBias( ctx.cuda_device_context(), out_linear_out_data, x_data, out_linear_bias_data, final_out_data, dropout_mask_out_data); } else { auto *ln_scale_2_data = (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); auto *ln_bias_2_data = (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); auto *bias_dropout_residual_out_data = bias_dropout_residual_out->mutable_data(ctx.GetPlace()); auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); // output = layernorm(residual + dropout(input + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( ctx.cuda_device_context(), out_linear_out_data, x_data, out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, ln_mean_2_data, ln_var_2_data); } } }; template class FusedAttentionGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); const float ln2epsilon = ctx.Attr("ln_epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); bool is_test_1 = ctx.Attr("attn_dropout_is_test"); auto &dropout_implementation_1 = ctx.Attr("attn_dropout_implementation"); bool is_upscale_in_train_1 = (dropout_implementation_1 == "upscale_in_train"); auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); int seed_val_1 = ctx.Attr("attn_dropout_seed"); // get inputs. auto *d_y = ctx.Input(framework::GradVarName("Y")); auto *d_y_data = d_y->data(); // fw input auto *input_x = ctx.Input("X"); auto *ln_scale = ctx.Input("LnScale"); auto *ln_2_scale = ctx.Input("Ln2Scale"); auto *x_data = input_x->data(); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_2_scale_data = (ln_2_scale == nullptr ? nullptr : ln_2_scale->data()); // fw parameters. auto *src_mask = ctx.Input("SrcMask"); auto *qkv_weight = ctx.Input("QKVW"); auto *qkv_bias = ctx.Input("QKVBias"); auto *out_linear_weight = ctx.Input("OutLinearW"); auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *out_linear_weight_data = out_linear_weight->data(); auto *out_linear_bias_data = (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); // fw output auto *fmha_out = ctx.Input("FMHAOut"); auto *transpose_out_2 = ctx.Input("TransposeOut2"); auto *qk_out = ctx.Input("QKOut"); auto *qktv_out = ctx.Input("QKTVOut"); auto *softmax_out = ctx.Input("SoftmaxOut"); auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); auto *src_mask_out = ctx.Input("SrcMaskOut"); auto *out_linear_out = ctx.Input("OutLinearOut"); auto *ln_2_mean = ctx.Input("Ln2Mean"); auto *ln_2_var = ctx.Input("Ln2Variance"); auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Input("BiasDropoutResidualOut"); auto *fmha_out_data = fmha_out->data(); auto *transpose_out_2_data = transpose_out_2->data(); auto *qk_out_data = qk_out->data(); auto *qktv_out_data = qktv_out->data(); auto *softmax_out_data = softmax_out->data(); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->data(); auto *out_linear_out_data = out_linear_out->data(); auto *dropout_mask_out_data = dropout_mask_out->data(); // output's grad auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_qkv_out = ctx.Output(framework::GradVarName("QKVOut")); auto *d_qkv_bias_out = ctx.Output(framework::GradVarName("QKVBiasOut")); auto *d_qktv_out = ctx.Output(framework::GradVarName("QKTVOut")); auto *d_transpose_out_2 = ctx.Output(framework::GradVarName("TransposeOut2")); auto *d_qk_out = ctx.Output(framework::GradVarName("QKOut")); auto *d_softmax_out = ctx.Output(framework::GradVarName("SoftmaxOut")); auto *d_attn_dropout_out = ctx.Output(framework::GradVarName("AttnDropoutOut")); auto *d_src_mask_out = ctx.Output(framework::GradVarName("SrcMaskOut")); auto *d_fmha_out = ctx.Output(framework::GradVarName("FMHAOut")); auto *d_out_linear_out = ctx.Output(framework::GradVarName("OutLinearOut")); auto *d_bias_dropout_residual_out = ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the // space can be reused. auto *d_qkv_out_data = (d_qkv_bias_out != nullptr) ? nullptr : d_qkv_out->mutable_data(ctx.GetPlace()); auto *d_qkv_bias_out_data = (d_qkv_bias_out == nullptr) ? nullptr : d_qkv_bias_out->mutable_data(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); auto *d_transpose_out_2_data = d_transpose_out_2->mutable_data(ctx.GetPlace()); auto *d_qk_out_data = d_qk_out->mutable_data(ctx.GetPlace()); auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); auto *d_attn_dropout_out_data = d_attn_dropout_out->mutable_data(ctx.GetPlace()); auto *d_src_mask_out_data = (src_mask == nullptr) ? nullptr : d_src_mask_out->mutable_data(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); auto *d_out_linear_out_data = d_out_linear_out->mutable_data(ctx.GetPlace()); // parameter grad auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); auto *d_out_linear_weight = ctx.Output(framework::GradVarName("OutLinearW")); auto *d_out_linear_bias = ctx.Output(framework::GradVarName("OutLinearBias")); auto *d_ln_2_scale = ctx.Output(framework::GradVarName("Ln2Scale")); auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); auto *d_qkv_bias_data = (d_qkv_bias == nullptr) ? nullptr : d_qkv_bias->mutable_data(ctx.GetPlace()); auto *d_out_linear_weight_data = d_out_linear_weight->mutable_data(ctx.GetPlace()); auto *d_out_linear_bias_data = (d_out_linear_bias == nullptr) ? nullptr : d_out_linear_bias->mutable_data(ctx.GetPlace()); const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); int batch_size = input_x_dims[0]; int max_seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int num_head = qkv_w_dims[1]; int dim_head = qkv_w_dims[2]; int bsz_seq = batch_size * max_seq_len; int hidden_size = num_head * dim_head; int output_size = 3 * hidden_size; int input_size = dim_embed; Tensor d_residual; d_residual.Resize(input_x_dims); T *d_residual_data = d_residual.mutable_data(ctx.GetPlace()); bool transA = false; bool transB = true; bool compute_qkv_bias = true; if (qkv_bias == nullptr) { compute_qkv_bias = false; } auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_qkv_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_prob, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); output_size = hidden_size; transA = false; transB = false; bool compute_bias = false; auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_bias); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln2epsilon); if (pre_layer_norm) { fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( ctx.cuda_device_context(), d_y_data, dropout_mask_out_data, d_out_linear_out_data, d_residual_data, d_out_linear_bias_data); } else { auto *ln_2_mean_data = ln_2_mean->data(); auto *ln_2_var_data = ln_2_var->data(); auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); auto *d_ln_2_scale_data = (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data( ctx.GetPlace())); auto *d_ln_2_bias_data = (d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data( ctx.GetPlace())); auto *d_bias_dropout_residual_out_data = d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); } out_linear_compute.ComputeBackward(fmha_out, out_linear_weight, d_out_linear_out, d_fmha_out, d_out_linear_weight, nullptr); if (qkv_bias != nullptr) { fmha_ref_compute.ComputeBackward( *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_transpose_out_2, nullptr, d_qkv_bias_out); } else { fmha_ref_compute.ComputeBackward( *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_transpose_out_2, nullptr, d_qkv_out); } if (pre_layer_norm) { auto *ln_mean = ctx.Input("LnMean"); auto *ln_var = ctx.Input("LnVariance"); auto *ln_out = ctx.Input("LnOut"); auto *ln_mean_data = ln_mean->data(); auto *ln_var_data = ln_var->data(); auto *ln_out_data = ln_out->data(); auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); auto *d_ln_scale_data = (d_ln_scale == nullptr ? nullptr : d_ln_scale->mutable_data(ctx.GetPlace())); auto *d_ln_bias_data = (d_ln_bias == nullptr ? nullptr : d_ln_bias->mutable_data(ctx.GetPlace())); if (qkv_bias != nullptr) { qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out, d_qkv_weight, d_qkv_bias); } else { qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias); } layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, ln_mean_data, ln_var_data, d_x_data, d_ln_scale_data, d_ln_bias_data); } else { if (qkv_bias != nullptr) { qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias); } else { qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias); } } // gradient accumulation std::vector ins; std::vector outs; ins.emplace_back(&d_residual); ins.emplace_back(d_x); outs.emplace_back(d_x); int elewise_add_axis = -1; LaunchElementwiseCudaKernel( ctx.cuda_device_context(), ins, &outs, elewise_add_axis, AddFunctor()); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel); REGISTER_OP_CUDA_KERNEL(fused_attention_grad, ops::FusedAttentionGradKernel, ops::FusedAttentionGradKernel, ops::FusedAttentionGradKernel);