/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/fused/attention_layer_norm.h" #include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif namespace paddle { namespace operators { using Tensor = framework::Tensor; template static void AllReduce(framework::Tensor &tensor, // NOLINT const int ring_id, const platform::CUDADeviceContext &ctx) { if (ring_id == -1) return; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto dtype = platform::ToNCCLDataType(framework::TransToProtoVarType(tensor.dtype())); int64_t numel = tensor.numel(); const void *sendbuff = tensor.data(); auto place = ctx.GetPlace(); void *recvbuff = tensor.mutable_data(place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto stream = ctx.stream(); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); #else PADDLE_THROW(platform::errors::Unimplemented( "PaddlePaddle should compile with NCCL or RCCL when used tensor model " "parallel op.")); #endif } template class FusedAttentionOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; auto *input_x = ctx.Input("X"); const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); auto *ln_scale = ctx.Input("LnScale"); auto *ln_bias = ctx.Input("LnBias"); auto *ln_mean = ctx.Output("LnMean"); auto *ln_var = ctx.Output("LnVariance"); auto *ln_out = ctx.Output("LnOut"); // x: qkv's input [batch_size, seq_len, dim_embed] // y: qkv's weight: [3, num_head, dim_head, dim_embed] auto *qkv_weight = ctx.Input("QKVW"); auto *qkv_bias = ctx.Input("QKVBias"); auto *qkv_out = ctx.Output("QKVOut"); auto *qkv_bias_out = ctx.Output("QKVBiasOut"); auto *src_mask = ctx.Input("SrcMask"); auto *transpose_out_2 = ctx.Output("TransposeOut2"); auto *cache_kv = ctx.Input("CacheKV"); auto *cache_kv_out = ctx.Output("CacheKVOut"); auto *qk_out = ctx.Output("QKOut"); auto *qktv_out = ctx.Output("QKTVOut"); auto *softmax_out = ctx.Output("SoftmaxOut"); auto *attn_dropout_mask_out = ctx.Output("AttnDropoutMaskOut"); auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); auto *src_mask_out = ctx.Output("SrcMaskOut"); auto *fmha_out = ctx.Output("FMHAOut"); auto *out_linear_weight = ctx.Input("OutLinearW"); auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *out_linear_out = ctx.Output("OutLinearOut"); auto *ln_scale_2 = ctx.Input("Ln2Scale"); auto *ln_bias_2 = ctx.Input("Ln2Bias"); auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Output("BiasDropoutResidualOut"); auto *ln_mean_2 = ctx.Output("Ln2Mean"); auto *ln_var_2 = ctx.Output("Ln2Variance"); const float ln_epsilon = ctx.Attr("ln_epsilon"); float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); bool is_test_1 = ctx.Attr("is_test"); auto &dropout_implementation_1 = ctx.Attr("attn_dropout_implementation"); bool is_upscale_in_train_1 = (dropout_implementation_1 == "upscale_in_train"); auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); int seed_val_1 = ctx.Attr("attn_dropout_seed"); int ring_id = ctx.Attr("ring_id"); // final output. auto *out = ctx.Output("Y"); // get data ptr for qkv part. const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); auto *x_data = input_x->data(); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); auto *qkv_bias_out_data = (qkv_bias == nullptr) ? nullptr : qkv_bias_out->mutable_data(ctx.GetPlace()); // get data ptr for FMHA. auto *transpose_out_2_data = transpose_out_2->mutable_data(ctx.GetPlace()); auto *cache_kv_out_data = (cache_kv_out == nullptr) ? nullptr : cache_kv_out->mutable_data(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->mutable_data(ctx.GetPlace()); auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); auto *attn_dropout_mask_out_data = attn_dropout_mask_out->mutable_data(ctx.GetPlace()); auto *attn_dropout_out_data = attn_dropout_out->mutable_data(ctx.GetPlace()); auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); // get data ptr for out_linear. auto *out_linear_weight_data = out_linear_weight->data(); auto *out_linear_bias_data = (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); // get data ptr for bias+dropout+residual+layernorm auto *dropout_mask_out_data = dropout_mask_out->mutable_data(ctx.GetPlace()); auto *final_out_data = out->mutable_data(ctx.GetPlace()); int batch_size = input_x_dims[0]; int max_seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int num_head = qkv_w_dims[1]; int dim_head = qkv_w_dims[2]; int bsz_seq = batch_size * max_seq_len; int hidden_size = num_head * dim_head; int output_size = 3 * hidden_size; int input_size = dim_embed; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); bool compute_bias = true; if (qkv_bias == nullptr) { compute_bias = false; } // (transA, transB, compute_bias) = (false, true, true) auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, bsz_seq, output_size, input_size, compute_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_rate, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); output_size = hidden_size; // (transA, transB, compute_bias) = (false, false, false) // NOTE(Yuang Liu): For general input size == output size, change the // position won't have effects. For mp, the output size is mp_head * dkey // which is actually the input size. While the input size is hidden size, // which is actually the output size. So for out linear, switch the // input size and output size. auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, input_size, output_size, false); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln_epsilon); if (pre_layer_norm) { auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, ln_out_data, ln_mean_data, ln_var_data); qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out); } else { qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out); } if (qkv_bias == nullptr) { fmha_ref_compute.ComputeForward( *qkv_out, cache_kv, src_mask, transpose_out_2, cache_kv_out, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); } else { fmha_ref_compute.ComputeForward( *qkv_bias_out, cache_kv, src_mask, transpose_out_2, cache_kv_out, qk_out, src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out, fmha_out); } // fmha_out: [batch_size, seq_len, num_head, head_dim] // weight: [embed_dim, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim] out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr); // tensor model parallel AllReduce(*out_linear_out, ring_id, ctx.cuda_device_context()); bool add_residual = ctx.Attr("add_residual"); const T *residual_ptr = add_residual ? x_data : nullptr; if (pre_layer_norm) { // output = (residual + dropout(input + bias)) fused_dropout_layernorm_helper.ResidualDropoutBias( ctx.cuda_device_context(), out_linear_out_data, residual_ptr, out_linear_bias_data, final_out_data, dropout_mask_out_data); } else { // TODO(Xreki): support post layer_norm case when add_residual is false. PADDLE_ENFORCE_EQ(add_residual, true, platform::errors::InvalidArgument( "Attribute add_residual is expected to be true " "when pre_layer_norm is false.")); const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data() : nullptr; const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data() : nullptr; T *bias_dropout_residual_out_ptr = bias_dropout_residual_out->mutable_data(ctx.GetPlace()); U *ln_mean_2_ptr = ln_mean_2->mutable_data(ctx.GetPlace()); U *ln_var_2_ptr = ln_var_2->mutable_data(ctx.GetPlace()); // output = layernorm(residual + dropout(input + bias)) fused_dropout_layernorm_helper.LayernormResidualDropoutBias( ctx.cuda_device_context(), out_linear_out_data, residual_ptr, out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr, bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data, ln_mean_2_ptr, ln_var_2_ptr); } } }; template class FusedAttentionGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { using U = LayerNormParamType; const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); const float epsilon = ctx.Attr("epsilon"); const float ln2epsilon = ctx.Attr("ln_epsilon"); float attn_dropout_prob = ctx.Attr("attn_dropout_rate"); bool is_test_1 = ctx.Attr("is_test"); auto &dropout_implementation_1 = ctx.Attr("attn_dropout_implementation"); bool is_upscale_in_train_1 = (dropout_implementation_1 == "upscale_in_train"); auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); int seed_val_1 = ctx.Attr("attn_dropout_seed"); int ring_id = ctx.Attr("ring_id"); // get inputs. auto *d_y = ctx.Input(framework::GradVarName("Y")); auto *d_y_data = d_y->data(); // fw input auto *input_x = ctx.Input("X"); auto *ln_scale = ctx.Input("LnScale"); auto *ln_2_scale = ctx.Input("Ln2Scale"); auto *x_data = input_x->data(); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_2_scale_data = (ln_2_scale == nullptr ? nullptr : ln_2_scale->data()); // fw parameters. auto *src_mask = ctx.Input("SrcMask"); auto *qkv_weight = ctx.Input("QKVW"); auto *qkv_bias = ctx.Input("QKVBias"); auto *out_linear_weight = ctx.Input("OutLinearW"); auto *out_linear_bias = ctx.Input("OutLinearBias"); auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data()); auto *qkv_weight_data = qkv_weight->data(); auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data(); auto *out_linear_weight_data = out_linear_weight->data(); auto *out_linear_bias_data = (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); // fw output auto *fmha_out = ctx.Input("FMHAOut"); auto *transpose_out_2 = ctx.Input("TransposeOut2"); auto *qk_out = ctx.Input("QKOut"); auto *qktv_out = ctx.Input("QKTVOut"); auto *softmax_out = ctx.Input("SoftmaxOut"); auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); auto *src_mask_out = ctx.Input("SrcMaskOut"); auto *out_linear_out = ctx.Input("OutLinearOut"); auto *ln_2_mean = ctx.Input("Ln2Mean"); auto *ln_2_var = ctx.Input("Ln2Variance"); auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Input("BiasDropoutResidualOut"); auto *fmha_out_data = fmha_out->data(); auto *transpose_out_2_data = transpose_out_2->data(); auto *qk_out_data = qk_out->data(); auto *qktv_out_data = qktv_out->data(); auto *softmax_out_data = softmax_out->data(); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->data(); auto *out_linear_out_data = out_linear_out->data(); auto *dropout_mask_out_data = dropout_mask_out->data(); // output's grad auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_qkv_out = ctx.Output(framework::GradVarName("QKVOut")); auto *d_qkv_bias_out = ctx.Output(framework::GradVarName("QKVBiasOut")); auto *d_qktv_out = ctx.Output(framework::GradVarName("QKTVOut")); auto *d_transpose_out_2 = ctx.Output(framework::GradVarName("TransposeOut2")); auto *d_qk_out = ctx.Output(framework::GradVarName("QKOut")); auto *d_softmax_out = ctx.Output(framework::GradVarName("SoftmaxOut")); auto *d_attn_dropout_out = ctx.Output(framework::GradVarName("AttnDropoutOut")); auto *d_src_mask_out = ctx.Output(framework::GradVarName("SrcMaskOut")); auto *d_fmha_out = ctx.Output(framework::GradVarName("FMHAOut")); auto *d_out_linear_out = ctx.Output(framework::GradVarName("OutLinearOut")); auto *d_bias_dropout_residual_out = ctx.Output(framework::GradVarName("BiasDropoutResidualOut")); auto *d_x_data = d_x->mutable_data(ctx.GetPlace()); // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the // space can be reused. auto *d_qkv_out_data = (d_qkv_bias_out != nullptr) ? nullptr : d_qkv_out->mutable_data(ctx.GetPlace()); auto *d_qkv_bias_out_data = (d_qkv_bias_out == nullptr) ? nullptr : d_qkv_bias_out->mutable_data(ctx.GetPlace()); auto *d_qktv_out_data = d_qktv_out->mutable_data(ctx.GetPlace()); auto *d_transpose_out_2_data = d_transpose_out_2->mutable_data(ctx.GetPlace()); auto *d_qk_out_data = d_qk_out->mutable_data(ctx.GetPlace()); auto *d_softmax_out_data = d_softmax_out->mutable_data(ctx.GetPlace()); auto *d_attn_dropout_out_data = d_attn_dropout_out->mutable_data(ctx.GetPlace()); auto *d_src_mask_out_data = (src_mask == nullptr) ? nullptr : d_src_mask_out->mutable_data(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data(ctx.GetPlace()); auto *d_out_linear_out_data = d_out_linear_out->mutable_data(ctx.GetPlace()); // parameter grad auto *d_qkv_weight = ctx.Output(framework::GradVarName("QKVW")); auto *d_qkv_bias = ctx.Output(framework::GradVarName("QKVBias")); auto *d_out_linear_weight = ctx.Output(framework::GradVarName("OutLinearW")); auto *d_out_linear_bias = ctx.Output(framework::GradVarName("OutLinearBias")); auto *d_ln_2_scale = ctx.Output(framework::GradVarName("Ln2Scale")); auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); auto *d_qkv_weight_data = d_qkv_weight->mutable_data(ctx.GetPlace()); auto *d_qkv_bias_data = (d_qkv_bias == nullptr) ? nullptr : d_qkv_bias->mutable_data(ctx.GetPlace()); auto *d_out_linear_weight_data = d_out_linear_weight->mutable_data(ctx.GetPlace()); auto *d_out_linear_bias_data = (d_out_linear_bias == nullptr) ? nullptr : d_out_linear_bias->mutable_data(ctx.GetPlace()); const auto input_x_dims = input_x->dims(); const auto qkv_w_dims = qkv_weight->dims(); int batch_size = input_x_dims[0]; int max_seq_len = input_x_dims[1]; int dim_embed = input_x_dims[2]; int num_head = qkv_w_dims[1]; int dim_head = qkv_w_dims[2]; int bsz_seq = batch_size * max_seq_len; int hidden_size = num_head * dim_head; int output_size = 3 * hidden_size; int input_size = dim_embed; bool add_residual = ctx.Attr("add_residual"); Tensor d_residual; T *d_residual_data = nullptr; if (add_residual) { d_residual.Resize(input_x_dims); d_residual_data = d_residual.mutable_data(ctx.GetPlace()); } bool transA = false; bool transB = true; bool compute_qkv_bias = qkv_bias ? true : false; auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed); auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, output_size, input_size, compute_qkv_bias); AttnDropoutParam attn_dropout_param( is_test_1, dropout_implementation_1, attn_dropout_prob, is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); auto fmha_ref_compute = FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, dim_head, attn_dropout_param); output_size = hidden_size; transA = false; transB = false; bool compute_bias = false; // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed) auto out_linear_compute = AttnMatMul(ctx.cuda_device_context(), transA, transB, bsz_seq, input_size, output_size, compute_bias); DropoutParam dropout_param2(ctx, 0); FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, ln2epsilon); if (pre_layer_norm) { fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( ctx.cuda_device_context(), d_y_data, dropout_mask_out_data, d_out_linear_out_data, d_residual_data, d_out_linear_bias_data); } else { auto *ln_2_mean_data = ln_2_mean->data(); auto *ln_2_var_data = ln_2_var->data(); auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data(); auto *d_ln_2_scale_data = (d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data(ctx.GetPlace())); auto *d_ln_2_bias_data = (d_ln_2_bias == nullptr ? nullptr : d_ln_2_bias->mutable_data(ctx.GetPlace())); auto *d_bias_dropout_residual_out_data = d_bias_dropout_residual_out->mutable_data(ctx.GetPlace()); fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data, dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data, d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data, d_out_linear_out_data, d_out_linear_bias_data, d_residual_data); } out_linear_compute.ComputeBackward(fmha_out, out_linear_weight, d_out_linear_out, d_fmha_out, d_out_linear_weight, nullptr); if (qkv_bias != nullptr) { fmha_ref_compute.ComputeBackward( *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_transpose_out_2, nullptr, d_qkv_bias_out); } else { fmha_ref_compute.ComputeBackward( *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_transpose_out_2, nullptr, d_qkv_out); } if (pre_layer_norm) { auto *ln_mean = ctx.Input("LnMean"); auto *ln_var = ctx.Input("LnVariance"); auto *ln_out = ctx.Input("LnOut"); auto *ln_mean_data = ln_mean->data(); auto *ln_var_data = ln_var->data(); auto *ln_out_data = ln_out->data(); auto *d_ln_out = ctx.Output(framework::GradVarName("LnOut")); auto *d_ln_scale = ctx.Output(framework::GradVarName("LnScale")); auto *d_ln_bias = ctx.Output(framework::GradVarName("LnBias")); auto *d_ln_out_data = d_ln_out->mutable_data(ctx.GetPlace()); auto *d_ln_scale_data = (d_ln_scale == nullptr ? nullptr : d_ln_scale->mutable_data(ctx.GetPlace())); auto *d_ln_bias_data = (d_ln_bias == nullptr ? nullptr : d_ln_bias->mutable_data(ctx.GetPlace())); if (qkv_bias != nullptr) { qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out, d_ln_out, d_qkv_weight, d_qkv_bias); } else { qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias); } // tensor model parallel AllReduce(*d_ln_out, ring_id, ctx.cuda_device_context()); layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data, ln_mean_data, ln_var_data, d_x_data, d_ln_scale_data, d_ln_bias_data); } else { if (qkv_bias != nullptr) { qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias); } else { qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias); } // tensor model parallel AllReduce(*d_x, ring_id, ctx.cuda_device_context()); } if (add_residual) { // gradient accumulation std::vector ins = {&d_residual, d_x}; std::vector outs = {d_x}; phi::funcs::ElementwiseKernel(ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor()); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel, ops::FusedAttentionOpKernel); REGISTER_OP_CUDA_KERNEL(fused_attention_grad, ops::FusedAttentionGradKernel, ops::FusedAttentionGradKernel, ops::FusedAttentionGradKernel);