未验证 提交 2afa9b76 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] Menual fused attention in eager (#43974)

* fused_gate_attention manual code in eager
上级 9aaae254
...@@ -67,3 +67,37 @@ fused_feedforward_dygraph_function( ...@@ -67,3 +67,37 @@ fused_feedforward_dygraph_function(
const paddle::experimental::Tensor& Ln2Scale, const paddle::experimental::Tensor& Ln2Scale,
const paddle::experimental::Tensor& Ln2Bias, const paddle::experimental::Tensor& Ln2Bias,
const paddle::framework::AttributeMap& attr_map); const paddle::framework::AttributeMap& attr_map);
std::tuple<paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor>
fused_attention_dygraph_function(
const paddle::experimental::Tensor& X,
const paddle::experimental::Tensor& LnScale,
const paddle::experimental::Tensor& LnBias,
const paddle::experimental::Tensor& QKVW,
const paddle::experimental::Tensor& QKVBias,
const paddle::experimental::Tensor& CacheKV,
const paddle::experimental::Tensor& SrcMask,
const paddle::experimental::Tensor& OutLinearW,
const paddle::experimental::Tensor& OutLinearBias,
const paddle::experimental::Tensor& Ln2Scale,
const paddle::experimental::Tensor& Ln2Bias,
const paddle::framework::AttributeMap& attr_map);
...@@ -12,6 +12,14 @@ cc_library( ...@@ -12,6 +12,14 @@ cc_library(
add_dependencies(fused_feedforward_fwd_func eager_codegen) add_dependencies(fused_feedforward_fwd_func eager_codegen)
cc_library(
fused_attention_fwd_func
SRCS fused_attention_fwd_func.cc
DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
add_dependencies(fused_attention_fwd_func eager_codegen)
set(fluid_manual_functions set(fluid_manual_functions
fused_gate_attention_fwd_func fused_feedforward_fwd_func fused_gate_attention_fwd_func fused_feedforward_fwd_func
fused_attention_fwd_func
PARENT_SCOPE) PARENT_SCOPE)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/amp_auto_cast.h"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#pragma GCC diagnostic ignored "-Wunused-variable"
std::tuple<paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor>
fused_attention_dygraph_function(
const paddle::experimental::Tensor& X,
const paddle::experimental::Tensor& LnScale,
const paddle::experimental::Tensor& LnBias,
const paddle::experimental::Tensor& QKVW,
const paddle::experimental::Tensor& QKVBias,
const paddle::experimental::Tensor& CacheKV,
const paddle::experimental::Tensor& SrcMask,
const paddle::experimental::Tensor& OutLinearW,
const paddle::experimental::Tensor& OutLinearBias,
const paddle::experimental::Tensor& Ln2Scale,
const paddle::experimental::Tensor& Ln2Bias,
const paddle::framework::AttributeMap& attr_map) {
paddle::platform::RecordEvent dygraph_entrance_record_event(
"fused_attention dygraph",
paddle::platform::TracerEventType::Operator,
1);
VLOG(3) << "Running Eager Forward Op: fused_attention";
// Dygraph Forward Pass
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
VLOG(5) << "Check and Prepare For AMP";
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
amp_tensors_vector = {{X}, {QKVW}, {OutLinearW}};
if (LnScale.initialized()) amp_tensors_vector.push_back({LnScale});
if (LnBias.initialized()) amp_tensors_vector.push_back({LnBias});
if (QKVBias.initialized()) amp_tensors_vector.push_back({QKVBias});
if (CacheKV.initialized()) amp_tensors_vector.push_back({CacheKV});
if (SrcMask.initialized()) amp_tensors_vector.push_back({SrcMask});
if (OutLinearBias.initialized())
amp_tensors_vector.push_back({OutLinearBias});
if (Ln2Scale.initialized()) amp_tensors_vector.push_back({Ln2Scale});
if (Ln2Bias.initialized()) amp_tensors_vector.push_back({Ln2Bias});
auto amp_dst_dtype =
egr::GetAmpDestDtype("fused_attention", amp_tensors_vector);
auto NEW_X = egr::AmpAutoCast("X", X, amp_dst_dtype, "fused_attention");
auto NEW_QKVW =
egr::AmpAutoCast("QKVW", QKVW, amp_dst_dtype, "fused_attention");
auto NEW_OutLinearW = egr::AmpAutoCast(
"OutLinearW", OutLinearW, amp_dst_dtype, "fused_attention");
auto NEW_LnScale =
((LnScale.initialized())
? egr::AmpAutoCast(
"LnScale", LnScale, amp_dst_dtype, "fused_attention")
: LnScale);
auto NEW_LnBias =
((LnBias.initialized())
? egr::AmpAutoCast(
"LnBias", LnBias, amp_dst_dtype, "fused_attention")
: LnBias);
auto NEW_QKVBias =
((QKVBias.initialized())
? egr::AmpAutoCast(
"QKVBias", QKVBias, amp_dst_dtype, "fused_attention")
: QKVBias);
auto NEW_CacheKV =
((CacheKV.initialized())
? egr::AmpAutoCast(
"CacheKV", CacheKV, amp_dst_dtype, "fused_attention")
: CacheKV);
auto NEW_SrcMask =
((SrcMask.initialized())
? egr::AmpAutoCast(
"SrcMask", SrcMask, amp_dst_dtype, "fused_attention")
: SrcMask);
auto NEW_OutLinearBias =
((OutLinearBias.initialized()) ? egr::AmpAutoCast("OutLinearBias",
OutLinearBias,
amp_dst_dtype,
"fused_attention")
: OutLinearBias);
auto NEW_Ln2Scale =
((Ln2Scale.initialized())
? egr::AmpAutoCast(
"Ln2Scale", Ln2Scale, amp_dst_dtype, "fused_attention")
: Ln2Scale);
auto NEW_Ln2Bias =
((Ln2Bias.initialized())
? egr::AmpAutoCast(
"Ln2Bias", Ln2Bias, amp_dst_dtype, "fused_attention")
: Ln2Bias);
{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentTracer(),
paddle::imperative::AmpLevel::O0);
return fused_attention_dygraph_function(NEW_X,
NEW_LnScale,
NEW_LnBias,
NEW_QKVW,
NEW_QKVBias,
NEW_CacheKV,
NEW_SrcMask,
NEW_OutLinearW,
NEW_OutLinearBias,
NEW_Ln2Scale,
NEW_Ln2Bias,
attr_map);
}
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins =
{{"X", egr::EagerUtils::TrySyncToVars(X)},
{"QKVW", egr::EagerUtils::TrySyncToVars(QKVW)},
{"OutLinearW", egr::EagerUtils::TrySyncToVars(OutLinearW)}};
if (LnScale.initialized())
ins["LnScale"] = egr::EagerUtils::TrySyncToVars(LnScale);
if (LnBias.initialized())
ins["LnBias"] = egr::EagerUtils::TrySyncToVars(LnBias);
if (QKVBias.initialized())
ins["QKVBias"] = egr::EagerUtils::TrySyncToVars(QKVBias);
if (CacheKV.initialized())
ins["CacheKV"] = egr::EagerUtils::TrySyncToVars(CacheKV);
if (SrcMask.initialized())
ins["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask);
if (OutLinearBias.initialized())
ins["OutLinearBias"] = egr::EagerUtils::TrySyncToVars(OutLinearBias);
if (Ln2Scale.initialized())
ins["Ln2Scale"] = egr::EagerUtils::TrySyncToVars(Ln2Scale);
if (Ln2Bias.initialized())
ins["Ln2Bias"] = egr::EagerUtils::TrySyncToVars(Ln2Bias);
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> outs =
{{"LnMean",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"LnVariance",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"LnOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"QKVOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"QKVBiasOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"TransposeOut2",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"QKOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"QKTVOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"SoftmaxOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"AttnDropoutMaskOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"AttnDropoutOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"SrcMaskOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"FMHAOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"OutLinearOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"DropoutMaskOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"Ln2Mean",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"Ln2Variance",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"BiasDropoutResidualOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"CacheKVOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"Y",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}}};
// Prepare Autograd Meta
egr::AutogradMeta* p_autograd_X = egr::EagerUtils::nullable_autograd_meta(X);
egr::AutogradMeta* p_autograd_LnScale =
egr::EagerUtils::nullable_autograd_meta(LnScale);
egr::AutogradMeta* p_autograd_LnBias =
egr::EagerUtils::nullable_autograd_meta(LnBias);
egr::AutogradMeta* p_autograd_QKVW =
egr::EagerUtils::nullable_autograd_meta(QKVW);
egr::AutogradMeta* p_autograd_QKVBias =
egr::EagerUtils::nullable_autograd_meta(QKVBias);
egr::AutogradMeta* p_autograd_CacheKV =
egr::EagerUtils::nullable_autograd_meta(CacheKV);
egr::AutogradMeta* p_autograd_SrcMask =
egr::EagerUtils::nullable_autograd_meta(SrcMask);
egr::AutogradMeta* p_autograd_OutLinearW =
egr::EagerUtils::nullable_autograd_meta(OutLinearW);
egr::AutogradMeta* p_autograd_OutLinearBias =
egr::EagerUtils::nullable_autograd_meta(OutLinearBias);
egr::AutogradMeta* p_autograd_Ln2Scale =
egr::EagerUtils::nullable_autograd_meta(Ln2Scale);
egr::AutogradMeta* p_autograd_Ln2Bias =
egr::EagerUtils::nullable_autograd_meta(Ln2Bias);
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad =
egr::EagerUtils::ComputeRequireGrad(trace_backward,
p_autograd_X,
p_autograd_LnScale,
p_autograd_LnBias,
p_autograd_QKVW,
p_autograd_QKVBias,
p_autograd_CacheKV,
p_autograd_SrcMask,
p_autograd_OutLinearW,
p_autograd_OutLinearBias,
p_autograd_Ln2Scale,
p_autograd_Ln2Bias);
paddle::framework::AttributeMap attrs = attr_map;
paddle::framework::AttributeMap default_attrs;
egr::Controller::Instance().GetCurrentTracer()->TraceOp(
"fused_attention",
ins,
outs,
attrs,
egr::Controller::Instance().GetExpectedPlace(),
&default_attrs,
true,
{});
paddle::experimental::Tensor LnMean;
egr::EagerUtils::GetOutput(outs["LnMean"][0], &LnMean);
paddle::experimental::Tensor LnVariance;
egr::EagerUtils::GetOutput(outs["LnVariance"][0], &LnVariance);
paddle::experimental::Tensor LnOut;
egr::EagerUtils::GetOutput(outs["LnOut"][0], &LnOut);
paddle::experimental::Tensor QKVOut;
egr::EagerUtils::GetOutput(outs["QKVOut"][0], &QKVOut);
paddle::experimental::Tensor QKVBiasOut;
egr::EagerUtils::GetOutput(outs["QKVBiasOut"][0], &QKVBiasOut);
paddle::experimental::Tensor TransposeOut2;
egr::EagerUtils::GetOutput(outs["TransposeOut2"][0], &TransposeOut2);
paddle::experimental::Tensor QKOut;
egr::EagerUtils::GetOutput(outs["QKOut"][0], &QKOut);
paddle::experimental::Tensor QKTVOut;
egr::EagerUtils::GetOutput(outs["QKTVOut"][0], &QKTVOut);
paddle::experimental::Tensor SoftmaxOut;
egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut);
paddle::experimental::Tensor AttnDropoutMaskOut;
egr::EagerUtils::GetOutput(outs["AttnDropoutMaskOut"][0],
&AttnDropoutMaskOut);
paddle::experimental::Tensor AttnDropoutOut;
egr::EagerUtils::GetOutput(outs["AttnDropoutOut"][0], &AttnDropoutOut);
paddle::experimental::Tensor SrcMaskOut;
egr::EagerUtils::GetOutput(outs["SrcMaskOut"][0], &SrcMaskOut);
paddle::experimental::Tensor FMHAOut;
egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut);
paddle::experimental::Tensor OutLinearOut;
egr::EagerUtils::GetOutput(outs["OutLinearOut"][0], &OutLinearOut);
paddle::experimental::Tensor DropoutMaskOut;
egr::EagerUtils::GetOutput(outs["DropoutMaskOut"][0], &DropoutMaskOut);
paddle::experimental::Tensor Ln2Mean;
egr::EagerUtils::GetOutput(outs["Ln2Mean"][0], &Ln2Mean);
paddle::experimental::Tensor Ln2Variance;
egr::EagerUtils::GetOutput(outs["Ln2Variance"][0], &Ln2Variance);
paddle::experimental::Tensor BiasDropoutResidualOut;
egr::EagerUtils::GetOutput(outs["BiasDropoutResidualOut"][0],
&BiasDropoutResidualOut);
paddle::experimental::Tensor CacheKVOut;
egr::EagerUtils::GetOutput(outs["CacheKVOut"][0], &CacheKVOut);
paddle::experimental::Tensor Y;
egr::EagerUtils::GetOutput(outs["Y"][0], &Y);
{
paddle::platform::RecordEvent node_creation_record_event(
"fused_attention node_creation",
paddle::platform::TracerEventType::Operator,
1);
egr::AutogradMeta* p_autograd_LnMean =
egr::EagerUtils::autograd_meta(&LnMean);
egr::AutogradMeta* p_autograd_LnVariance =
egr::EagerUtils::autograd_meta(&LnVariance);
egr::AutogradMeta* p_autograd_LnOut =
egr::EagerUtils::autograd_meta(&LnOut);
egr::AutogradMeta* p_autograd_QKVOut =
egr::EagerUtils::autograd_meta(&QKVOut);
egr::AutogradMeta* p_autograd_QKVBiasOut =
egr::EagerUtils::autograd_meta(&QKVBiasOut);
egr::AutogradMeta* p_autograd_TransposeOut2 =
egr::EagerUtils::autograd_meta(&TransposeOut2);
egr::AutogradMeta* p_autograd_QKOut =
egr::EagerUtils::autograd_meta(&QKOut);
egr::AutogradMeta* p_autograd_QKTVOut =
egr::EagerUtils::autograd_meta(&QKTVOut);
egr::AutogradMeta* p_autograd_SoftmaxOut =
egr::EagerUtils::autograd_meta(&SoftmaxOut);
egr::AutogradMeta* p_autograd_AttnDropoutMaskOut =
egr::EagerUtils::autograd_meta(&AttnDropoutMaskOut);
egr::AutogradMeta* p_autograd_AttnDropoutOut =
egr::EagerUtils::autograd_meta(&AttnDropoutOut);
egr::AutogradMeta* p_autograd_SrcMaskOut =
egr::EagerUtils::autograd_meta(&SrcMaskOut);
egr::AutogradMeta* p_autograd_FMHAOut =
egr::EagerUtils::autograd_meta(&FMHAOut);
egr::AutogradMeta* p_autograd_OutLinearOut =
egr::EagerUtils::autograd_meta(&OutLinearOut);
egr::AutogradMeta* p_autograd_DropoutMaskOut =
egr::EagerUtils::autograd_meta(&DropoutMaskOut);
egr::AutogradMeta* p_autograd_Ln2Mean =
egr::EagerUtils::autograd_meta(&Ln2Mean);
egr::AutogradMeta* p_autograd_Ln2Variance =
egr::EagerUtils::autograd_meta(&Ln2Variance);
egr::AutogradMeta* p_autograd_BiasDropoutResidualOut =
egr::EagerUtils::autograd_meta(&BiasDropoutResidualOut);
egr::AutogradMeta* p_autograd_CacheKVOut =
egr::EagerUtils::autograd_meta(&CacheKVOut);
egr::AutogradMeta* p_autograd_Y = egr::EagerUtils::autograd_meta(&Y);
if (require_any_grad) {
VLOG(6) << " Construct Grad for fused_attention ";
egr::EagerUtils::PassStopGradient(false,
p_autograd_LnMean,
p_autograd_LnVariance,
p_autograd_LnOut,
p_autograd_QKVOut,
p_autograd_QKVBiasOut,
p_autograd_TransposeOut2,
p_autograd_QKOut,
p_autograd_QKTVOut,
p_autograd_SoftmaxOut,
p_autograd_AttnDropoutMaskOut,
p_autograd_AttnDropoutOut,
p_autograd_SrcMaskOut,
p_autograd_FMHAOut,
p_autograd_OutLinearOut,
p_autograd_DropoutMaskOut,
p_autograd_Ln2Mean,
p_autograd_Ln2Variance,
p_autograd_BiasDropoutResidualOut,
p_autograd_CacheKVOut,
p_autograd_Y);
// Create GradOpNode
auto grad_node = std::shared_ptr<fused_attentionGradNodeCompat>(
new fused_attentionGradNodeCompat(20, 23));
bool pre_layer_norm = false;
if (attrs.count("pre_layer_norm")) {
pre_layer_norm = BOOST_GET_CONST(bool, attrs.at("pre_layer_norm"));
}
// Set Attributes
grad_node->SetAttrMap(std::move(attrs));
grad_node->SetDefaultAttrMap(std::move(default_attrs));
grad_node->SetTensorWrapperX(X);
grad_node->SetTensorWrapperQKVW(QKVW);
grad_node->SetTensorWrapperOutLinearW(OutLinearW);
grad_node->SetTensorWrapperQKVOut(QKVOut);
grad_node->SetTensorWrapperTransposeOut2(TransposeOut2);
grad_node->SetTensorWrapperQKOut(QKOut);
grad_node->SetTensorWrapperQKTVOut(QKTVOut);
grad_node->SetTensorWrapperSoftmaxOut(SoftmaxOut);
grad_node->SetTensorWrapperAttnDropoutMaskOut(AttnDropoutMaskOut);
grad_node->SetTensorWrapperAttnDropoutOut(AttnDropoutOut);
grad_node->SetTensorWrapperFMHAOut(FMHAOut);
grad_node->SetTensorWrapperOutLinearOut(OutLinearOut);
grad_node->SetTensorWrapperDropoutMaskOut(DropoutMaskOut);
grad_node->SetGradOutMeta(X, 0);
grad_node->SetGradOutMeta(QKVW, 3);
grad_node->SetGradOutMeta(OutLinearW, 7);
if (QKVBias.initialized()) {
grad_node->SetTensorWrapperQKVBias(QKVBias);
grad_node->SetTensorWrapperQKVBiasOut(QKVBiasOut);
grad_node->SetGradOutMeta(QKVBias, 4);
auto QKVBiasOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKVBiasOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVBiasOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKVBiasOut,
QKVBiasOut_accumulation_node);
QKVBiasOut_accumulation_node->SetGradInMeta(QKVBiasOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKVBiasOut);
grad_node->SetGradOutMeta(QKVBiasOut, 11);
}
if (SrcMask.initialized()) {
grad_node->SetTensorWrapperSrcMask(SrcMask);
grad_node->SetTensorWrapperSrcMaskOut(SrcMaskOut);
auto SrcMaskOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_SrcMaskOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_SrcMaskOut, 0);
egr::EagerUtils::SetHistory(p_autograd_SrcMaskOut,
SrcMaskOut_accumulation_node);
SrcMaskOut_accumulation_node->SetGradInMeta(SrcMaskOut, 0);
egr::EagerUtils::CheckAndRetainGrad(SrcMaskOut);
grad_node->SetGradOutMeta(SrcMaskOut, 12);
}
if (OutLinearBias.initialized()) {
grad_node->SetTensorWrapperOutLinearBias(OutLinearBias);
grad_node->SetGradOutMeta(OutLinearBias, 8);
}
if (pre_layer_norm) {
if (LnScale.initialized()) {
grad_node->SetTensorWrapperLnScale(LnScale);
grad_node->SetGradOutMeta(LnScale, 1);
}
if (LnBias.initialized()) {
grad_node->SetTensorWrapperLnBias(LnBias);
grad_node->SetGradOutMeta(LnBias, 2);
}
if (LnOut.initialized()) {
grad_node->SetTensorWrapperLnOut(LnOut);
auto LnOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_LnOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_LnOut, 0);
egr::EagerUtils::SetHistory(p_autograd_LnOut,
LnOut_accumulation_node);
LnOut_accumulation_node->SetGradInMeta(LnOut, 0);
egr::EagerUtils::CheckAndRetainGrad(LnOut);
grad_node->SetGradOutMeta(LnOut, 13);
}
if (LnMean.initialized()) {
grad_node->SetTensorWrapperLnMean(LnMean);
}
if (LnVariance.initialized()) {
grad_node->SetTensorWrapperLnVariance(LnVariance);
}
} else {
if (Ln2Scale.initialized()) {
grad_node->SetTensorWrapperLn2Scale(Ln2Scale);
grad_node->SetGradOutMeta(Ln2Scale, 9);
}
if (Ln2Bias.initialized()) {
grad_node->SetTensorWrapperLn2Bias(Ln2Bias);
grad_node->SetGradOutMeta(Ln2Bias, 10);
}
grad_node->SetTensorWrapperBiasDropoutResidualOut(
BiasDropoutResidualOut);
grad_node->SetTensorWrapperLn2Mean(Ln2Mean);
grad_node->SetTensorWrapperLn2Variance(Ln2Variance);
auto BiasDropoutResidualOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(
p_autograd_BiasDropoutResidualOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_BiasDropoutResidualOut,
0);
egr::EagerUtils::SetHistory(p_autograd_BiasDropoutResidualOut,
BiasDropoutResidualOut_accumulation_node);
BiasDropoutResidualOut_accumulation_node->SetGradInMeta(
BiasDropoutResidualOut, 0);
egr::EagerUtils::CheckAndRetainGrad(BiasDropoutResidualOut);
grad_node->SetGradOutMeta(BiasDropoutResidualOut, 14);
}
egr::EagerUtils::SetOutRankWithSlot(p_autograd_LnMean, 0);
grad_node->SetGradInMeta(LnMean, 0);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_LnVariance, 1);
grad_node->SetGradInMeta(LnVariance, 1);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutMaskOut, 9);
grad_node->SetGradInMeta(AttnDropoutMaskOut, 9);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_DropoutMaskOut, 14);
grad_node->SetGradInMeta(DropoutMaskOut, 14);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Ln2Mean, 15);
grad_node->SetGradInMeta(Ln2Mean, 15);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Ln2Variance, 16);
grad_node->SetGradInMeta(Ln2Variance, 16);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_CacheKVOut, 18);
egr::EagerUtils::SetHistory(p_autograd_CacheKVOut, grad_node);
grad_node->SetGradInMeta(CacheKVOut, 18);
egr::EagerUtils::CheckAndRetainGrad(CacheKVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Y, 19);
egr::EagerUtils::SetHistory(p_autograd_Y, grad_node);
grad_node->SetGradInMeta(Y, 19);
egr::EagerUtils::CheckAndRetainGrad(Y);
auto QKVOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKVOut, QKVOut_accumulation_node);
QKVOut_accumulation_node->SetGradInMeta(QKVOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKVOut);
grad_node->SetGradOutMeta(QKVOut, 15);
auto QKTVOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKTVOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKTVOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKTVOut,
QKTVOut_accumulation_node);
QKTVOut_accumulation_node->SetGradInMeta(QKTVOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKTVOut);
grad_node->SetGradOutMeta(QKTVOut, 16);
auto TransposeOut2_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_TransposeOut2);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_TransposeOut2, 0);
egr::EagerUtils::SetHistory(p_autograd_TransposeOut2,
TransposeOut2_accumulation_node);
TransposeOut2_accumulation_node->SetGradInMeta(TransposeOut2, 0);
egr::EagerUtils::CheckAndRetainGrad(TransposeOut2);
grad_node->SetGradOutMeta(TransposeOut2, 17);
auto QKOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_QKOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKOut, 0);
egr::EagerUtils::SetHistory(p_autograd_QKOut, QKOut_accumulation_node);
QKOut_accumulation_node->SetGradInMeta(QKOut, 0);
egr::EagerUtils::CheckAndRetainGrad(QKOut);
grad_node->SetGradOutMeta(QKOut, 18);
auto SoftmaxOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_SoftmaxOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_SoftmaxOut, 0);
egr::EagerUtils::SetHistory(p_autograd_SoftmaxOut,
SoftmaxOut_accumulation_node);
SoftmaxOut_accumulation_node->SetGradInMeta(SoftmaxOut, 0);
egr::EagerUtils::CheckAndRetainGrad(SoftmaxOut);
grad_node->SetGradOutMeta(SoftmaxOut, 19);
auto AttnDropoutOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(
p_autograd_AttnDropoutOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_AttnDropoutOut, 0);
egr::EagerUtils::SetHistory(p_autograd_AttnDropoutOut,
AttnDropoutOut_accumulation_node);
AttnDropoutOut_accumulation_node->SetGradInMeta(AttnDropoutOut, 0);
egr::EagerUtils::CheckAndRetainGrad(AttnDropoutOut);
grad_node->SetGradOutMeta(AttnDropoutOut, 20);
auto FMHAOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_FMHAOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_FMHAOut, 0);
egr::EagerUtils::SetHistory(p_autograd_FMHAOut,
FMHAOut_accumulation_node);
FMHAOut_accumulation_node->SetGradInMeta(FMHAOut, 0);
egr::EagerUtils::CheckAndRetainGrad(FMHAOut);
grad_node->SetGradOutMeta(FMHAOut, 21);
auto OutLinearOut_accumulation_node =
std::make_shared<egr::GradNodeAccumulation>(p_autograd_OutLinearOut);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_OutLinearOut, 0);
egr::EagerUtils::SetHistory(p_autograd_OutLinearOut,
OutLinearOut_accumulation_node);
OutLinearOut_accumulation_node->SetGradInMeta(OutLinearOut, 0);
egr::EagerUtils::CheckAndRetainGrad(OutLinearOut);
grad_node->SetGradOutMeta(OutLinearOut, 22);
}
}
return std::make_tuple(LnMean,
LnVariance,
LnOut,
QKVOut,
QKVBiasOut,
TransposeOut2,
QKOut,
QKTVOut,
SoftmaxOut,
AttnDropoutMaskOut,
AttnDropoutOut,
SrcMaskOut,
FMHAOut,
OutLinearOut,
DropoutMaskOut,
Ln2Mean,
Ln2Variance,
BiasDropoutResidualOut,
CacheKVOut,
Y);
}
...@@ -8,6 +8,11 @@ cc_library( ...@@ -8,6 +8,11 @@ cc_library(
SRCS fused_feedforward_node.cc SRCS fused_feedforward_node.cc
DEPS ${eager_deps} ${fluid_deps}) DEPS ${eager_deps} ${fluid_deps})
cc_library(
fused_attention_node
SRCS fused_attention_node.cc
DEPS ${eager_deps} ${fluid_deps})
set(fluid_manual_nodes set(fluid_manual_nodes
fused_gate_attention_node fused_feedforward_node fused_gate_attention_node fused_feedforward_node fused_attention_node
PARENT_SCOPE) PARENT_SCOPE)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/phi/api/all.h"
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
fused_attentionGradNodeCompat::operator()(
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>& grads,
bool create_graph,
bool is_new_grad) {
VLOG(3) << "Running Eager Backward Node: fused_attentionGradNodeCompat";
const auto& out_metas = OutputMeta();
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
outputs(23);
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
hooked_grads0 = fused_attentionGradNodeCompat::ApplyGradientHooks(grads);
bool pre_layer_norm = false;
if (attr_map_.count("pre_layer_norm")) {
pre_layer_norm = BOOST_GET_CONST(bool, attr_map_.at("pre_layer_norm"));
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 =
{{"AttnDropoutMaskOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->AttnDropoutMaskOut_))},
{"AttnDropoutOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->AttnDropoutOut_))},
{"DropoutMaskOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->DropoutMaskOut_))},
{"FMHAOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->FMHAOut_))},
{"OutLinearOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearOut_))},
{"OutLinearW",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearW_))},
{"QKOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->QKOut_))},
{"QKTVOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->QKTVOut_))},
{"QKVOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->QKVOut_))},
{"QKVW",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->QKVW_))},
{"SoftmaxOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxOut_))},
{"TransposeOut2",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->TransposeOut2_))},
{"X",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->X_))},
{"Y@GRAD", egr::EagerUtils::TrySyncToVars(hooked_grads0[19])}};
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> outs0;
if ((!out_metas[7].empty()) && (!(out_metas[7][0].IsStopGradient()))) {
outs0.insert({"OutLinearW@GRAD",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}});
}
if ((!out_metas[3].empty()) && (!(out_metas[3][0].IsStopGradient()))) {
outs0.insert({"QKVW@GRAD",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}});
}
if ((!out_metas[0].empty()) && (!(out_metas[0][0].IsStopGradient()))) {
outs0.insert({"X@GRAD",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}});
}
auto QKVOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKVOut_);
if (QKVOut.defined() && (!out_metas[15].empty()) &&
(!out_metas[15][0].IsStopGradient()))
outs0["QKVOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto QKTVOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKTVOut_);
if (QKTVOut.defined() && (!out_metas[16].empty()) &&
(!out_metas[16][0].IsStopGradient()))
outs0["QKTVOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto TransposeOut2 =
egr::EagerUtils::RecoverTensorWrapper(&this->TransposeOut2_);
if (TransposeOut2.defined() && (!out_metas[17].empty()) &&
(!out_metas[17][0].IsStopGradient()))
outs0["TransposeOut2@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto QKOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKOut_);
if (QKOut.defined() && (!out_metas[18].empty()) &&
(!out_metas[18][0].IsStopGradient()))
outs0["QKOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto SoftmaxOut = egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxOut_);
if (SoftmaxOut.defined() && (!out_metas[19].empty()) &&
(!out_metas[19][0].IsStopGradient()))
outs0["SoftmaxOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto AttnDropoutOut =
egr::EagerUtils::RecoverTensorWrapper(&this->AttnDropoutOut_);
if (AttnDropoutOut.defined() && (!out_metas[20].empty()) &&
(!out_metas[20][0].IsStopGradient()))
outs0["AttnDropoutOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto FMHAOut = egr::EagerUtils::RecoverTensorWrapper(&this->FMHAOut_);
if (FMHAOut.defined() && (!out_metas[21].empty()) &&
(!out_metas[21][0].IsStopGradient()))
outs0["FMHAOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto OutLinearOut =
egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearOut_);
if (OutLinearOut.defined() && (!out_metas[22].empty()) &&
(!out_metas[22][0].IsStopGradient()))
outs0["OutLinearOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
auto QKVBias = egr::EagerUtils::RecoverTensorWrapper(&this->QKVBias_);
if (QKVBias.defined()) {
ins0["QKVBias"] = egr::EagerUtils::TrySyncToVars(QKVBias);
auto QKVBiasOut = egr::EagerUtils::RecoverTensorWrapper(&this->QKVBiasOut_);
ins0["QKVBiasOut"] = egr::EagerUtils::TrySyncToVars(QKVBiasOut);
if (QKVBias.defined() && (!out_metas[4].empty()) &&
(!out_metas[4][0].IsStopGradient()))
outs0["QKVBias@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
if (QKVBiasOut.defined() && (!out_metas[11].empty()) &&
(!out_metas[11][0].IsStopGradient()))
outs0["QKVBiasOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto SrcMask = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMask_);
if (SrcMask.defined()) {
ins0["SrcMask"] = egr::EagerUtils::TrySyncToVars(SrcMask);
auto SrcMaskOut = egr::EagerUtils::RecoverTensorWrapper(&this->SrcMaskOut_);
ins0["SrcMaskOut"] = egr::EagerUtils::TrySyncToVars(SrcMaskOut);
if (SrcMaskOut.defined() && (!out_metas[12].empty()) &&
(!out_metas[12][0].IsStopGradient()))
outs0["SrcMaskOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto OutLinearBias =
egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearBias_);
if (OutLinearBias.defined()) {
ins0["OutLinearBias"] = egr::EagerUtils::TrySyncToVars(OutLinearBias);
if (OutLinearBias.defined() && (!out_metas[8].empty()) &&
(!out_metas[8][0].IsStopGradient()))
outs0["OutLinearBias@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
if (pre_layer_norm) {
auto LnScale = egr::EagerUtils::RecoverTensorWrapper(&this->LnScale_);
if (LnScale.defined()) {
ins0["LnScale"] = egr::EagerUtils::TrySyncToVars(LnScale);
if (LnScale.defined() && (!out_metas[1].empty()) &&
(!out_metas[1][0].IsStopGradient()))
outs0["LnScale@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto LnBias = egr::EagerUtils::RecoverTensorWrapper(&this->LnBias_);
if (LnBias.defined()) {
ins0["LnBias"] = egr::EagerUtils::TrySyncToVars(LnBias);
if (LnBias.defined() && (!out_metas[2].empty()) &&
(!out_metas[2][0].IsStopGradient()))
outs0["LnBias@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto LnOut = egr::EagerUtils::RecoverTensorWrapper(&this->LnOut_);
if (LnOut.defined()) {
ins0["LnOut"] = egr::EagerUtils::TrySyncToVars(LnOut);
if (LnOut.defined() && (!out_metas[13].empty()) &&
(!out_metas[13][0].IsStopGradient()))
outs0["LnOut@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto LnMean = egr::EagerUtils::RecoverTensorWrapper(&this->LnMean_);
if (LnMean.defined()) {
ins0["LnMean"] = egr::EagerUtils::TrySyncToVars(LnMean);
}
auto LnVariance = egr::EagerUtils::RecoverTensorWrapper(&this->LnVariance_);
if (LnVariance.defined()) {
ins0["LnVariance"] = egr::EagerUtils::TrySyncToVars(LnVariance);
}
} else {
auto Ln2Scale = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Scale_);
if (Ln2Scale.defined()) {
ins0["Ln2Scale"] = egr::EagerUtils::TrySyncToVars(Ln2Scale);
if (Ln2Scale.defined() && (!out_metas[9].empty()) &&
(!out_metas[9][0].IsStopGradient()))
outs0["Ln2Scale@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto Ln2Bias = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Bias_);
if (Ln2Bias.defined()) {
ins0["Ln2Bias"] = egr::EagerUtils::TrySyncToVars(Ln2Bias);
if (Ln2Bias.defined() && (!out_metas[10].empty()) &&
(!out_metas[10][0].IsStopGradient()))
outs0["Ln2Bias@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto BiasDropoutResidualOut =
egr::EagerUtils::RecoverTensorWrapper(&this->BiasDropoutResidualOut_);
auto Ln2Mean = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Mean_);
auto Ln2Variance =
egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Variance_);
ins0["BiasDropoutResidualOut"] =
egr::EagerUtils::TrySyncToVars(BiasDropoutResidualOut);
ins0["Ln2Mean"] = egr::EagerUtils::TrySyncToVars(Ln2Mean);
ins0["Ln2Variance"] = egr::EagerUtils::TrySyncToVars(Ln2Variance);
if (BiasDropoutResidualOut.defined() && (!out_metas[14].empty()) &&
(!out_metas[14][0].IsStopGradient()))
outs0["BiasDropoutResidualOut@GRAD"] = {
std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto& attrs_map0 = this->attr_map_;
// Pass the entire attribute map to TraceOp
// The underlying kernel will pickup whatever attribute they need at runtime
egr::Controller::Instance().GetCurrentTracer()->TraceOp(
"fused_attention_grad",
ins0,
outs0,
attrs_map0,
egr::Controller::Instance().GetExpectedPlace(),
&this->default_attr_map_,
false,
{});
if (outs0.find("OutLinearW@GRAD") != outs0.end()) {
outputs[7] = egr::EagerUtils::GetOutputs(outs0["OutLinearW@GRAD"]);
}
if (outs0.find("QKVW@GRAD") != outs0.end()) {
outputs[3] = egr::EagerUtils::GetOutputs(outs0["QKVW@GRAD"]);
}
if (outs0.find("X@GRAD") != outs0.end()) {
outputs[0] = egr::EagerUtils::GetOutputs(outs0["X@GRAD"]);
}
if (outs0.find("QKVOut@GRAD") != outs0.end()) {
outputs[15] = egr::EagerUtils::GetOutputs(outs0["QKVOut@GRAD"]);
}
if (outs0.find("QKTVOut@GRAD") != outs0.end()) {
outputs[16] = egr::EagerUtils::GetOutputs(outs0["QKTVOut@GRAD"]);
}
if (outs0.find("TransposeOut2@GRAD") != outs0.end()) {
outputs[17] = egr::EagerUtils::GetOutputs(outs0["TransposeOut2@GRAD"]);
}
if (outs0.find("QKOut@GRAD") != outs0.end()) {
outputs[18] = egr::EagerUtils::GetOutputs(outs0["QKOut@GRAD"]);
}
if (outs0.find("SoftmaxOut@GRAD") != outs0.end()) {
outputs[19] = egr::EagerUtils::GetOutputs(outs0["SoftmaxOut@GRAD"]);
}
if (outs0.find("AttnDropoutOut@GRAD") != outs0.end()) {
outputs[20] = egr::EagerUtils::GetOutputs(outs0["AttnDropoutOut@GRAD"]);
}
if (outs0.find("FMHAOut@GRAD") != outs0.end()) {
outputs[21] = egr::EagerUtils::GetOutputs(outs0["FMHAOut@GRAD"]);
}
if (outs0.find("OutLinearOut@GRAD") != outs0.end()) {
outputs[22] = egr::EagerUtils::GetOutputs(outs0["OutLinearOut@GRAD"]);
}
if (QKVBias.defined()) {
if (outs0.find("QKVBias@GRAD") != outs0.end()) {
outputs[4] = egr::EagerUtils::GetOutputs(outs0["QKVBias@GRAD"]);
}
if (outs0.find("QKVBiasOut@GRAD") != outs0.end()) {
outputs[11] = egr::EagerUtils::GetOutputs(outs0["QKVBiasOut@GRAD"]);
}
}
if (SrcMask.defined()) {
if (outs0.find("SrcMaskOut@GRAD") != outs0.end()) {
outputs[12] = egr::EagerUtils::GetOutputs(outs0["SrcMaskOut@GRAD"]);
}
}
if (OutLinearBias.defined()) {
if (outs0.find("OutLinearBias@GRAD") != outs0.end()) {
outputs[8] = egr::EagerUtils::GetOutputs(outs0["OutLinearBias@GRAD"]);
}
}
if (pre_layer_norm) {
auto LnScale = egr::EagerUtils::RecoverTensorWrapper(&this->LnScale_);
if (LnScale.defined()) {
if (outs0.find("LnScale@GRAD") != outs0.end()) {
outputs[1] = egr::EagerUtils::GetOutputs(outs0["LnScale@GRAD"]);
}
}
auto LnBias = egr::EagerUtils::RecoverTensorWrapper(&this->LnBias_);
if (LnBias.defined()) {
if (outs0.find("LnBias@GRAD") != outs0.end()) {
outputs[2] = egr::EagerUtils::GetOutputs(outs0["LnBias@GRAD"]);
}
}
auto LnOut = egr::EagerUtils::RecoverTensorWrapper(&this->LnOut_);
if (LnOut.defined()) {
if (outs0.find("LnOut@GRAD") != outs0.end()) {
outputs[13] = egr::EagerUtils::GetOutputs(outs0["LnOut@GRAD"]);
}
}
} else {
auto Ln2Scale = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Scale_);
if (Ln2Scale.defined()) {
if (outs0.find("Ln2Scale@GRAD") != outs0.end()) {
outputs[9] = egr::EagerUtils::GetOutputs(outs0["Ln2Scale@GRAD"]);
}
}
auto Ln2Bias = egr::EagerUtils::RecoverTensorWrapper(&this->Ln2Bias_);
if (Ln2Bias.defined()) {
if (outs0.find("Ln2Bias@GRAD") != outs0.end()) {
outputs[10] = egr::EagerUtils::GetOutputs(outs0["Ln2Bias@GRAD"]);
}
}
if (outs0.find("BiasDropoutResidualOut@GRAD") != outs0.end()) {
outputs[14] =
egr::EagerUtils::GetOutputs(outs0["BiasDropoutResidualOut@GRAD"]);
}
}
if (NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&outputs);
return outputs;
}
...@@ -329,3 +329,205 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase { ...@@ -329,3 +329,205 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase {
paddle::framework::AttributeMap attr_map_; paddle::framework::AttributeMap attr_map_;
paddle::framework::AttributeMap default_attr_map_; paddle::framework::AttributeMap default_attr_map_;
}; };
class fused_attentionGradNodeCompat : public egr::GradNodeBase {
public:
fused_attentionGradNodeCompat() : egr::GradNodeBase() {
VLOG(7) << " Construct fused_attentionGradNodeCompat ";
}
fused_attentionGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {
VLOG(7) << " Construct fused_attentionGradNodeCompat ";
}
~fused_attentionGradNodeCompat() override {
VLOG(6) << " Destruct fused_attentionGradNodeCompat ";
}
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
operator()(
paddle::small_vector<std::vector<paddle::experimental::Tensor>, // NOLINT
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
void ClearTensorWrappers() override {
AttnDropoutMaskOut_.clear();
AttnDropoutOut_.clear();
BiasDropoutResidualOut_.clear();
DropoutMaskOut_.clear();
FMHAOut_.clear();
Ln2Bias_.clear();
Ln2Mean_.clear();
Ln2Scale_.clear();
Ln2Variance_.clear();
OutLinearBias_.clear();
OutLinearOut_.clear();
OutLinearW_.clear();
QKOut_.clear();
QKTVOut_.clear();
QKVBias_.clear();
QKVBiasOut_.clear();
QKVOut_.clear();
QKVW_.clear();
SoftmaxOut_.clear();
SrcMask_.clear();
SrcMaskOut_.clear();
TransposeOut2_.clear();
X_.clear();
SetIsTensorWrappersCleared(true);
}
std::string name() override { return "fused_attentionGradNodeCompat"; }
std::shared_ptr<GradNodeBase> Copy() const override {
{
auto copied_node = std::shared_ptr<fused_attentionGradNodeCompat>(
new fused_attentionGradNodeCompat(*this));
return copied_node;
}
}
// SetX, SetY, ...
void SetTensorWrapperAttnDropoutMaskOut(
const paddle::experimental::Tensor& AttnDropoutMaskOut) {
AttnDropoutMaskOut_ = egr::TensorWrapper(AttnDropoutMaskOut, false);
}
void SetTensorWrapperAttnDropoutOut(
const paddle::experimental::Tensor& AttnDropoutOut) {
AttnDropoutOut_ = egr::TensorWrapper(AttnDropoutOut, false);
}
void SetTensorWrapperBiasDropoutResidualOut(
const paddle::experimental::Tensor& BiasDropoutResidualOut) {
BiasDropoutResidualOut_ = egr::TensorWrapper(BiasDropoutResidualOut, false);
}
void SetTensorWrapperDropoutMaskOut(
const paddle::experimental::Tensor& DropoutMaskOut) {
DropoutMaskOut_ = egr::TensorWrapper(DropoutMaskOut, false);
}
void SetTensorWrapperFMHAOut(const paddle::experimental::Tensor& FMHAOut) {
FMHAOut_ = egr::TensorWrapper(FMHAOut, false);
}
void SetTensorWrapperLn2Bias(const paddle::experimental::Tensor& Ln2Bias) {
Ln2Bias_ = egr::TensorWrapper(Ln2Bias, false);
}
void SetTensorWrapperLn2Mean(const paddle::experimental::Tensor& Ln2Mean) {
Ln2Mean_ = egr::TensorWrapper(Ln2Mean, false);
}
void SetTensorWrapperLn2Scale(const paddle::experimental::Tensor& Ln2Scale) {
Ln2Scale_ = egr::TensorWrapper(Ln2Scale, false);
}
void SetTensorWrapperLn2Variance(
const paddle::experimental::Tensor& Ln2Variance) {
Ln2Variance_ = egr::TensorWrapper(Ln2Variance, false);
}
void SetTensorWrapperOutLinearBias(
const paddle::experimental::Tensor& OutLinearBias) {
OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false);
}
void SetTensorWrapperOutLinearOut(
const paddle::experimental::Tensor& OutLinearOut) {
OutLinearOut_ = egr::TensorWrapper(OutLinearOut, false);
}
void SetTensorWrapperOutLinearW(
const paddle::experimental::Tensor& OutLinearW) {
OutLinearW_ = egr::TensorWrapper(OutLinearW, false);
}
void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) {
QKOut_ = egr::TensorWrapper(QKOut, false);
}
void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) {
QKTVOut_ = egr::TensorWrapper(QKTVOut, false);
}
void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) {
QKVBias_ = egr::TensorWrapper(QKVBias, false);
}
void SetTensorWrapperQKVBiasOut(
const paddle::experimental::Tensor& QKVBiasOut) {
QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, false);
}
void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) {
QKVOut_ = egr::TensorWrapper(QKVOut, false);
}
void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) {
QKVW_ = egr::TensorWrapper(QKVW, false);
}
void SetTensorWrapperSoftmaxOut(
const paddle::experimental::Tensor& SoftmaxOut) {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
}
void SetTensorWrapperSrcMask(const paddle::experimental::Tensor& SrcMask) {
SrcMask_ = egr::TensorWrapper(SrcMask, false);
}
void SetTensorWrapperSrcMaskOut(
const paddle::experimental::Tensor& SrcMaskOut) {
SrcMaskOut_ = egr::TensorWrapper(SrcMaskOut, false);
}
void SetTensorWrapperTransposeOut2(
const paddle::experimental::Tensor& TransposeOut2) {
TransposeOut2_ = egr::TensorWrapper(TransposeOut2, false);
}
void SetTensorWrapperX(const paddle::experimental::Tensor& X) {
X_ = egr::TensorWrapper(X, false);
}
void SetTensorWrapperLnScale(const paddle::experimental::Tensor& LnScale) {
LnScale_ = egr::TensorWrapper(LnScale, false);
}
void SetTensorWrapperLnBias(const paddle::experimental::Tensor& LnBias) {
LnBias_ = egr::TensorWrapper(LnBias, false);
}
void SetTensorWrapperLnOut(const paddle::experimental::Tensor& LnOut) {
LnOut_ = egr::TensorWrapper(LnOut, false);
}
void SetTensorWrapperLnMean(const paddle::experimental::Tensor& LnMean) {
LnMean_ = egr::TensorWrapper(LnMean, false);
}
void SetTensorWrapperLnVariance(
const paddle::experimental::Tensor& LnVariance) {
LnVariance_ = egr::TensorWrapper(LnVariance, false);
}
// SetAttrMap
void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {
attr_map_ = std::move(attr_map);
}
void SetDefaultAttrMap(paddle::framework::AttributeMap&& default_attr_map) {
default_attr_map_ = std::move(default_attr_map);
}
private:
// TensorWrappers
egr::TensorWrapper AttnDropoutMaskOut_;
egr::TensorWrapper AttnDropoutOut_;
egr::TensorWrapper BiasDropoutResidualOut_;
egr::TensorWrapper DropoutMaskOut_;
egr::TensorWrapper FMHAOut_;
egr::TensorWrapper Ln2Bias_;
egr::TensorWrapper Ln2Mean_;
egr::TensorWrapper Ln2Scale_;
egr::TensorWrapper Ln2Variance_;
egr::TensorWrapper OutLinearBias_;
egr::TensorWrapper OutLinearOut_;
egr::TensorWrapper OutLinearW_;
egr::TensorWrapper QKOut_;
egr::TensorWrapper QKTVOut_;
egr::TensorWrapper QKVBias_;
egr::TensorWrapper QKVBiasOut_;
egr::TensorWrapper QKVOut_;
egr::TensorWrapper QKVW_;
egr::TensorWrapper SoftmaxOut_;
egr::TensorWrapper SrcMask_;
egr::TensorWrapper SrcMaskOut_;
egr::TensorWrapper TransposeOut2_;
egr::TensorWrapper X_;
egr::TensorWrapper LnScale_;
egr::TensorWrapper LnBias_;
egr::TensorWrapper LnOut_;
egr::TensorWrapper LnMean_;
egr::TensorWrapper LnVariance_;
// Attribute Map
paddle::framework::AttributeMap attr_map_;
paddle::framework::AttributeMap default_attr_map_;
};
...@@ -51,8 +51,10 @@ static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = { ...@@ -51,8 +51,10 @@ static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = {
"split", "rnn"}; "split", "rnn"};
/* --- Black Ops list that's NO NEED to apply code generation --- */ /* --- Black Ops list that's NO NEED to apply code generation --- */
static std::unordered_set<std::string> black_ops_list = { static std::unordered_set<std::string> black_ops_list = {"run_program",
"run_program", "fused_gate_attention", "fused_feedforward"}; "fused_gate_attention",
"fused_feedforward",
"fused_attention"};
static std::string LegalizeVariableName(const std::string& var_name) { static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name; std::string ret = var_name;
......
...@@ -26,9 +26,7 @@ from paddle import tensor ...@@ -26,9 +26,7 @@ from paddle import tensor
from paddle.fluid import layers from paddle.fluid import layers
import unittest import unittest
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph from paddle.fluid.framework import default_main_program
_enable_legacy_dygraph()
default_main_program().random_seed = 42 default_main_program().random_seed = 42
......
...@@ -26,11 +26,9 @@ import unittest ...@@ -26,11 +26,9 @@ import unittest
from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float
from test_sparse_attention_op import get_cuda_version from test_sparse_attention_op import get_cuda_version
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import default_main_program, _enable_legacy_dygraph from paddle.fluid.framework import default_main_program
from paddle.fluid import core from paddle.fluid import core
_enable_legacy_dygraph()
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"Paddle is not compiled with CUDA") "Paddle is not compiled with CUDA")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册