From 2afa9b7652c924a589244a946a109f1d4651f343 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 7 Jul 2022 12:54:19 +0800 Subject: [PATCH] [Eager] Menual fused attention in eager (#43974) * fused_gate_attention manual code in eager --- .../manual/fluid_manual/dygraph_forward_api.h | 34 + .../fluid_manual/forwards/CMakeLists.txt | 8 + .../forwards/fused_attention_fwd_func.cc | 628 ++++++++++++++++++ .../manual/fluid_manual/nodes/CMakeLists.txt | 7 +- .../nodes/fused_attention_node.cc | 366 ++++++++++ .../api/manual/fluid_manual/nodes/nodes.h | 202 ++++++ .../auto_code_generator/eager_generator.cc | 6 +- .../unittests/test_fused_attention_op.py | 4 +- .../unittests/test_fused_gate_attention_op.py | 4 +- 9 files changed, 1250 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_attention_node.cc diff --git a/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h b/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h index 397e549e61..91d556f955 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h @@ -67,3 +67,37 @@ fused_feedforward_dygraph_function( const paddle::experimental::Tensor& Ln2Scale, const paddle::experimental::Tensor& Ln2Bias, const paddle::framework::AttributeMap& attr_map); + +std::tuple +fused_attention_dygraph_function( + const paddle::experimental::Tensor& X, + const paddle::experimental::Tensor& LnScale, + const paddle::experimental::Tensor& LnBias, + const paddle::experimental::Tensor& QKVW, + const paddle::experimental::Tensor& QKVBias, + const paddle::experimental::Tensor& CacheKV, + const paddle::experimental::Tensor& SrcMask, + const paddle::experimental::Tensor& OutLinearW, + const paddle::experimental::Tensor& OutLinearBias, + const paddle::experimental::Tensor& Ln2Scale, + const paddle::experimental::Tensor& Ln2Bias, + const paddle::framework::AttributeMap& attr_map); diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt b/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt index 305df1c92c..4912663ef1 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt @@ -12,6 +12,14 @@ cc_library( add_dependencies(fused_feedforward_fwd_func eager_codegen) +cc_library( + fused_attention_fwd_func + SRCS fused_attention_fwd_func.cc + DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) + +add_dependencies(fused_attention_fwd_func eager_codegen) + set(fluid_manual_functions fused_gate_attention_fwd_func fused_feedforward_fwd_func + fused_attention_fwd_func PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc new file mode 100644 index 0000000000..b058fa50ac --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc @@ -0,0 +1,628 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/amp_auto_cast.h" +#include "paddle/fluid/eager/amp_utils.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" + +#pragma GCC diagnostic ignored "-Wunused-variable" + +std::tuple +fused_attention_dygraph_function( + const paddle::experimental::Tensor& X, + const paddle::experimental::Tensor& LnScale, + const paddle::experimental::Tensor& LnBias, + const paddle::experimental::Tensor& QKVW, + const paddle::experimental::Tensor& QKVBias, + const paddle::experimental::Tensor& CacheKV, + const paddle::experimental::Tensor& SrcMask, + const paddle::experimental::Tensor& OutLinearW, + const paddle::experimental::Tensor& OutLinearBias, + const paddle::experimental::Tensor& Ln2Scale, + const paddle::experimental::Tensor& Ln2Bias, + const paddle::framework::AttributeMap& attr_map) { + paddle::platform::RecordEvent dygraph_entrance_record_event( + "fused_attention dygraph", + paddle::platform::TracerEventType::Operator, + 1); + VLOG(3) << "Running Eager Forward Op: fused_attention"; + // Dygraph Forward Pass + + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + VLOG(5) << "Check and Prepare For AMP"; + + paddle::small_vector, + egr::kSlotSmallVectorSize> + amp_tensors_vector = {{X}, {QKVW}, {OutLinearW}}; + if (LnScale.initialized()) amp_tensors_vector.push_back({LnScale}); + if (LnBias.initialized()) amp_tensors_vector.push_back({LnBias}); + if (QKVBias.initialized()) amp_tensors_vector.push_back({QKVBias}); + if (CacheKV.initialized()) amp_tensors_vector.push_back({CacheKV}); + if (SrcMask.initialized()) amp_tensors_vector.push_back({SrcMask}); + if (OutLinearBias.initialized()) + amp_tensors_vector.push_back({OutLinearBias}); + if (Ln2Scale.initialized()) amp_tensors_vector.push_back({Ln2Scale}); + if (Ln2Bias.initialized()) amp_tensors_vector.push_back({Ln2Bias}); + + auto amp_dst_dtype = + egr::GetAmpDestDtype("fused_attention", amp_tensors_vector); + + auto NEW_X = egr::AmpAutoCast("X", X, amp_dst_dtype, "fused_attention"); + auto NEW_QKVW = + egr::AmpAutoCast("QKVW", QKVW, amp_dst_dtype, "fused_attention"); + auto NEW_OutLinearW = egr::AmpAutoCast( + "OutLinearW", OutLinearW, amp_dst_dtype, "fused_attention"); + auto NEW_LnScale = + ((LnScale.initialized()) + ? egr::AmpAutoCast( + "LnScale", LnScale, amp_dst_dtype, "fused_attention") + : LnScale); + auto NEW_LnBias = + ((LnBias.initialized()) + ? egr::AmpAutoCast( + "LnBias", LnBias, amp_dst_dtype, "fused_attention") + : LnBias); + auto NEW_QKVBias = + ((QKVBias.initialized()) + ? egr::AmpAutoCast( + "QKVBias", QKVBias, amp_dst_dtype, "fused_attention") + : QKVBias); + auto NEW_CacheKV = + ((CacheKV.initialized()) + ? egr::AmpAutoCast( + "CacheKV", CacheKV, amp_dst_dtype, "fused_attention") + : CacheKV); + auto NEW_SrcMask = + ((SrcMask.initialized()) + ? egr::AmpAutoCast( + "SrcMask", SrcMask, amp_dst_dtype, "fused_attention") + : SrcMask); + auto NEW_OutLinearBias = + ((OutLinearBias.initialized()) ? egr::AmpAutoCast("OutLinearBias", + OutLinearBias, + amp_dst_dtype, + "fused_attention") + : OutLinearBias); + auto NEW_Ln2Scale = + ((Ln2Scale.initialized()) + ? egr::AmpAutoCast( + "Ln2Scale", Ln2Scale, amp_dst_dtype, "fused_attention") + : Ln2Scale); + auto NEW_Ln2Bias = + ((Ln2Bias.initialized()) + ? egr::AmpAutoCast( + "Ln2Bias", Ln2Bias, amp_dst_dtype, "fused_attention") + : Ln2Bias); + + { + paddle::imperative::AutoCastGuard guard( + egr::Controller::Instance().GetCurrentTracer(), + paddle::imperative::AmpLevel::O0); + return fused_attention_dygraph_function(NEW_X, + NEW_LnScale, + NEW_LnBias, + NEW_QKVW, + NEW_QKVBias, + NEW_CacheKV, + NEW_SrcMask, + NEW_OutLinearW, + NEW_OutLinearBias, + NEW_Ln2Scale, + NEW_Ln2Bias, + attr_map); + } + } + + std::map>> ins = + {{"X", egr::EagerUtils::TrySyncToVars(X)}, + {"QKVW", egr::EagerUtils::TrySyncToVars(QKVW)}, + {"OutLinearW", egr::EagerUtils::TrySyncToVars(OutLinearW)}}; + if (LnScale.initialized()) + ins["LnScale"] = egr::EagerUtils::TrySyncToVars(LnScale); + if (LnBias.initialized()) + ins["LnBias"] = egr::EagerUtils::TrySyncToVars(LnBias); + if (QKVBias.initialized()) + ins["QKVBias"] = egr::EagerUtils::TrySyncToVars(QKVBias); + if (CacheKV.initialized()) + ins["CacheKV"] = egr::EagerUtils::TrySyncToVars(CacheKV); + if (SrcMask.initialized()) + ins["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask); + if (OutLinearBias.initialized()) + ins["OutLinearBias"] = egr::EagerUtils::TrySyncToVars(OutLinearBias); + if (Ln2Scale.initialized()) + ins["Ln2Scale"] = egr::EagerUtils::TrySyncToVars(Ln2Scale); + if (Ln2Bias.initialized()) + ins["Ln2Bias"] = egr::EagerUtils::TrySyncToVars(Ln2Bias); + + std::map>> outs = + {{"LnMean", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"LnVariance", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"LnOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"QKVOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"QKVBiasOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"TransposeOut2", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"QKOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"QKTVOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"SoftmaxOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"AttnDropoutMaskOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"AttnDropoutOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"SrcMaskOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"FMHAOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"OutLinearOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"DropoutMaskOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"Ln2Mean", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"Ln2Variance", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"BiasDropoutResidualOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"CacheKVOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"Y", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}}; + + // Prepare Autograd Meta + egr::AutogradMeta* p_autograd_X = egr::EagerUtils::nullable_autograd_meta(X); + egr::AutogradMeta* p_autograd_LnScale = + egr::EagerUtils::nullable_autograd_meta(LnScale); + egr::AutogradMeta* p_autograd_LnBias = + egr::EagerUtils::nullable_autograd_meta(LnBias); + egr::AutogradMeta* p_autograd_QKVW = + egr::EagerUtils::nullable_autograd_meta(QKVW); + egr::AutogradMeta* p_autograd_QKVBias = + egr::EagerUtils::nullable_autograd_meta(QKVBias); + egr::AutogradMeta* p_autograd_CacheKV = + egr::EagerUtils::nullable_autograd_meta(CacheKV); + egr::AutogradMeta* p_autograd_SrcMask = + egr::EagerUtils::nullable_autograd_meta(SrcMask); + egr::AutogradMeta* p_autograd_OutLinearW = + egr::EagerUtils::nullable_autograd_meta(OutLinearW); + egr::AutogradMeta* p_autograd_OutLinearBias = + egr::EagerUtils::nullable_autograd_meta(OutLinearBias); + egr::AutogradMeta* p_autograd_Ln2Scale = + egr::EagerUtils::nullable_autograd_meta(Ln2Scale); + egr::AutogradMeta* p_autograd_Ln2Bias = + egr::EagerUtils::nullable_autograd_meta(Ln2Bias); + + bool trace_backward = egr::Controller::Instance().HasGrad(); + + bool require_any_grad = + egr::EagerUtils::ComputeRequireGrad(trace_backward, + p_autograd_X, + p_autograd_LnScale, + p_autograd_LnBias, + p_autograd_QKVW, + p_autograd_QKVBias, + p_autograd_CacheKV, + p_autograd_SrcMask, + p_autograd_OutLinearW, + p_autograd_OutLinearBias, + p_autograd_Ln2Scale, + p_autograd_Ln2Bias); + + paddle::framework::AttributeMap attrs = attr_map; + paddle::framework::AttributeMap default_attrs; + egr::Controller::Instance().GetCurrentTracer()->TraceOp( + "fused_attention", + ins, + outs, + attrs, + egr::Controller::Instance().GetExpectedPlace(), + &default_attrs, + true, + {}); + + paddle::experimental::Tensor LnMean; + egr::EagerUtils::GetOutput(outs["LnMean"][0], &LnMean); + paddle::experimental::Tensor LnVariance; + egr::EagerUtils::GetOutput(outs["LnVariance"][0], &LnVariance); + paddle::experimental::Tensor LnOut; + egr::EagerUtils::GetOutput(outs["LnOut"][0], &LnOut); + paddle::experimental::Tensor QKVOut; + egr::EagerUtils::GetOutput(outs["QKVOut"][0], &QKVOut); + paddle::experimental::Tensor QKVBiasOut; + egr::EagerUtils::GetOutput(outs["QKVBiasOut"][0], &QKVBiasOut); + paddle::experimental::Tensor TransposeOut2; + egr::EagerUtils::GetOutput(outs["TransposeOut2"][0], &TransposeOut2); + paddle::experimental::Tensor QKOut; + egr::EagerUtils::GetOutput(outs["QKOut"][0], &QKOut); + paddle::experimental::Tensor QKTVOut; + egr::EagerUtils::GetOutput(outs["QKTVOut"][0], &QKTVOut); + paddle::experimental::Tensor SoftmaxOut; + egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut); + paddle::experimental::Tensor AttnDropoutMaskOut; + egr::EagerUtils::GetOutput(outs["AttnDropoutMaskOut"][0], + &AttnDropoutMaskOut); + paddle::experimental::Tensor AttnDropoutOut; + egr::EagerUtils::GetOutput(outs["AttnDropoutOut"][0], &AttnDropoutOut); + paddle::experimental::Tensor SrcMaskOut; + egr::EagerUtils::GetOutput(outs["SrcMaskOut"][0], &SrcMaskOut); + paddle::experimental::Tensor FMHAOut; + egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut); + paddle::experimental::Tensor OutLinearOut; + egr::EagerUtils::GetOutput(outs["OutLinearOut"][0], &OutLinearOut); + paddle::experimental::Tensor DropoutMaskOut; + egr::EagerUtils::GetOutput(outs["DropoutMaskOut"][0], &DropoutMaskOut); + paddle::experimental::Tensor Ln2Mean; + egr::EagerUtils::GetOutput(outs["Ln2Mean"][0], &Ln2Mean); + paddle::experimental::Tensor Ln2Variance; + egr::EagerUtils::GetOutput(outs["Ln2Variance"][0], &Ln2Variance); + paddle::experimental::Tensor BiasDropoutResidualOut; + egr::EagerUtils::GetOutput(outs["BiasDropoutResidualOut"][0], + &BiasDropoutResidualOut); + paddle::experimental::Tensor CacheKVOut; + egr::EagerUtils::GetOutput(outs["CacheKVOut"][0], &CacheKVOut); + paddle::experimental::Tensor Y; + egr::EagerUtils::GetOutput(outs["Y"][0], &Y); + + { + paddle::platform::RecordEvent node_creation_record_event( + "fused_attention node_creation", + paddle::platform::TracerEventType::Operator, + 1); + egr::AutogradMeta* p_autograd_LnMean = + egr::EagerUtils::autograd_meta(&LnMean); + egr::AutogradMeta* p_autograd_LnVariance = + egr::EagerUtils::autograd_meta(&LnVariance); + egr::AutogradMeta* p_autograd_LnOut = + egr::EagerUtils::autograd_meta(&LnOut); + egr::AutogradMeta* p_autograd_QKVOut = + egr::EagerUtils::autograd_meta(&QKVOut); + egr::AutogradMeta* p_autograd_QKVBiasOut = + egr::EagerUtils::autograd_meta(&QKVBiasOut); + egr::AutogradMeta* p_autograd_TransposeOut2 = + egr::EagerUtils::autograd_meta(&TransposeOut2); + egr::AutogradMeta* p_autograd_QKOut = + egr::EagerUtils::autograd_meta(&QKOut); + egr::AutogradMeta* p_autograd_QKTVOut = + egr::EagerUtils::autograd_meta(&QKTVOut); + egr::AutogradMeta* p_autograd_SoftmaxOut = + egr::EagerUtils::autograd_meta(&SoftmaxOut); + egr::AutogradMeta* p_autograd_AttnDropoutMaskOut = + egr::EagerUtils::autograd_meta(&AttnDropoutMaskOut); + egr::AutogradMeta* p_autograd_AttnDropoutOut = + egr::EagerUtils::autograd_meta(&AttnDropoutOut); + egr::AutogradMeta* p_autograd_SrcMaskOut = + egr::EagerUtils::autograd_meta(&SrcMaskOut); + egr::AutogradMeta* p_autograd_FMHAOut = + egr::EagerUtils::autograd_meta(&FMHAOut); + egr::AutogradMeta* p_autograd_OutLinearOut = + egr::EagerUtils::autograd_meta(&OutLinearOut); + egr::AutogradMeta* p_autograd_DropoutMaskOut = + egr::EagerUtils::autograd_meta(&DropoutMaskOut); + egr::AutogradMeta* p_autograd_Ln2Mean = + egr::EagerUtils::autograd_meta(&Ln2Mean); + egr::AutogradMeta* p_autograd_Ln2Variance = + egr::EagerUtils::autograd_meta(&Ln2Variance); + egr::AutogradMeta* p_autograd_BiasDropoutResidualOut = + egr::EagerUtils::autograd_meta(&BiasDropoutResidualOut); + egr::AutogradMeta* p_autograd_CacheKVOut = + egr::EagerUtils::autograd_meta(&CacheKVOut); + egr::AutogradMeta* p_autograd_Y = egr::EagerUtils::autograd_meta(&Y); + if (require_any_grad) { + VLOG(6) << " Construct Grad for fused_attention "; + egr::EagerUtils::PassStopGradient(false, + p_autograd_LnMean, + p_autograd_LnVariance, + p_autograd_LnOut, + p_autograd_QKVOut, + p_autograd_QKVBiasOut, + p_autograd_TransposeOut2, + p_autograd_QKOut, + p_autograd_QKTVOut, + p_autograd_SoftmaxOut, + p_autograd_AttnDropoutMaskOut, + p_autograd_AttnDropoutOut, + p_autograd_SrcMaskOut, + p_autograd_FMHAOut, + p_autograd_OutLinearOut, + p_autograd_DropoutMaskOut, + p_autograd_Ln2Mean, + p_autograd_Ln2Variance, + p_autograd_BiasDropoutResidualOut, + p_autograd_CacheKVOut, + p_autograd_Y); + // Create GradOpNode + auto grad_node = std::shared_ptr( + new fused_attentionGradNodeCompat(20, 23)); + + bool pre_layer_norm = false; + if (attrs.count("pre_layer_norm")) { + pre_layer_norm = BOOST_GET_CONST(bool, attrs.at("pre_layer_norm")); + } + + // Set Attributes + grad_node->SetAttrMap(std::move(attrs)); + grad_node->SetDefaultAttrMap(std::move(default_attrs)); + + grad_node->SetTensorWrapperX(X); + grad_node->SetTensorWrapperQKVW(QKVW); + grad_node->SetTensorWrapperOutLinearW(OutLinearW); + grad_node->SetTensorWrapperQKVOut(QKVOut); + grad_node->SetTensorWrapperTransposeOut2(TransposeOut2); + grad_node->SetTensorWrapperQKOut(QKOut); + grad_node->SetTensorWrapperQKTVOut(QKTVOut); + grad_node->SetTensorWrapperSoftmaxOut(SoftmaxOut); + grad_node->SetTensorWrapperAttnDropoutMaskOut(AttnDropoutMaskOut); + grad_node->SetTensorWrapperAttnDropoutOut(AttnDropoutOut); + grad_node->SetTensorWrapperFMHAOut(FMHAOut); + grad_node->SetTensorWrapperOutLinearOut(OutLinearOut); + grad_node->SetTensorWrapperDropoutMaskOut(DropoutMaskOut); + + grad_node->SetGradOutMeta(X, 0); + grad_node->SetGradOutMeta(QKVW, 3); + grad_node->SetGradOutMeta(OutLinearW, 7); + + if (QKVBias.initialized()) { + grad_node->SetTensorWrapperQKVBias(QKVBias); + grad_node->SetTensorWrapperQKVBiasOut(QKVBiasOut); + grad_node->SetGradOutMeta(QKVBias, 4); + + auto QKVBiasOut_accumulation_node = + std::make_shared(p_autograd_QKVBiasOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVBiasOut, 0); + egr::EagerUtils::SetHistory(p_autograd_QKVBiasOut, + QKVBiasOut_accumulation_node); + QKVBiasOut_accumulation_node->SetGradInMeta(QKVBiasOut, 0); + egr::EagerUtils::CheckAndRetainGrad(QKVBiasOut); + grad_node->SetGradOutMeta(QKVBiasOut, 11); + } + + if (SrcMask.initialized()) { + grad_node->SetTensorWrapperSrcMask(SrcMask); + grad_node->SetTensorWrapperSrcMaskOut(SrcMaskOut); + + auto SrcMaskOut_accumulation_node = + std::make_shared(p_autograd_SrcMaskOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_SrcMaskOut, 0); + egr::EagerUtils::SetHistory(p_autograd_SrcMaskOut, + SrcMaskOut_accumulation_node); + SrcMaskOut_accumulation_node->SetGradInMeta(SrcMaskOut, 0); + egr::EagerUtils::CheckAndRetainGrad(SrcMaskOut); + grad_node->SetGradOutMeta(SrcMaskOut, 12); + } + + if (OutLinearBias.initialized()) { + grad_node->SetTensorWrapperOutLinearBias(OutLinearBias); + grad_node->SetGradOutMeta(OutLinearBias, 8); + } + + if (pre_layer_norm) { + if (LnScale.initialized()) { + grad_node->SetTensorWrapperLnScale(LnScale); + grad_node->SetGradOutMeta(LnScale, 1); + } + if (LnBias.initialized()) { + grad_node->SetTensorWrapperLnBias(LnBias); + grad_node->SetGradOutMeta(LnBias, 2); + } + if (LnOut.initialized()) { + grad_node->SetTensorWrapperLnOut(LnOut); + + auto LnOut_accumulation_node = + std::make_shared(p_autograd_LnOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_LnOut, 0); + egr::EagerUtils::SetHistory(p_autograd_LnOut, + LnOut_accumulation_node); + LnOut_accumulation_node->SetGradInMeta(LnOut, 0); + egr::EagerUtils::CheckAndRetainGrad(LnOut); + grad_node->SetGradOutMeta(LnOut, 13); + } + if (LnMean.initialized()) { + grad_node->SetTensorWrapperLnMean(LnMean); + } + if (LnVariance.initialized()) { + grad_node->SetTensorWrapperLnVariance(LnVariance); + } + } else { + if (Ln2Scale.initialized()) { + grad_node->SetTensorWrapperLn2Scale(Ln2Scale); + grad_node->SetGradOutMeta(Ln2Scale, 9); + } + if (Ln2Bias.initialized()) { + grad_node->SetTensorWrapperLn2Bias(Ln2Bias); + grad_node->SetGradOutMeta(Ln2Bias, 10); + } + grad_node->SetTensorWrapperBiasDropoutResidualOut( + BiasDropoutResidualOut); + grad_node->SetTensorWrapperLn2Mean(Ln2Mean); + grad_node->SetTensorWrapperLn2Variance(Ln2Variance); + + auto BiasDropoutResidualOut_accumulation_node = + std::make_shared( + p_autograd_BiasDropoutResidualOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_BiasDropoutResidualOut, + 0); + egr::EagerUtils::SetHistory(p_autograd_BiasDropoutResidualOut, + BiasDropoutResidualOut_accumulation_node); + BiasDropoutResidualOut_accumulation_node->SetGradInMeta( + BiasDropoutResidualOut, 0); + egr::EagerUtils::CheckAndRetainGrad(BiasDropoutResidualOut); + grad_node->SetGradOutMeta(BiasDropoutResidualOut, 14); + } + + egr::EagerUtils::SetOutRankWithSlot(p_autograd_LnMean, 0); + grad_node->SetGradInMeta(LnMean, 0); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_LnVariance, 1); + grad_node->SetGradInMeta(LnVariance, 1); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutMaskOut, 9); + grad_node->SetGradInMeta(AttnDropoutMaskOut, 9); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_DropoutMaskOut, 14); + grad_node->SetGradInMeta(DropoutMaskOut, 14); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_Ln2Mean, 15); + grad_node->SetGradInMeta(Ln2Mean, 15); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_Ln2Variance, 16); + grad_node->SetGradInMeta(Ln2Variance, 16); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_CacheKVOut, 18); + egr::EagerUtils::SetHistory(p_autograd_CacheKVOut, grad_node); + grad_node->SetGradInMeta(CacheKVOut, 18); + egr::EagerUtils::CheckAndRetainGrad(CacheKVOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 19); + egr::EagerUtils::SetHistory(p_autograd_Y, grad_node); + grad_node->SetGradInMeta(Y, 19); + egr::EagerUtils::CheckAndRetainGrad(Y); + + auto QKVOut_accumulation_node = + std::make_shared(p_autograd_QKVOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVOut, 0); + egr::EagerUtils::SetHistory(p_autograd_QKVOut, QKVOut_accumulation_node); + QKVOut_accumulation_node->SetGradInMeta(QKVOut, 0); + egr::EagerUtils::CheckAndRetainGrad(QKVOut); + grad_node->SetGradOutMeta(QKVOut, 15); + + auto QKTVOut_accumulation_node = + std::make_shared(p_autograd_QKTVOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKTVOut, 0); + egr::EagerUtils::SetHistory(p_autograd_QKTVOut, + QKTVOut_accumulation_node); + QKTVOut_accumulation_node->SetGradInMeta(QKTVOut, 0); + egr::EagerUtils::CheckAndRetainGrad(QKTVOut); + grad_node->SetGradOutMeta(QKTVOut, 16); + + auto TransposeOut2_accumulation_node = + std::make_shared(p_autograd_TransposeOut2); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_TransposeOut2, 0); + egr::EagerUtils::SetHistory(p_autograd_TransposeOut2, + TransposeOut2_accumulation_node); + TransposeOut2_accumulation_node->SetGradInMeta(TransposeOut2, 0); + egr::EagerUtils::CheckAndRetainGrad(TransposeOut2); + grad_node->SetGradOutMeta(TransposeOut2, 17); + + auto QKOut_accumulation_node = + std::make_shared(p_autograd_QKOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKOut, 0); + egr::EagerUtils::SetHistory(p_autograd_QKOut, QKOut_accumulation_node); + QKOut_accumulation_node->SetGradInMeta(QKOut, 0); + egr::EagerUtils::CheckAndRetainGrad(QKOut); + grad_node->SetGradOutMeta(QKOut, 18); + + auto SoftmaxOut_accumulation_node = + std::make_shared(p_autograd_SoftmaxOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_SoftmaxOut, 0); + egr::EagerUtils::SetHistory(p_autograd_SoftmaxOut, + SoftmaxOut_accumulation_node); + SoftmaxOut_accumulation_node->SetGradInMeta(SoftmaxOut, 0); + egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut); + grad_node->SetGradOutMeta(SoftmaxOut, 19); + + auto AttnDropoutOut_accumulation_node = + std::make_shared( + p_autograd_AttnDropoutOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0); + egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut, + AttnDropoutOut_accumulation_node); + AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0); + egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut); + grad_node->SetGradOutMeta(AttnDropoutOut, 20); + + auto FMHAOut_accumulation_node = + std::make_shared(p_autograd_FMHAOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_FMHAOut, 0); + egr::EagerUtils::SetHistory(p_autograd_FMHAOut, + FMHAOut_accumulation_node); + FMHAOut_accumulation_node->SetGradInMeta(FMHAOut, 0); + egr::EagerUtils::CheckAndRetainGrad(FMHAOut); + grad_node->SetGradOutMeta(FMHAOut, 21); + + auto OutLinearOut_accumulation_node = + std::make_shared(p_autograd_OutLinearOut); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_OutLinearOut, 0); + egr::EagerUtils::SetHistory(p_autograd_OutLinearOut, + OutLinearOut_accumulation_node); + OutLinearOut_accumulation_node->SetGradInMeta(OutLinearOut, 0); + egr::EagerUtils::CheckAndRetainGrad(OutLinearOut); + grad_node->SetGradOutMeta(OutLinearOut, 22); + } + } + + return std::make_tuple(LnMean, + LnVariance, + LnOut, + QKVOut, + QKVBiasOut, + TransposeOut2, + QKOut, + QKTVOut, + SoftmaxOut, + AttnDropoutMaskOut, + AttnDropoutOut, + SrcMaskOut, + FMHAOut, + OutLinearOut, + DropoutMaskOut, + Ln2Mean, + Ln2Variance, + BiasDropoutResidualOut, + CacheKVOut, + Y); +} diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt b/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt index 4eaa43a4b5..28c034e8b5 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt @@ -8,6 +8,11 @@ cc_library( SRCS fused_feedforward_node.cc DEPS ${eager_deps} ${fluid_deps}) +cc_library( + fused_attention_node + SRCS fused_attention_node.cc + DEPS ${eager_deps} ${fluid_deps}) + set(fluid_manual_nodes - fused_gate_attention_node fused_feedforward_node + fused_gate_attention_node fused_feedforward_node fused_attention_node PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_attention_node.cc b/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_attention_node.cc new file mode 100644 index 0000000000..990cfb5226 --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_attention_node.cc @@ -0,0 +1,366 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/phi/api/all.h" + +paddle::small_vector, + egr::kSlotSmallVectorSize> +fused_attentionGradNodeCompat::operator()( + paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, + bool create_graph, + bool is_new_grad) { + VLOG(3) << "Running Eager Backward Node: fused_attentionGradNodeCompat"; + const auto& out_metas = OutputMeta(); + paddle::small_vector, + egr::kSlotSmallVectorSize> + outputs(23); + paddle::small_vector, + egr::kSlotSmallVectorSize> + hooked_grads0 = fused_attentionGradNodeCompat::ApplyGradientHooks(grads); + + bool pre_layer_norm = false; + if (attr_map_.count("pre_layer_norm")) { + pre_layer_norm = BOOST_GET_CONST(bool, attr_map_.at("pre_layer_norm")); + } + + std::map>> ins0 = + {{"AttnDropoutMaskOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->AttnDropoutMaskOut_))}, + {"AttnDropoutOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->AttnDropoutOut_))}, + {"DropoutMaskOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->DropoutMaskOut_))}, + {"FMHAOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->FMHAOut_))}, + {"OutLinearOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearOut_))}, + {"OutLinearW", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearW_))}, + {"QKOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->QKOut_))}, + {"QKTVOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->QKTVOut_))}, + {"QKVOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->QKVOut_))}, + {"QKVW", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->QKVW_))}, + {"SoftmaxOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxOut_))}, + {"TransposeOut2", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->TransposeOut2_))}, + {"X", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->X_))}, + {"Y@GRAD", egr::EagerUtils::TrySyncToVars(hooked_grads0[19])}}; + std::map>> outs0; + + if ((!out_metas[7].empty()) && (!(out_metas[7][0].IsStopGradient()))) { + outs0.insert({"OutLinearW@GRAD", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}); + } + if ((!out_metas[3].empty()) && (!(out_metas[3][0].IsStopGradient()))) { + outs0.insert({"QKVW@GRAD", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}); + } + if ((!out_metas[0].empty()) && (!(out_metas[0][0].IsStopGradient()))) { + outs0.insert({"X@GRAD", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}); + } + + auto QKVOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKVOut_); + if (QKVOut.defined() && (!out_metas[15].empty()) && + (!out_metas[15][0].IsStopGradient())) + outs0["QKVOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto QKTVOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKTVOut_); + if (QKTVOut.defined() && (!out_metas[16].empty()) && + (!out_metas[16][0].IsStopGradient())) + outs0["QKTVOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto TransposeOut2 = + egr::EagerUtils::RecoverTensorWrapper(&this->TransposeOut2_); + if (TransposeOut2.defined() && (!out_metas[17].empty()) && + (!out_metas[17][0].IsStopGradient())) + outs0["TransposeOut2@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto QKOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKOut_); + if (QKOut.defined() && (!out_metas[18].empty()) && + (!out_metas[18][0].IsStopGradient())) + outs0["QKOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto SoftmaxOut = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxOut_); + if (SoftmaxOut.defined() && (!out_metas[19].empty()) && + (!out_metas[19][0].IsStopGradient())) + outs0["SoftmaxOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto AttnDropoutOut = + egr::EagerUtils::RecoverTensorWrapper(&this->AttnDropoutOut_); + if (AttnDropoutOut.defined() && (!out_metas[20].empty()) && + (!out_metas[20][0].IsStopGradient())) + outs0["AttnDropoutOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto FMHAOut = egr::EagerUtils::RecoverTensorWrapper(&this->FMHAOut_); + if (FMHAOut.defined() && (!out_metas[21].empty()) && + (!out_metas[21][0].IsStopGradient())) + outs0["FMHAOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + auto OutLinearOut = + egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearOut_); + if (OutLinearOut.defined() && (!out_metas[22].empty()) && + (!out_metas[22][0].IsStopGradient())) + outs0["OutLinearOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + + auto QKVBias = egr::EagerUtils::RecoverTensorWrapper(&this->QKVBias_); + if (QKVBias.defined()) { + ins0["QKVBias"] = egr::EagerUtils::TrySyncToVars(QKVBias); + auto QKVBiasOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKVBiasOut_); + ins0["QKVBiasOut"] = egr::EagerUtils::TrySyncToVars(QKVBiasOut); + if (QKVBias.defined() && (!out_metas[4].empty()) && + (!out_metas[4][0].IsStopGradient())) + outs0["QKVBias@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + if (QKVBiasOut.defined() && (!out_metas[11].empty()) && + (!out_metas[11][0].IsStopGradient())) + outs0["QKVBiasOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_); + if (SrcMask.defined()) { + ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask); + auto SrcMaskOut = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMaskOut_); + ins0["SrcMaskOut"] = egr::EagerUtils::TrySyncToVars(SrcMaskOut); + if (SrcMaskOut.defined() && (!out_metas[12].empty()) && + (!out_metas[12][0].IsStopGradient())) + outs0["SrcMaskOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto OutLinearBias = + egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearBias_); + if (OutLinearBias.defined()) { + ins0["OutLinearBias"] = egr::EagerUtils::TrySyncToVars(OutLinearBias); + if (OutLinearBias.defined() && (!out_metas[8].empty()) && + (!out_metas[8][0].IsStopGradient())) + outs0["OutLinearBias@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + if (pre_layer_norm) { + auto LnScale = egr::EagerUtils::RecoverTensorWrapper(&this->LnScale_); + if (LnScale.defined()) { + ins0["LnScale"] = egr::EagerUtils::TrySyncToVars(LnScale); + if (LnScale.defined() && (!out_metas[1].empty()) && + (!out_metas[1][0].IsStopGradient())) + outs0["LnScale@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto LnBias = egr::EagerUtils::RecoverTensorWrapper(&this->LnBias_); + if (LnBias.defined()) { + ins0["LnBias"] = egr::EagerUtils::TrySyncToVars(LnBias); + if (LnBias.defined() && (!out_metas[2].empty()) && + (!out_metas[2][0].IsStopGradient())) + outs0["LnBias@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto LnOut = egr::EagerUtils::RecoverTensorWrapper(&this->LnOut_); + if (LnOut.defined()) { + ins0["LnOut"] = egr::EagerUtils::TrySyncToVars(LnOut); + if (LnOut.defined() && (!out_metas[13].empty()) && + (!out_metas[13][0].IsStopGradient())) + outs0["LnOut@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto LnMean = egr::EagerUtils::RecoverTensorWrapper(&this->LnMean_); + if (LnMean.defined()) { + ins0["LnMean"] = egr::EagerUtils::TrySyncToVars(LnMean); + } + + auto LnVariance = egr::EagerUtils::RecoverTensorWrapper(&this->LnVariance_); + if (LnVariance.defined()) { + ins0["LnVariance"] = egr::EagerUtils::TrySyncToVars(LnVariance); + } + } else { + auto Ln2Scale = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Scale_); + if (Ln2Scale.defined()) { + ins0["Ln2Scale"] = egr::EagerUtils::TrySyncToVars(Ln2Scale); + if (Ln2Scale.defined() && (!out_metas[9].empty()) && + (!out_metas[9][0].IsStopGradient())) + outs0["Ln2Scale@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto Ln2Bias = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Bias_); + if (Ln2Bias.defined()) { + ins0["Ln2Bias"] = egr::EagerUtils::TrySyncToVars(Ln2Bias); + if (Ln2Bias.defined() && (!out_metas[10].empty()) && + (!out_metas[10][0].IsStopGradient())) + outs0["Ln2Bias@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + auto BiasDropoutResidualOut = + egr::EagerUtils::RecoverTensorWrapper(&this->BiasDropoutResidualOut_); + auto Ln2Mean = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Mean_); + auto Ln2Variance = + egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Variance_); + ins0["BiasDropoutResidualOut"] = + egr::EagerUtils::TrySyncToVars(BiasDropoutResidualOut); + ins0["Ln2Mean"] = egr::EagerUtils::TrySyncToVars(Ln2Mean); + ins0["Ln2Variance"] = egr::EagerUtils::TrySyncToVars(Ln2Variance); + if (BiasDropoutResidualOut.defined() && (!out_metas[14].empty()) && + (!out_metas[14][0].IsStopGradient())) + outs0["BiasDropoutResidualOut@GRAD"] = { + std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto& attrs_map0 = this->attr_map_; + // Pass the entire attribute map to TraceOp + // The underlying kernel will pickup whatever attribute they need at runtime + egr::Controller::Instance().GetCurrentTracer()->TraceOp( + "fused_attention_grad", + ins0, + outs0, + attrs_map0, + egr::Controller::Instance().GetExpectedPlace(), + &this->default_attr_map_, + false, + {}); + + if (outs0.find("OutLinearW@GRAD") != outs0.end()) { + outputs[7] = egr::EagerUtils::GetOutputs(outs0["OutLinearW@GRAD"]); + } + if (outs0.find("QKVW@GRAD") != outs0.end()) { + outputs[3] = egr::EagerUtils::GetOutputs(outs0["QKVW@GRAD"]); + } + if (outs0.find("X@GRAD") != outs0.end()) { + outputs[0] = egr::EagerUtils::GetOutputs(outs0["X@GRAD"]); + } + + if (outs0.find("QKVOut@GRAD") != outs0.end()) { + outputs[15] = egr::EagerUtils::GetOutputs(outs0["QKVOut@GRAD"]); + } + if (outs0.find("QKTVOut@GRAD") != outs0.end()) { + outputs[16] = egr::EagerUtils::GetOutputs(outs0["QKTVOut@GRAD"]); + } + if (outs0.find("TransposeOut2@GRAD") != outs0.end()) { + outputs[17] = egr::EagerUtils::GetOutputs(outs0["TransposeOut2@GRAD"]); + } + if (outs0.find("QKOut@GRAD") != outs0.end()) { + outputs[18] = egr::EagerUtils::GetOutputs(outs0["QKOut@GRAD"]); + } + if (outs0.find("SoftmaxOut@GRAD") != outs0.end()) { + outputs[19] = egr::EagerUtils::GetOutputs(outs0["SoftmaxOut@GRAD"]); + } + if (outs0.find("AttnDropoutOut@GRAD") != outs0.end()) { + outputs[20] = egr::EagerUtils::GetOutputs(outs0["AttnDropoutOut@GRAD"]); + } + if (outs0.find("FMHAOut@GRAD") != outs0.end()) { + outputs[21] = egr::EagerUtils::GetOutputs(outs0["FMHAOut@GRAD"]); + } + if (outs0.find("OutLinearOut@GRAD") != outs0.end()) { + outputs[22] = egr::EagerUtils::GetOutputs(outs0["OutLinearOut@GRAD"]); + } + + if (QKVBias.defined()) { + if (outs0.find("QKVBias@GRAD") != outs0.end()) { + outputs[4] = egr::EagerUtils::GetOutputs(outs0["QKVBias@GRAD"]); + } + if (outs0.find("QKVBiasOut@GRAD") != outs0.end()) { + outputs[11] = egr::EagerUtils::GetOutputs(outs0["QKVBiasOut@GRAD"]); + } + } + + if (SrcMask.defined()) { + if (outs0.find("SrcMaskOut@GRAD") != outs0.end()) { + outputs[12] = egr::EagerUtils::GetOutputs(outs0["SrcMaskOut@GRAD"]); + } + } + + if (OutLinearBias.defined()) { + if (outs0.find("OutLinearBias@GRAD") != outs0.end()) { + outputs[8] = egr::EagerUtils::GetOutputs(outs0["OutLinearBias@GRAD"]); + } + } + + if (pre_layer_norm) { + auto LnScale = egr::EagerUtils::RecoverTensorWrapper(&this->LnScale_); + if (LnScale.defined()) { + if (outs0.find("LnScale@GRAD") != outs0.end()) { + outputs[1] = egr::EagerUtils::GetOutputs(outs0["LnScale@GRAD"]); + } + } + + auto LnBias = egr::EagerUtils::RecoverTensorWrapper(&this->LnBias_); + if (LnBias.defined()) { + if (outs0.find("LnBias@GRAD") != outs0.end()) { + outputs[2] = egr::EagerUtils::GetOutputs(outs0["LnBias@GRAD"]); + } + } + + auto LnOut = egr::EagerUtils::RecoverTensorWrapper(&this->LnOut_); + if (LnOut.defined()) { + if (outs0.find("LnOut@GRAD") != outs0.end()) { + outputs[13] = egr::EagerUtils::GetOutputs(outs0["LnOut@GRAD"]); + } + } + } else { + auto Ln2Scale = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Scale_); + if (Ln2Scale.defined()) { + if (outs0.find("Ln2Scale@GRAD") != outs0.end()) { + outputs[9] = egr::EagerUtils::GetOutputs(outs0["Ln2Scale@GRAD"]); + } + } + + auto Ln2Bias = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Bias_); + if (Ln2Bias.defined()) { + if (outs0.find("Ln2Bias@GRAD") != outs0.end()) { + outputs[10] = egr::EagerUtils::GetOutputs(outs0["Ln2Bias@GRAD"]); + } + } + if (outs0.find("BiasDropoutResidualOut@GRAD") != outs0.end()) { + outputs[14] = + egr::EagerUtils::GetOutputs(outs0["BiasDropoutResidualOut@GRAD"]); + } + } + + if (NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&outputs); + return outputs; +} diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h index 52d3b44d7b..571deb4e9c 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h @@ -329,3 +329,205 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase { paddle::framework::AttributeMap attr_map_; paddle::framework::AttributeMap default_attr_map_; }; + +class fused_attentionGradNodeCompat : public egr::GradNodeBase { + public: + fused_attentionGradNodeCompat() : egr::GradNodeBase() { + VLOG(7) << " Construct fused_attentionGradNodeCompat "; + } + fused_attentionGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { + VLOG(7) << " Construct fused_attentionGradNodeCompat "; + } + ~fused_attentionGradNodeCompat() override { + VLOG(6) << " Destruct fused_attentionGradNodeCompat "; + } + + virtual paddle::small_vector, + egr::kSlotSmallVectorSize> + operator()( + paddle::small_vector, // NOLINT + egr::kSlotSmallVectorSize>& grads, // NOLINT + bool create_graph = false, + bool is_new_grad = false) override; + + void ClearTensorWrappers() override { + AttnDropoutMaskOut_.clear(); + AttnDropoutOut_.clear(); + BiasDropoutResidualOut_.clear(); + DropoutMaskOut_.clear(); + FMHAOut_.clear(); + Ln2Bias_.clear(); + Ln2Mean_.clear(); + Ln2Scale_.clear(); + Ln2Variance_.clear(); + OutLinearBias_.clear(); + OutLinearOut_.clear(); + OutLinearW_.clear(); + QKOut_.clear(); + QKTVOut_.clear(); + QKVBias_.clear(); + QKVBiasOut_.clear(); + QKVOut_.clear(); + QKVW_.clear(); + SoftmaxOut_.clear(); + SrcMask_.clear(); + SrcMaskOut_.clear(); + TransposeOut2_.clear(); + X_.clear(); + + SetIsTensorWrappersCleared(true); + } + std::string name() override { return "fused_attentionGradNodeCompat"; } + + std::shared_ptr Copy() const override { + { + auto copied_node = std::shared_ptr( + new fused_attentionGradNodeCompat(*this)); + return copied_node; + } + } + + // SetX, SetY, ... + void SetTensorWrapperAttnDropoutMaskOut( + const paddle::experimental::Tensor& AttnDropoutMaskOut) { + AttnDropoutMaskOut_ = egr::TensorWrapper(AttnDropoutMaskOut, false); + } + void SetTensorWrapperAttnDropoutOut( + const paddle::experimental::Tensor& AttnDropoutOut) { + AttnDropoutOut_ = egr::TensorWrapper(AttnDropoutOut, false); + } + void SetTensorWrapperBiasDropoutResidualOut( + const paddle::experimental::Tensor& BiasDropoutResidualOut) { + BiasDropoutResidualOut_ = egr::TensorWrapper(BiasDropoutResidualOut, false); + } + void SetTensorWrapperDropoutMaskOut( + const paddle::experimental::Tensor& DropoutMaskOut) { + DropoutMaskOut_ = egr::TensorWrapper(DropoutMaskOut, false); + } + void SetTensorWrapperFMHAOut(const paddle::experimental::Tensor& FMHAOut) { + FMHAOut_ = egr::TensorWrapper(FMHAOut, false); + } + void SetTensorWrapperLn2Bias(const paddle::experimental::Tensor& Ln2Bias) { + Ln2Bias_ = egr::TensorWrapper(Ln2Bias, false); + } + void SetTensorWrapperLn2Mean(const paddle::experimental::Tensor& Ln2Mean) { + Ln2Mean_ = egr::TensorWrapper(Ln2Mean, false); + } + void SetTensorWrapperLn2Scale(const paddle::experimental::Tensor& Ln2Scale) { + Ln2Scale_ = egr::TensorWrapper(Ln2Scale, false); + } + void SetTensorWrapperLn2Variance( + const paddle::experimental::Tensor& Ln2Variance) { + Ln2Variance_ = egr::TensorWrapper(Ln2Variance, false); + } + void SetTensorWrapperOutLinearBias( + const paddle::experimental::Tensor& OutLinearBias) { + OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false); + } + void SetTensorWrapperOutLinearOut( + const paddle::experimental::Tensor& OutLinearOut) { + OutLinearOut_ = egr::TensorWrapper(OutLinearOut, false); + } + void SetTensorWrapperOutLinearW( + const paddle::experimental::Tensor& OutLinearW) { + OutLinearW_ = egr::TensorWrapper(OutLinearW, false); + } + void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) { + QKOut_ = egr::TensorWrapper(QKOut, false); + } + void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) { + QKTVOut_ = egr::TensorWrapper(QKTVOut, false); + } + void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) { + QKVBias_ = egr::TensorWrapper(QKVBias, false); + } + void SetTensorWrapperQKVBiasOut( + const paddle::experimental::Tensor& QKVBiasOut) { + QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, false); + } + void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) { + QKVOut_ = egr::TensorWrapper(QKVOut, false); + } + void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) { + QKVW_ = egr::TensorWrapper(QKVW, false); + } + void SetTensorWrapperSoftmaxOut( + const paddle::experimental::Tensor& SoftmaxOut) { + SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); + } + void SetTensorWrapperSrcMask(const paddle::experimental::Tensor& SrcMask) { + SrcMask_ = egr::TensorWrapper(SrcMask, false); + } + void SetTensorWrapperSrcMaskOut( + const paddle::experimental::Tensor& SrcMaskOut) { + SrcMaskOut_ = egr::TensorWrapper(SrcMaskOut, false); + } + void SetTensorWrapperTransposeOut2( + const paddle::experimental::Tensor& TransposeOut2) { + TransposeOut2_ = egr::TensorWrapper(TransposeOut2, false); + } + void SetTensorWrapperX(const paddle::experimental::Tensor& X) { + X_ = egr::TensorWrapper(X, false); + } + void SetTensorWrapperLnScale(const paddle::experimental::Tensor& LnScale) { + LnScale_ = egr::TensorWrapper(LnScale, false); + } + void SetTensorWrapperLnBias(const paddle::experimental::Tensor& LnBias) { + LnBias_ = egr::TensorWrapper(LnBias, false); + } + void SetTensorWrapperLnOut(const paddle::experimental::Tensor& LnOut) { + LnOut_ = egr::TensorWrapper(LnOut, false); + } + void SetTensorWrapperLnMean(const paddle::experimental::Tensor& LnMean) { + LnMean_ = egr::TensorWrapper(LnMean, false); + } + void SetTensorWrapperLnVariance( + const paddle::experimental::Tensor& LnVariance) { + LnVariance_ = egr::TensorWrapper(LnVariance, false); + } + + // SetAttrMap + void SetAttrMap(paddle::framework::AttributeMap&& attr_map) { + attr_map_ = std::move(attr_map); + } + void SetDefaultAttrMap(paddle::framework::AttributeMap&& default_attr_map) { + default_attr_map_ = std::move(default_attr_map); + } + + private: + // TensorWrappers + egr::TensorWrapper AttnDropoutMaskOut_; + egr::TensorWrapper AttnDropoutOut_; + egr::TensorWrapper BiasDropoutResidualOut_; + egr::TensorWrapper DropoutMaskOut_; + egr::TensorWrapper FMHAOut_; + egr::TensorWrapper Ln2Bias_; + egr::TensorWrapper Ln2Mean_; + egr::TensorWrapper Ln2Scale_; + egr::TensorWrapper Ln2Variance_; + egr::TensorWrapper OutLinearBias_; + egr::TensorWrapper OutLinearOut_; + egr::TensorWrapper OutLinearW_; + egr::TensorWrapper QKOut_; + egr::TensorWrapper QKTVOut_; + egr::TensorWrapper QKVBias_; + egr::TensorWrapper QKVBiasOut_; + egr::TensorWrapper QKVOut_; + egr::TensorWrapper QKVW_; + egr::TensorWrapper SoftmaxOut_; + egr::TensorWrapper SrcMask_; + egr::TensorWrapper SrcMaskOut_; + egr::TensorWrapper TransposeOut2_; + egr::TensorWrapper X_; + + egr::TensorWrapper LnScale_; + egr::TensorWrapper LnBias_; + egr::TensorWrapper LnOut_; + egr::TensorWrapper LnMean_; + egr::TensorWrapper LnVariance_; + + // Attribute Map + paddle::framework::AttributeMap attr_map_; + paddle::framework::AttributeMap default_attr_map_; +}; diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 6eb35eb13f..1b3c7fd8e4 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -51,8 +51,10 @@ static std::unordered_set ops_to_fill_zero_for_empty_grads = { "split", "rnn"}; /* --- Black Ops list that's NO NEED to apply code generation --- */ -static std::unordered_set black_ops_list = { - "run_program", "fused_gate_attention", "fused_feedforward"}; +static std::unordered_set black_ops_list = {"run_program", + "fused_gate_attention", + "fused_feedforward", + "fused_attention"}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 6507cc1ee3..1ad29ecadd 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -26,9 +26,7 @@ from paddle import tensor from paddle.fluid import layers import unittest from op_test import OpTest -from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph - -_enable_legacy_dygraph() +from paddle.fluid.framework import default_main_program default_main_program().random_seed = 42 diff --git a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py index 0aad7ec758..8b8d378e5c 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py @@ -26,11 +26,9 @@ import unittest from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float from test_sparse_attention_op import get_cuda_version from paddle import _C_ops -from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph +from paddle.fluid.framework import default_main_program from paddle.fluid import core -_enable_legacy_dygraph() - @unittest.skipIf(not core.is_compiled_with_cuda(), "Paddle is not compiled with CUDA") -- GitLab