未验证 提交 73f957cf 编写于 作者: W wanghuancoder 提交者: GitHub

fused_gate_attention manual code in eager (#43897)

* fused_gate_attention manual code in eager

* refine

* refine

* refine

* refine

* refine

* refine
上级 842f363d
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h"
#include "paddle/fluid/framework/convert_utils.h"
namespace egr {
......
add_subdirectory(manual)
add_subdirectory(utils)
add_subdirectory(generated)
......
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(fluid_manual)
set(fluid_manual_functions
${fluid_manual_functions}
PARENT_SCOPE)
set(fluid_manual_nodes
${fluid_manual_nodes}
PARENT_SCOPE)
endif()
add_subdirectory(forwards)
add_subdirectory(nodes)
set(fluid_manual_functions
${fluid_manual_functions}
PARENT_SCOPE)
set(fluid_manual_nodes
${fluid_manual_nodes}
PARENT_SCOPE)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/phi/api/all.h"
std::tuple<paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor>
fused_gate_attention_dygraph_function(
const paddle::experimental::Tensor& Query,
const paddle::experimental::Tensor& Key,
const paddle::experimental::Tensor& QueryWeight,
const paddle::experimental::Tensor& KeyWeight,
const paddle::experimental::Tensor& ValueWeight,
const paddle::experimental::Tensor& QKVWeight,
const paddle::experimental::Tensor& NonbatchedBias,
const paddle::experimental::Tensor& SrcMask,
const paddle::experimental::Tensor& GateWeight,
const paddle::experimental::Tensor& GateBias,
const paddle::experimental::Tensor& OutLinearWeight,
const paddle::experimental::Tensor& OutLinearBias,
const paddle::framework::AttributeMap& attr_map);
cc_library(
fused_gate_attention_fwd_func
SRCS fused_gate_attention_fwd_func.cc
DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
add_dependencies(fused_gate_attention_fwd_func eager_codegen)
set(fluid_manual_functions
fused_gate_attention_fwd_func
PARENT_SCOPE)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/amp_auto_cast.h"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#pragma GCC diagnostic ignored "-Wunused-variable"
std::tuple<paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor,
paddle::experimental::Tensor>
fused_gate_attention_dygraph_function(
const paddle::experimental::Tensor& Query,
const paddle::experimental::Tensor& Key,
const paddle::experimental::Tensor& QueryWeight,
const paddle::experimental::Tensor& KeyWeight,
const paddle::experimental::Tensor& ValueWeight,
const paddle::experimental::Tensor& QKVWeight,
const paddle::experimental::Tensor& NonbatchedBias,
const paddle::experimental::Tensor& SrcMask,
const paddle::experimental::Tensor& GateWeight,
const paddle::experimental::Tensor& GateBias,
const paddle::experimental::Tensor& OutLinearWeight,
const paddle::experimental::Tensor& OutLinearBias,
const paddle::framework::AttributeMap& attr_map) {
paddle::platform::RecordEvent dygraph_entrance_record_event(
"fused_gate_attention dygraph",
paddle::platform::TracerEventType::Operator,
1);
VLOG(3) << "Running Eager Forward Op: fused_gate_attention";
// Dygraph Forward Pass
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
VLOG(5) << "Check and Prepare For AMP";
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
amp_tensors_vector = {
{Query}, {SrcMask}, {OutLinearWeight}, {OutLinearBias}};
if (Key.initialized()) amp_tensors_vector.push_back({Key});
if (QueryWeight.initialized()) amp_tensors_vector.push_back({QueryWeight});
if (KeyWeight.initialized()) amp_tensors_vector.push_back({KeyWeight});
if (ValueWeight.initialized()) amp_tensors_vector.push_back({ValueWeight});
if (QKVWeight.initialized()) amp_tensors_vector.push_back({QKVWeight});
if (NonbatchedBias.initialized())
amp_tensors_vector.push_back({NonbatchedBias});
if (GateWeight.initialized()) amp_tensors_vector.push_back({GateWeight});
if (GateBias.initialized()) amp_tensors_vector.push_back({GateBias});
auto amp_dst_dtype =
egr::GetAmpDestDtype("fused_gate_attention", amp_tensors_vector);
auto NEW_Query =
egr::AmpAutoCast("Query", Query, amp_dst_dtype, "fused_gate_attention");
auto NEW_SrcMask = egr::AmpAutoCast(
"SrcMask", SrcMask, amp_dst_dtype, "fused_gate_attention");
auto NEW_OutLinearWeight = egr::AmpAutoCast("OutLinearWeight",
OutLinearWeight,
amp_dst_dtype,
"fused_gate_attention");
auto NEW_OutLinearBias = egr::AmpAutoCast(
"OutLinearBias", OutLinearBias, amp_dst_dtype, "fused_gate_attention");
auto NEW_Key = ((Key.initialized())
? egr::AmpAutoCast(
"Key", Key, amp_dst_dtype, "fused_gate_attention")
: Key);
auto NEW_QueryWeight =
((QueryWeight.initialized()) ? egr::AmpAutoCast("QueryWeight",
QueryWeight,
amp_dst_dtype,
"fused_gate_attention")
: QueryWeight);
auto NEW_KeyWeight =
((KeyWeight.initialized()) ? egr::AmpAutoCast("KeyWeight",
KeyWeight,
amp_dst_dtype,
"fused_gate_attention")
: KeyWeight);
auto NEW_ValueWeight =
((ValueWeight.initialized()) ? egr::AmpAutoCast("ValueWeight",
ValueWeight,
amp_dst_dtype,
"fused_gate_attention")
: ValueWeight);
auto NEW_QKVWeight =
((QKVWeight.initialized()) ? egr::AmpAutoCast("QKVWeight",
QKVWeight,
amp_dst_dtype,
"fused_gate_attention")
: QKVWeight);
auto NEW_NonbatchedBias = ((NonbatchedBias.initialized())
? egr::AmpAutoCast("NonbatchedBias",
NonbatchedBias,
amp_dst_dtype,
"fused_gate_attention")
: NonbatchedBias);
auto NEW_GateWeight =
((GateWeight.initialized()) ? egr::AmpAutoCast("GateWeight",
GateWeight,
amp_dst_dtype,
"fused_gate_attention")
: GateWeight);
auto NEW_GateBias =
((GateBias.initialized())
? egr::AmpAutoCast(
"GateBias", GateBias, amp_dst_dtype, "fused_gate_attention")
: GateBias);
{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentTracer(),
paddle::imperative::AmpLevel::O0);
return fused_gate_attention_dygraph_function(NEW_Query,
NEW_Key,
NEW_QueryWeight,
NEW_KeyWeight,
NEW_ValueWeight,
NEW_QKVWeight,
NEW_NonbatchedBias,
NEW_SrcMask,
NEW_GateWeight,
NEW_GateBias,
NEW_OutLinearWeight,
NEW_OutLinearBias,
attr_map);
}
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins =
{{"Query", egr::EagerUtils::TrySyncToVars(Query)},
{"SrcMask", egr::EagerUtils::TrySyncToVars(SrcMask)},
{"OutLinearWeight", egr::EagerUtils::TrySyncToVars(OutLinearWeight)},
{"OutLinearBias", egr::EagerUtils::TrySyncToVars(OutLinearBias)}};
if (Key.initialized()) ins["Key"] = egr::EagerUtils::TrySyncToVars(Key);
if (QueryWeight.initialized())
ins["QueryWeight"] = egr::EagerUtils::TrySyncToVars(QueryWeight);
if (KeyWeight.initialized())
ins["KeyWeight"] = egr::EagerUtils::TrySyncToVars(KeyWeight);
if (ValueWeight.initialized())
ins["ValueWeight"] = egr::EagerUtils::TrySyncToVars(ValueWeight);
if (QKVWeight.initialized())
ins["QKVWeight"] = egr::EagerUtils::TrySyncToVars(QKVWeight);
if (NonbatchedBias.initialized())
ins["NonbatchedBias"] = egr::EagerUtils::TrySyncToVars(NonbatchedBias);
if (GateWeight.initialized())
ins["GateWeight"] = egr::EagerUtils::TrySyncToVars(GateWeight);
if (GateBias.initialized())
ins["GateBias"] = egr::EagerUtils::TrySyncToVars(GateBias);
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> outs =
{{"QueryTransposeOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"KeyTransposeOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"ValueTransposeOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"QKVTransposeOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"SoftmaxOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"FMHAOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"GateOut",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}},
{"Out",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}}};
// Prepare Autograd Meta
egr::AutogradMeta* p_autograd_Query =
egr::EagerUtils::nullable_autograd_meta(Query);
egr::AutogradMeta* p_autograd_Key =
egr::EagerUtils::nullable_autograd_meta(Key);
egr::AutogradMeta* p_autograd_QueryWeight =
egr::EagerUtils::nullable_autograd_meta(QueryWeight);
egr::AutogradMeta* p_autograd_KeyWeight =
egr::EagerUtils::nullable_autograd_meta(KeyWeight);
egr::AutogradMeta* p_autograd_ValueWeight =
egr::EagerUtils::nullable_autograd_meta(ValueWeight);
egr::AutogradMeta* p_autograd_QKVWeight =
egr::EagerUtils::nullable_autograd_meta(QKVWeight);
egr::AutogradMeta* p_autograd_NonbatchedBias =
egr::EagerUtils::nullable_autograd_meta(NonbatchedBias);
egr::AutogradMeta* p_autograd_SrcMask =
egr::EagerUtils::nullable_autograd_meta(SrcMask);
egr::AutogradMeta* p_autograd_GateWeight =
egr::EagerUtils::nullable_autograd_meta(GateWeight);
egr::AutogradMeta* p_autograd_GateBias =
egr::EagerUtils::nullable_autograd_meta(GateBias);
egr::AutogradMeta* p_autograd_OutLinearWeight =
egr::EagerUtils::nullable_autograd_meta(OutLinearWeight);
egr::AutogradMeta* p_autograd_OutLinearBias =
egr::EagerUtils::nullable_autograd_meta(OutLinearBias);
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad =
egr::EagerUtils::ComputeRequireGrad(trace_backward,
p_autograd_Query,
p_autograd_Key,
p_autograd_QueryWeight,
p_autograd_KeyWeight,
p_autograd_ValueWeight,
p_autograd_QKVWeight,
p_autograd_NonbatchedBias,
p_autograd_SrcMask,
p_autograd_GateWeight,
p_autograd_GateBias,
p_autograd_OutLinearWeight,
p_autograd_OutLinearBias);
paddle::framework::AttributeMap attrs = attr_map;
paddle::framework::AttributeMap default_attrs;
egr::Controller::Instance().GetCurrentTracer()->TraceOp(
"fused_gate_attention",
ins,
outs,
attrs,
egr::Controller::Instance().GetExpectedPlace(),
&default_attrs,
true,
{});
paddle::experimental::Tensor QueryTransposeOut;
egr::EagerUtils::GetOutput(outs["QueryTransposeOut"][0], &QueryTransposeOut);
paddle::experimental::Tensor KeyTransposeOut;
egr::EagerUtils::GetOutput(outs["KeyTransposeOut"][0], &KeyTransposeOut);
paddle::experimental::Tensor ValueTransposeOut;
egr::EagerUtils::GetOutput(outs["ValueTransposeOut"][0], &ValueTransposeOut);
paddle::experimental::Tensor QKVTransposeOut;
egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut);
paddle::experimental::Tensor SoftmaxOut;
egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut);
paddle::experimental::Tensor FMHAOut;
egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut);
paddle::experimental::Tensor GateOut;
egr::EagerUtils::GetOutput(outs["GateOut"][0], &GateOut);
paddle::experimental::Tensor Out;
egr::EagerUtils::GetOutput(outs["Out"][0], &Out);
{
paddle::platform::RecordEvent node_creation_record_event(
"fused_gate_attention node_creation",
paddle::platform::TracerEventType::Operator,
1);
egr::AutogradMeta* p_autograd_QueryTransposeOut =
egr::EagerUtils::autograd_meta(&QueryTransposeOut);
egr::AutogradMeta* p_autograd_KeyTransposeOut =
egr::EagerUtils::autograd_meta(&KeyTransposeOut);
egr::AutogradMeta* p_autograd_ValueTransposeOut =
egr::EagerUtils::autograd_meta(&ValueTransposeOut);
egr::AutogradMeta* p_autograd_QKVTransposeOut =
egr::EagerUtils::autograd_meta(&QKVTransposeOut);
egr::AutogradMeta* p_autograd_SoftmaxOut =
egr::EagerUtils::autograd_meta(&SoftmaxOut);
egr::AutogradMeta* p_autograd_FMHAOut =
egr::EagerUtils::autograd_meta(&FMHAOut);
egr::AutogradMeta* p_autograd_GateOut =
egr::EagerUtils::autograd_meta(&GateOut);
egr::AutogradMeta* p_autograd_Out = egr::EagerUtils::autograd_meta(&Out);
if (require_any_grad) {
VLOG(6) << " Construct Grad for fused_gate_attention ";
egr::EagerUtils::PassStopGradient(false,
p_autograd_QueryTransposeOut,
p_autograd_KeyTransposeOut,
p_autograd_ValueTransposeOut,
p_autograd_QKVTransposeOut,
p_autograd_SoftmaxOut,
p_autograd_FMHAOut,
p_autograd_GateOut,
p_autograd_Out);
// Create GradOpNode
auto grad_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>(
new fused_gate_attentionGradNodeCompat(8, 12));
bool merge_qkv = true;
if (attrs.count("merge_qkv")) {
merge_qkv = BOOST_GET_CONST(bool, attrs.at("merge_qkv"));
}
bool has_gating = true;
if (attrs.count("has_gating")) {
has_gating = BOOST_GET_CONST(bool, attrs.at("has_gating"));
}
// Set Attributes
grad_node->SetAttrMap(std::move(attrs));
grad_node->SetDefaultAttrMap(std::move(default_attrs));
grad_node->SetTensorWrapperFMHAOut(FMHAOut);
grad_node->SetTensorWrapperQuery(Query);
grad_node->SetTensorWrapperSoftmaxOut(SoftmaxOut);
grad_node->SetTensorWrapperOutLinearBias(OutLinearBias);
grad_node->SetTensorWrapperOutLinearWeight(OutLinearWeight);
grad_node->SetGradOutMeta(Query, 0);
grad_node->SetGradOutMeta(OutLinearWeight, 10);
grad_node->SetGradOutMeta(OutLinearBias, 11);
if (merge_qkv) {
grad_node->SetTensorWrapperQKVTransposeOut(QKVTransposeOut);
grad_node->SetTensorWrapperQKVWeight(QKVWeight);
grad_node->SetGradOutMeta(QKVWeight, 5);
} else {
grad_node->SetTensorWrapperKey(Key);
grad_node->SetTensorWrapperQueryWeight(QueryWeight);
grad_node->SetTensorWrapperKeyWeight(KeyWeight);
grad_node->SetTensorWrapperValueWeight(ValueWeight);
grad_node->SetTensorWrapperQueryTransposeOut(QueryTransposeOut);
grad_node->SetTensorWrapperKeyTransposeOut(KeyTransposeOut);
grad_node->SetTensorWrapperValueTransposeOut(ValueTransposeOut);
grad_node->SetGradOutMeta(Key, 1);
grad_node->SetGradOutMeta(QueryWeight, 2);
grad_node->SetGradOutMeta(KeyWeight, 3);
grad_node->SetGradOutMeta(ValueWeight, 4);
}
if (has_gating) {
grad_node->SetTensorWrapperGateWeight(GateWeight);
grad_node->SetGradOutMeta(GateWeight, 8);
grad_node->SetTensorWrapperGateBias(GateBias);
grad_node->SetGradOutMeta(GateBias, 9);
grad_node->SetTensorWrapperGateOut(GateOut);
}
if (NonbatchedBias.initialized()) {
grad_node->SetTensorWrapperNonbatchedBias(NonbatchedBias);
grad_node->SetGradOutMeta(NonbatchedBias, 6);
}
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0);
grad_node->SetGradInMeta(QueryTransposeOut, 0);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1);
grad_node->SetGradInMeta(KeyTransposeOut, 1);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_ValueTransposeOut, 2);
grad_node->SetGradInMeta(ValueTransposeOut, 2);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVTransposeOut, 3);
grad_node->SetGradInMeta(QKVTransposeOut, 3);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_SoftmaxOut, 4);
grad_node->SetGradInMeta(SoftmaxOut, 4);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_FMHAOut, 5);
grad_node->SetGradInMeta(FMHAOut, 5);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_GateOut, 6);
grad_node->SetGradInMeta(GateOut, 6);
egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 7);
egr::EagerUtils::SetHistory(p_autograd_Out, grad_node);
grad_node->SetGradInMeta(Out, 7);
egr::EagerUtils::CheckAndRetainGrad(Out);
}
}
return std::make_tuple(QueryTransposeOut,
KeyTransposeOut,
ValueTransposeOut,
QKVTransposeOut,
SoftmaxOut,
FMHAOut,
GateOut,
Out);
}
cc_library(
fused_gate_attention_node
SRCS fused_gate_attention_node.cc
DEPS ${eager_deps} ${fluid_deps})
set(fluid_manual_nodes
fused_gate_attention_node
PARENT_SCOPE)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/phi/api/all.h"
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
fused_gate_attentionGradNodeCompat::operator()(
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>& grads,
bool create_graph,
bool is_new_grad) {
VLOG(3) << "Running Eager Backward Node: fused_gate_attentionGradNodeCompat";
const auto& out_metas = OutputMeta();
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
outputs(12);
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
hooked_grads0 =
fused_gate_attentionGradNodeCompat::ApplyGradientHooks(grads);
bool merge_qkv = true;
if (attr_map_.count("merge_qkv")) {
merge_qkv = BOOST_GET_CONST(bool, attr_map_.at("merge_qkv"));
}
bool has_gating = true;
if (attr_map_.count("has_gating")) {
has_gating = BOOST_GET_CONST(bool, attr_map_.at("has_gating"));
}
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> ins0 =
{{"FMHAOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->FMHAOut_))},
{"Out@GRAD", egr::EagerUtils::TrySyncToVars(hooked_grads0[7])},
{"OutLinearBias",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearBias_))},
{"OutLinearWeight",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearWeight_))},
{"Query",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->Query_))},
{"SoftmaxOut",
egr::EagerUtils::TrySyncToVars(
egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxOut_))}};
std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>> outs0;
if ((!out_metas[11].empty()) && (!(out_metas[11][0].IsStopGradient()))) {
outs0.insert({"OutLinearBias@GRAD",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}});
}
if ((!out_metas[10].empty()) && (!(out_metas[10][0].IsStopGradient()))) {
outs0.insert({"OutLinearWeight@GRAD",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}});
}
if ((!out_metas[0].empty()) && (!(out_metas[0][0].IsStopGradient()))) {
outs0.insert({"Query@GRAD",
{std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())}});
}
if (merge_qkv) {
auto QKVTransposeOut =
egr::EagerUtils::RecoverTensorWrapper(&this->QKVTransposeOut_);
if (QKVTransposeOut.defined())
ins0["QKVTransposeOut"] = egr::EagerUtils::TrySyncToVars(QKVTransposeOut);
auto QKVWeight = egr::EagerUtils::RecoverTensorWrapper(&this->QKVWeight_);
if (QKVWeight.defined())
ins0["QKVWeight"] = egr::EagerUtils::TrySyncToVars(QKVWeight);
if (QKVWeight.defined() && (!out_metas[5].empty()) &&
(!out_metas[5][0].IsStopGradient()))
outs0["QKVWeight@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
} else {
auto Key = egr::EagerUtils::RecoverTensorWrapper(&this->Key_);
if (Key.defined()) ins0["Key"] = egr::EagerUtils::TrySyncToVars(Key);
auto QueryWeight =
egr::EagerUtils::RecoverTensorWrapper(&this->QueryWeight_);
if (QueryWeight.defined())
ins0["QueryWeight"] = egr::EagerUtils::TrySyncToVars(QueryWeight);
auto KeyWeight = egr::EagerUtils::RecoverTensorWrapper(&this->KeyWeight_);
if (KeyWeight.defined())
ins0["KeyWeight"] = egr::EagerUtils::TrySyncToVars(KeyWeight);
auto ValueWeight =
egr::EagerUtils::RecoverTensorWrapper(&this->ValueWeight_);
if (ValueWeight.defined())
ins0["ValueWeight"] = egr::EagerUtils::TrySyncToVars(ValueWeight);
auto QueryTransposeOut =
egr::EagerUtils::RecoverTensorWrapper(&this->QueryTransposeOut_);
if (QueryTransposeOut.defined())
ins0["QueryTransposeOut"] =
egr::EagerUtils::TrySyncToVars(QueryTransposeOut);
auto KeyTransposeOut =
egr::EagerUtils::RecoverTensorWrapper(&this->KeyTransposeOut_);
if (KeyTransposeOut.defined())
ins0["KeyTransposeOut"] = egr::EagerUtils::TrySyncToVars(KeyTransposeOut);
auto ValueTransposeOut =
egr::EagerUtils::RecoverTensorWrapper(&this->ValueTransposeOut_);
if (ValueTransposeOut.defined())
ins0["ValueTransposeOut"] =
egr::EagerUtils::TrySyncToVars(ValueTransposeOut);
if (Key.defined() && (!out_metas[1].empty()) &&
(!out_metas[1][0].IsStopGradient()))
outs0["Key@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
if (QueryWeight.defined() && (!out_metas[2].empty()) &&
(!out_metas[2][0].IsStopGradient()))
outs0["QueryWeight@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
if (KeyWeight.defined() && (!out_metas[3].empty()) &&
(!out_metas[3][0].IsStopGradient()))
outs0["KeyWeight@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
if (ValueWeight.defined() && (!out_metas[4].empty()) &&
(!out_metas[4][0].IsStopGradient()))
outs0["ValueWeight@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
if (has_gating) {
auto GateBias = egr::EagerUtils::RecoverTensorWrapper(&this->GateBias_);
if (GateBias.defined())
ins0["GateBias"] = egr::EagerUtils::TrySyncToVars(GateBias);
auto GateWeight = egr::EagerUtils::RecoverTensorWrapper(&this->GateWeight_);
if (GateWeight.defined())
ins0["GateWeight"] = egr::EagerUtils::TrySyncToVars(GateWeight);
auto GateOut = egr::EagerUtils::RecoverTensorWrapper(&this->GateOut_);
if (GateOut.defined())
ins0["GateOut"] = egr::EagerUtils::TrySyncToVars(GateOut);
if (GateBias.defined() && (!out_metas[9].empty()) &&
(!out_metas[9][0].IsStopGradient()))
outs0["GateBias@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
if (GateWeight.defined() && (!out_metas[8].empty()) &&
(!out_metas[8][0].IsStopGradient()))
outs0["GateWeight@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto NonbatchedBias =
egr::EagerUtils::RecoverTensorWrapper(&this->NonbatchedBias_);
if (NonbatchedBias.defined()) {
ins0["NonbatchedBias"] = egr::EagerUtils::TrySyncToVars(NonbatchedBias);
if ((!out_metas[6].empty()) && (!out_metas[6][0].IsStopGradient()))
outs0["NonbatchedBias@GRAD"] = {std::make_shared<egr::EagerVariable>(
egr::Controller::Instance().GenerateUniqueName())};
}
auto& attrs_map0 = this->attr_map_;
// Pass the entire attribute map to TraceOp
// The underlying kernel will pickup whatever attribute they need at runtime
egr::Controller::Instance().GetCurrentTracer()->TraceOp(
"fused_gate_attention_grad",
ins0,
outs0,
attrs_map0,
egr::Controller::Instance().GetExpectedPlace(),
&this->default_attr_map_,
false,
{});
if (outs0.find("Query@GRAD") != outs0.end()) {
outputs[0] = egr::EagerUtils::GetOutputs(outs0["Query@GRAD"]);
}
if (outs0.find("OutLinearBias@GRAD") != outs0.end()) {
outputs[11] = egr::EagerUtils::GetOutputs(outs0["OutLinearBias@GRAD"]);
}
if (outs0.find("OutLinearWeight@GRAD") != outs0.end()) {
outputs[10] = egr::EagerUtils::GetOutputs(outs0["OutLinearWeight@GRAD"]);
}
if (merge_qkv) {
if (outs0.find("QKVWeight@GRAD") != outs0.end()) {
outputs[5] = egr::EagerUtils::GetOutputs(outs0["QKVWeight@GRAD"]);
}
} else {
if (outs0.find("Key@GRAD") != outs0.end()) {
outputs[1] = egr::EagerUtils::GetOutputs(outs0["Key@GRAD"]);
}
if (outs0.find("QueryWeight@GRAD") != outs0.end()) {
outputs[2] = egr::EagerUtils::GetOutputs(outs0["QueryWeight@GRAD"]);
}
if (outs0.find("KeyWeight@GRAD") != outs0.end()) {
outputs[3] = egr::EagerUtils::GetOutputs(outs0["KeyWeight@GRAD"]);
}
if (outs0.find("ValueWeight@GRAD") != outs0.end()) {
outputs[4] = egr::EagerUtils::GetOutputs(outs0["ValueWeight@GRAD"]);
}
}
if (has_gating) {
if (outs0.find("GateBias@GRAD") != outs0.end()) {
outputs[9] = egr::EagerUtils::GetOutputs(outs0["GateBias@GRAD"]);
}
if (outs0.find("GateWeight@GRAD") != outs0.end()) {
outputs[8] = egr::EagerUtils::GetOutputs(outs0["GateWeight@GRAD"]);
}
}
if (NonbatchedBias.defined()) {
if (outs0.find("NonbatchedBias@GRAD") != outs0.end()) {
outputs[6] = egr::EagerUtils::GetOutputs(outs0["NonbatchedBias@GRAD"]);
}
}
if (NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&outputs);
return outputs;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/imperative/tracer.h"
class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
public:
fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() {
VLOG(7) << " Construct fused_gate_attentionGradNodeCompat ";
}
fused_gate_attentionGradNodeCompat(size_t bwd_in_slot_num,
size_t bwd_out_slot_num)
: egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {
VLOG(7) << " Construct fused_gate_attentionGradNodeCompat ";
}
~fused_gate_attentionGradNodeCompat() override {
VLOG(6) << " Destruct fused_gate_attentionGradNodeCompat ";
}
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
operator()(
paddle::small_vector<std::vector<paddle::experimental::Tensor>, // NOLINT
egr::kSlotSmallVectorSize>& grads, // NOLINT
bool create_graph = false,
bool is_new_grad = false) override;
void ClearTensorWrappers() override {
FMHAOut_.clear();
GateBias_.clear();
GateOut_.clear();
GateWeight_.clear();
NonbatchedBias_.clear();
OutLinearBias_.clear();
OutLinearWeight_.clear();
QKVTransposeOut_.clear();
QKVWeight_.clear();
Query_.clear();
SoftmaxOut_.clear();
Key_.clear();
QueryWeight_.clear();
KeyWeight_.clear();
ValueWeight_.clear();
QueryTransposeOut_.clear();
KeyTransposeOut_.clear();
ValueTransposeOut_.clear();
SetIsTensorWrappersCleared(true);
}
std::string name() override { return "fused_gate_attentionGradNodeCompat"; }
std::shared_ptr<GradNodeBase> Copy() const override {
{
auto copied_node = std::shared_ptr<fused_gate_attentionGradNodeCompat>(
new fused_gate_attentionGradNodeCompat(*this));
return copied_node;
}
}
// SetX, SetY, ...
void SetTensorWrapperFMHAOut(const paddle::experimental::Tensor& FMHAOut) {
FMHAOut_ = egr::TensorWrapper(FMHAOut, false);
}
void SetTensorWrapperGateBias(const paddle::experimental::Tensor& GateBias) {
GateBias_ = egr::TensorWrapper(GateBias, false);
}
void SetTensorWrapperGateOut(const paddle::experimental::Tensor& GateOut) {
GateOut_ = egr::TensorWrapper(GateOut, false);
}
void SetTensorWrapperGateWeight(
const paddle::experimental::Tensor& GateWeight) {
GateWeight_ = egr::TensorWrapper(GateWeight, false);
}
void SetTensorWrapperNonbatchedBias(
const paddle::experimental::Tensor& NonbatchedBias) {
NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false);
}
void SetTensorWrapperOutLinearBias(
const paddle::experimental::Tensor& OutLinearBias) {
OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false);
}
void SetTensorWrapperOutLinearWeight(
const paddle::experimental::Tensor& OutLinearWeight) {
OutLinearWeight_ = egr::TensorWrapper(OutLinearWeight, false);
}
void SetTensorWrapperQKVTransposeOut(
const paddle::experimental::Tensor& QKVTransposeOut) {
QKVTransposeOut_ = egr::TensorWrapper(QKVTransposeOut, false);
}
void SetTensorWrapperQKVWeight(
const paddle::experimental::Tensor& QKVWeight) {
QKVWeight_ = egr::TensorWrapper(QKVWeight, false);
}
void SetTensorWrapperQuery(const paddle::experimental::Tensor& Query) {
Query_ = egr::TensorWrapper(Query, false);
}
void SetTensorWrapperSoftmaxOut(
const paddle::experimental::Tensor& SoftmaxOut) {
SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false);
}
void SetTensorWrapperKey(const paddle::experimental::Tensor& Key) {
Key_ = egr::TensorWrapper(Key, false);
}
void SetTensorWrapperQueryWeight(
const paddle::experimental::Tensor& QueryWeight) {
QueryWeight_ = egr::TensorWrapper(QueryWeight, false);
}
void SetTensorWrapperKeyWeight(
const paddle::experimental::Tensor& KeyWeight) {
KeyWeight_ = egr::TensorWrapper(KeyWeight, false);
}
void SetTensorWrapperValueWeight(
const paddle::experimental::Tensor& ValueWeight) {
ValueWeight_ = egr::TensorWrapper(ValueWeight, false);
}
void SetTensorWrapperQueryTransposeOut(
const paddle::experimental::Tensor& QueryTransposeOut) {
QueryTransposeOut_ = egr::TensorWrapper(QueryTransposeOut, false);
}
void SetTensorWrapperKeyTransposeOut(
const paddle::experimental::Tensor& KeyTransposeOut) {
KeyTransposeOut_ = egr::TensorWrapper(KeyTransposeOut, false);
}
void SetTensorWrapperValueTransposeOut(
const paddle::experimental::Tensor& ValueTransposeOut) {
ValueTransposeOut_ = egr::TensorWrapper(ValueTransposeOut, false);
}
// SetAttrMap
void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {
attr_map_ = std::move(attr_map);
}
void SetDefaultAttrMap(paddle::framework::AttributeMap&& default_attr_map) {
default_attr_map_ = std::move(default_attr_map);
}
private:
// TensorWrappers
egr::TensorWrapper FMHAOut_;
egr::TensorWrapper GateBias_;
egr::TensorWrapper GateOut_;
egr::TensorWrapper GateWeight_;
egr::TensorWrapper NonbatchedBias_;
egr::TensorWrapper OutLinearBias_;
egr::TensorWrapper OutLinearWeight_;
egr::TensorWrapper QKVTransposeOut_;
egr::TensorWrapper QKVWeight_;
egr::TensorWrapper Query_;
egr::TensorWrapper SoftmaxOut_;
egr::TensorWrapper Key_;
egr::TensorWrapper QueryWeight_;
egr::TensorWrapper KeyWeight_;
egr::TensorWrapper ValueWeight_;
egr::TensorWrapper QueryTransposeOut_;
egr::TensorWrapper KeyTransposeOut_;
egr::TensorWrapper ValueTransposeOut_;
// Attribute Map
paddle::framework::AttributeMap attr_map_;
paddle::framework::AttributeMap default_attr_map_;
};
......@@ -51,7 +51,8 @@ static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = {
"split", "rnn"};
/* --- Black Ops list that's NO NEED to apply code generation --- */
static std::unordered_set<std::string> black_ops_list = {"run_program"};
static std::unordered_set<std::string> black_ops_list = {
"run_program", "fused_gate_attention"};
static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name;
......@@ -2972,7 +2973,10 @@ static std::string GenerateDygraphHFileIncludes() {
"#include \"paddle/phi/api/all.h\"\n"
"#include \"paddle/fluid/eager/utils.h\"\n"
"#include \"paddle/fluid/imperative/tracer.h\"\n"
"#include \"paddle/fluid/framework/op_registry.h\"\n\n";
"#include \"paddle/fluid/framework/op_registry.h\"\n"
"#include "
"\"paddle/fluid/eager/api/manual/fluid_manual/"
"dygraph_forward_api.h\"\n\n";
dygraph_forward_api_includes_str +=
"extern std::unordered_map<std::string, std::vector<std::string>> "
......@@ -3021,7 +3025,10 @@ static void GenerateNodeHFile(const std::string& node_h_path,
"#pragma once\n"
"#include \"paddle/fluid/eager/tensor_wrapper.h\"\n"
"#include \"paddle/fluid/imperative/tracer.h\"\n"
"#include \"paddle/fluid/eager/grad_node_info.h\"\n\n";
"#include \"paddle/fluid/eager/grad_node_info.h\"\n"
"#include "
"\"paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h\"\n\n";
std::ofstream node_h_stream(node_h_path, std::ios::out);
node_h_stream << node_h_include_str;
node_h_stream << grad_node_str;
......
......@@ -103,13 +103,13 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir):
with open(nodes_level_cmakelist_path, "w") as f:
f.write(
"cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps})\n"
"cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps} ${fluid_manual_nodes})\n"
)
f.write("add_dependencies(dygraph_node eager_codegen)")
with open(forwards_level_cmakelist_path, "w") as f:
f.write(
"cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})\n"
"cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ${fluid_manual_functions})\n"
)
f.write("add_dependencies(dygraph_function eager_codegen)")
......
......@@ -363,7 +363,8 @@ class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
if (merge_qkv) {
PADDLE_ENFORCE_EQ(!key || query == key,
PADDLE_ENFORCE_EQ(
!key || query == key || query->data<T>() == key->data<T>(),
true,
platform::errors::InvalidArgument(
"key is expected to be nullptr or the same as "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册