From 73f957cf56e9ee7fea5bb338adc91bc224daf1ce Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 30 Jun 2022 19:35:26 +0800 Subject: [PATCH] fused_gate_attention manual code in eager (#43897) * fused_gate_attention manual code in eager * refine * refine * refine * refine * refine * refine --- paddle/fluid/eager/amp_auto_cast.h | 1 + paddle/fluid/eager/api/CMakeLists.txt | 1 + paddle/fluid/eager/api/manual/CMakeLists.txt | 9 + .../api/manual/fluid_manual/CMakeLists.txt | 8 + .../manual/fluid_manual/dygraph_forward_api.h | 44 ++ .../fluid_manual/forwards/CMakeLists.txt | 10 + .../forwards/fused_gate_attention_fwd_func.cc | 389 ++++++++++++++++++ .../manual/fluid_manual/nodes/CMakeLists.txt | 8 + .../nodes/fused_gate_attention_node.cc | 233 +++++++++++ .../api/manual/fluid_manual/nodes/nodes.h | 176 ++++++++ .../auto_code_generator/eager_generator.cc | 13 +- .../generate_file_structures.py | 10 +- .../fused/fused_gate_attention_op.cu | 15 +- 13 files changed, 902 insertions(+), 15 deletions(-) create mode 100644 paddle/fluid/eager/api/manual/CMakeLists.txt create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/CMakeLists.txt create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc create mode 100644 paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h diff --git a/paddle/fluid/eager/amp_auto_cast.h b/paddle/fluid/eager/amp_auto_cast.h index ed05a6e69c0..5110f6f883e 100644 --- a/paddle/fluid/eager/amp_auto_cast.h +++ b/paddle/fluid/eager/amp_auto_cast.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h" #include "paddle/fluid/framework/convert_utils.h" namespace egr { diff --git a/paddle/fluid/eager/api/CMakeLists.txt b/paddle/fluid/eager/api/CMakeLists.txt index 4525a58a44d..0da46bbbfbb 100644 --- a/paddle/fluid/eager/api/CMakeLists.txt +++ b/paddle/fluid/eager/api/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(manual) add_subdirectory(utils) add_subdirectory(generated) diff --git a/paddle/fluid/eager/api/manual/CMakeLists.txt b/paddle/fluid/eager/api/manual/CMakeLists.txt new file mode 100644 index 00000000000..ebfcaad2eea --- /dev/null +++ b/paddle/fluid/eager/api/manual/CMakeLists.txt @@ -0,0 +1,9 @@ +if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) + add_subdirectory(fluid_manual) + set(fluid_manual_functions + ${fluid_manual_functions} + PARENT_SCOPE) + set(fluid_manual_nodes + ${fluid_manual_nodes} + PARENT_SCOPE) +endif() diff --git a/paddle/fluid/eager/api/manual/fluid_manual/CMakeLists.txt b/paddle/fluid/eager/api/manual/fluid_manual/CMakeLists.txt new file mode 100644 index 00000000000..254f4a7246d --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/CMakeLists.txt @@ -0,0 +1,8 @@ +add_subdirectory(forwards) +add_subdirectory(nodes) +set(fluid_manual_functions + ${fluid_manual_functions} + PARENT_SCOPE) +set(fluid_manual_nodes + ${fluid_manual_nodes} + PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h b/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h new file mode 100644 index 00000000000..3715544b923 --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "glog/logging.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/phi/api/all.h" + +std::tuple +fused_gate_attention_dygraph_function( + const paddle::experimental::Tensor& Query, + const paddle::experimental::Tensor& Key, + const paddle::experimental::Tensor& QueryWeight, + const paddle::experimental::Tensor& KeyWeight, + const paddle::experimental::Tensor& ValueWeight, + const paddle::experimental::Tensor& QKVWeight, + const paddle::experimental::Tensor& NonbatchedBias, + const paddle::experimental::Tensor& SrcMask, + const paddle::experimental::Tensor& GateWeight, + const paddle::experimental::Tensor& GateBias, + const paddle::experimental::Tensor& OutLinearWeight, + const paddle::experimental::Tensor& OutLinearBias, + const paddle::framework::AttributeMap& attr_map); diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt b/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt new file mode 100644 index 00000000000..2a7d72eb7ca --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/CMakeLists.txt @@ -0,0 +1,10 @@ +cc_library( + fused_gate_attention_fwd_func + SRCS fused_gate_attention_fwd_func.cc + DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) + +add_dependencies(fused_gate_attention_fwd_func eager_codegen) + +set(fluid_manual_functions + fused_gate_attention_fwd_func + PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc new file mode 100644 index 00000000000..81b4db4df20 --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc @@ -0,0 +1,389 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/amp_auto_cast.h" +#include "paddle/fluid/eager/amp_utils.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/dygraph_forward_api.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" + +#pragma GCC diagnostic ignored "-Wunused-variable" + +std::tuple +fused_gate_attention_dygraph_function( + const paddle::experimental::Tensor& Query, + const paddle::experimental::Tensor& Key, + const paddle::experimental::Tensor& QueryWeight, + const paddle::experimental::Tensor& KeyWeight, + const paddle::experimental::Tensor& ValueWeight, + const paddle::experimental::Tensor& QKVWeight, + const paddle::experimental::Tensor& NonbatchedBias, + const paddle::experimental::Tensor& SrcMask, + const paddle::experimental::Tensor& GateWeight, + const paddle::experimental::Tensor& GateBias, + const paddle::experimental::Tensor& OutLinearWeight, + const paddle::experimental::Tensor& OutLinearBias, + const paddle::framework::AttributeMap& attr_map) { + paddle::platform::RecordEvent dygraph_entrance_record_event( + "fused_gate_attention dygraph", + paddle::platform::TracerEventType::Operator, + 1); + VLOG(3) << "Running Eager Forward Op: fused_gate_attention"; + // Dygraph Forward Pass + + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + VLOG(5) << "Check and Prepare For AMP"; + + paddle::small_vector, + egr::kSlotSmallVectorSize> + amp_tensors_vector = { + {Query}, {SrcMask}, {OutLinearWeight}, {OutLinearBias}}; + if (Key.initialized()) amp_tensors_vector.push_back({Key}); + if (QueryWeight.initialized()) amp_tensors_vector.push_back({QueryWeight}); + if (KeyWeight.initialized()) amp_tensors_vector.push_back({KeyWeight}); + if (ValueWeight.initialized()) amp_tensors_vector.push_back({ValueWeight}); + if (QKVWeight.initialized()) amp_tensors_vector.push_back({QKVWeight}); + if (NonbatchedBias.initialized()) + amp_tensors_vector.push_back({NonbatchedBias}); + if (GateWeight.initialized()) amp_tensors_vector.push_back({GateWeight}); + if (GateBias.initialized()) amp_tensors_vector.push_back({GateBias}); + + auto amp_dst_dtype = + egr::GetAmpDestDtype("fused_gate_attention", amp_tensors_vector); + + auto NEW_Query = + egr::AmpAutoCast("Query", Query, amp_dst_dtype, "fused_gate_attention"); + auto NEW_SrcMask = egr::AmpAutoCast( + "SrcMask", SrcMask, amp_dst_dtype, "fused_gate_attention"); + auto NEW_OutLinearWeight = egr::AmpAutoCast("OutLinearWeight", + OutLinearWeight, + amp_dst_dtype, + "fused_gate_attention"); + auto NEW_OutLinearBias = egr::AmpAutoCast( + "OutLinearBias", OutLinearBias, amp_dst_dtype, "fused_gate_attention"); + auto NEW_Key = ((Key.initialized()) + ? egr::AmpAutoCast( + "Key", Key, amp_dst_dtype, "fused_gate_attention") + : Key); + auto NEW_QueryWeight = + ((QueryWeight.initialized()) ? egr::AmpAutoCast("QueryWeight", + QueryWeight, + amp_dst_dtype, + "fused_gate_attention") + : QueryWeight); + auto NEW_KeyWeight = + ((KeyWeight.initialized()) ? egr::AmpAutoCast("KeyWeight", + KeyWeight, + amp_dst_dtype, + "fused_gate_attention") + : KeyWeight); + auto NEW_ValueWeight = + ((ValueWeight.initialized()) ? egr::AmpAutoCast("ValueWeight", + ValueWeight, + amp_dst_dtype, + "fused_gate_attention") + : ValueWeight); + auto NEW_QKVWeight = + ((QKVWeight.initialized()) ? egr::AmpAutoCast("QKVWeight", + QKVWeight, + amp_dst_dtype, + "fused_gate_attention") + : QKVWeight); + auto NEW_NonbatchedBias = ((NonbatchedBias.initialized()) + ? egr::AmpAutoCast("NonbatchedBias", + NonbatchedBias, + amp_dst_dtype, + "fused_gate_attention") + : NonbatchedBias); + auto NEW_GateWeight = + ((GateWeight.initialized()) ? egr::AmpAutoCast("GateWeight", + GateWeight, + amp_dst_dtype, + "fused_gate_attention") + : GateWeight); + auto NEW_GateBias = + ((GateBias.initialized()) + ? egr::AmpAutoCast( + "GateBias", GateBias, amp_dst_dtype, "fused_gate_attention") + : GateBias); + + { + paddle::imperative::AutoCastGuard guard( + egr::Controller::Instance().GetCurrentTracer(), + paddle::imperative::AmpLevel::O0); + return fused_gate_attention_dygraph_function(NEW_Query, + NEW_Key, + NEW_QueryWeight, + NEW_KeyWeight, + NEW_ValueWeight, + NEW_QKVWeight, + NEW_NonbatchedBias, + NEW_SrcMask, + NEW_GateWeight, + NEW_GateBias, + NEW_OutLinearWeight, + NEW_OutLinearBias, + attr_map); + } + } + + std::map>> ins = + {{"Query", egr::EagerUtils::TrySyncToVars(Query)}, + {"SrcMask", egr::EagerUtils::TrySyncToVars(SrcMask)}, + {"OutLinearWeight", egr::EagerUtils::TrySyncToVars(OutLinearWeight)}, + {"OutLinearBias", egr::EagerUtils::TrySyncToVars(OutLinearBias)}}; + if (Key.initialized()) ins["Key"] = egr::EagerUtils::TrySyncToVars(Key); + if (QueryWeight.initialized()) + ins["QueryWeight"] = egr::EagerUtils::TrySyncToVars(QueryWeight); + if (KeyWeight.initialized()) + ins["KeyWeight"] = egr::EagerUtils::TrySyncToVars(KeyWeight); + if (ValueWeight.initialized()) + ins["ValueWeight"] = egr::EagerUtils::TrySyncToVars(ValueWeight); + if (QKVWeight.initialized()) + ins["QKVWeight"] = egr::EagerUtils::TrySyncToVars(QKVWeight); + if (NonbatchedBias.initialized()) + ins["NonbatchedBias"] = egr::EagerUtils::TrySyncToVars(NonbatchedBias); + if (GateWeight.initialized()) + ins["GateWeight"] = egr::EagerUtils::TrySyncToVars(GateWeight); + if (GateBias.initialized()) + ins["GateBias"] = egr::EagerUtils::TrySyncToVars(GateBias); + + std::map>> outs = + {{"QueryTransposeOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"KeyTransposeOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"ValueTransposeOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"QKVTransposeOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"SoftmaxOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"FMHAOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"GateOut", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}, + {"Out", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}}; + + // Prepare Autograd Meta + egr::AutogradMeta* p_autograd_Query = + egr::EagerUtils::nullable_autograd_meta(Query); + egr::AutogradMeta* p_autograd_Key = + egr::EagerUtils::nullable_autograd_meta(Key); + egr::AutogradMeta* p_autograd_QueryWeight = + egr::EagerUtils::nullable_autograd_meta(QueryWeight); + egr::AutogradMeta* p_autograd_KeyWeight = + egr::EagerUtils::nullable_autograd_meta(KeyWeight); + egr::AutogradMeta* p_autograd_ValueWeight = + egr::EagerUtils::nullable_autograd_meta(ValueWeight); + egr::AutogradMeta* p_autograd_QKVWeight = + egr::EagerUtils::nullable_autograd_meta(QKVWeight); + egr::AutogradMeta* p_autograd_NonbatchedBias = + egr::EagerUtils::nullable_autograd_meta(NonbatchedBias); + egr::AutogradMeta* p_autograd_SrcMask = + egr::EagerUtils::nullable_autograd_meta(SrcMask); + egr::AutogradMeta* p_autograd_GateWeight = + egr::EagerUtils::nullable_autograd_meta(GateWeight); + egr::AutogradMeta* p_autograd_GateBias = + egr::EagerUtils::nullable_autograd_meta(GateBias); + egr::AutogradMeta* p_autograd_OutLinearWeight = + egr::EagerUtils::nullable_autograd_meta(OutLinearWeight); + egr::AutogradMeta* p_autograd_OutLinearBias = + egr::EagerUtils::nullable_autograd_meta(OutLinearBias); + + bool trace_backward = egr::Controller::Instance().HasGrad(); + + bool require_any_grad = + egr::EagerUtils::ComputeRequireGrad(trace_backward, + p_autograd_Query, + p_autograd_Key, + p_autograd_QueryWeight, + p_autograd_KeyWeight, + p_autograd_ValueWeight, + p_autograd_QKVWeight, + p_autograd_NonbatchedBias, + p_autograd_SrcMask, + p_autograd_GateWeight, + p_autograd_GateBias, + p_autograd_OutLinearWeight, + p_autograd_OutLinearBias); + + paddle::framework::AttributeMap attrs = attr_map; + paddle::framework::AttributeMap default_attrs; + egr::Controller::Instance().GetCurrentTracer()->TraceOp( + "fused_gate_attention", + ins, + outs, + attrs, + egr::Controller::Instance().GetExpectedPlace(), + &default_attrs, + true, + {}); + + paddle::experimental::Tensor QueryTransposeOut; + egr::EagerUtils::GetOutput(outs["QueryTransposeOut"][0], &QueryTransposeOut); + paddle::experimental::Tensor KeyTransposeOut; + egr::EagerUtils::GetOutput(outs["KeyTransposeOut"][0], &KeyTransposeOut); + paddle::experimental::Tensor ValueTransposeOut; + egr::EagerUtils::GetOutput(outs["ValueTransposeOut"][0], &ValueTransposeOut); + paddle::experimental::Tensor QKVTransposeOut; + egr::EagerUtils::GetOutput(outs["QKVTransposeOut"][0], &QKVTransposeOut); + paddle::experimental::Tensor SoftmaxOut; + egr::EagerUtils::GetOutput(outs["SoftmaxOut"][0], &SoftmaxOut); + paddle::experimental::Tensor FMHAOut; + egr::EagerUtils::GetOutput(outs["FMHAOut"][0], &FMHAOut); + paddle::experimental::Tensor GateOut; + egr::EagerUtils::GetOutput(outs["GateOut"][0], &GateOut); + paddle::experimental::Tensor Out; + egr::EagerUtils::GetOutput(outs["Out"][0], &Out); + + { + paddle::platform::RecordEvent node_creation_record_event( + "fused_gate_attention node_creation", + paddle::platform::TracerEventType::Operator, + 1); + egr::AutogradMeta* p_autograd_QueryTransposeOut = + egr::EagerUtils::autograd_meta(&QueryTransposeOut); + egr::AutogradMeta* p_autograd_KeyTransposeOut = + egr::EagerUtils::autograd_meta(&KeyTransposeOut); + egr::AutogradMeta* p_autograd_ValueTransposeOut = + egr::EagerUtils::autograd_meta(&ValueTransposeOut); + egr::AutogradMeta* p_autograd_QKVTransposeOut = + egr::EagerUtils::autograd_meta(&QKVTransposeOut); + egr::AutogradMeta* p_autograd_SoftmaxOut = + egr::EagerUtils::autograd_meta(&SoftmaxOut); + egr::AutogradMeta* p_autograd_FMHAOut = + egr::EagerUtils::autograd_meta(&FMHAOut); + egr::AutogradMeta* p_autograd_GateOut = + egr::EagerUtils::autograd_meta(&GateOut); + egr::AutogradMeta* p_autograd_Out = egr::EagerUtils::autograd_meta(&Out); + if (require_any_grad) { + VLOG(6) << " Construct Grad for fused_gate_attention "; + egr::EagerUtils::PassStopGradient(false, + p_autograd_QueryTransposeOut, + p_autograd_KeyTransposeOut, + p_autograd_ValueTransposeOut, + p_autograd_QKVTransposeOut, + p_autograd_SoftmaxOut, + p_autograd_FMHAOut, + p_autograd_GateOut, + p_autograd_Out); + // Create GradOpNode + auto grad_node = std::shared_ptr( + new fused_gate_attentionGradNodeCompat(8, 12)); + + bool merge_qkv = true; + if (attrs.count("merge_qkv")) { + merge_qkv = BOOST_GET_CONST(bool, attrs.at("merge_qkv")); + } + + bool has_gating = true; + if (attrs.count("has_gating")) { + has_gating = BOOST_GET_CONST(bool, attrs.at("has_gating")); + } + + // Set Attributes + grad_node->SetAttrMap(std::move(attrs)); + grad_node->SetDefaultAttrMap(std::move(default_attrs)); + + grad_node->SetTensorWrapperFMHAOut(FMHAOut); + grad_node->SetTensorWrapperQuery(Query); + grad_node->SetTensorWrapperSoftmaxOut(SoftmaxOut); + grad_node->SetTensorWrapperOutLinearBias(OutLinearBias); + grad_node->SetTensorWrapperOutLinearWeight(OutLinearWeight); + + grad_node->SetGradOutMeta(Query, 0); + grad_node->SetGradOutMeta(OutLinearWeight, 10); + grad_node->SetGradOutMeta(OutLinearBias, 11); + + if (merge_qkv) { + grad_node->SetTensorWrapperQKVTransposeOut(QKVTransposeOut); + grad_node->SetTensorWrapperQKVWeight(QKVWeight); + grad_node->SetGradOutMeta(QKVWeight, 5); + } else { + grad_node->SetTensorWrapperKey(Key); + grad_node->SetTensorWrapperQueryWeight(QueryWeight); + grad_node->SetTensorWrapperKeyWeight(KeyWeight); + grad_node->SetTensorWrapperValueWeight(ValueWeight); + grad_node->SetTensorWrapperQueryTransposeOut(QueryTransposeOut); + grad_node->SetTensorWrapperKeyTransposeOut(KeyTransposeOut); + grad_node->SetTensorWrapperValueTransposeOut(ValueTransposeOut); + + grad_node->SetGradOutMeta(Key, 1); + grad_node->SetGradOutMeta(QueryWeight, 2); + grad_node->SetGradOutMeta(KeyWeight, 3); + grad_node->SetGradOutMeta(ValueWeight, 4); + } + + if (has_gating) { + grad_node->SetTensorWrapperGateWeight(GateWeight); + grad_node->SetGradOutMeta(GateWeight, 8); + grad_node->SetTensorWrapperGateBias(GateBias); + grad_node->SetGradOutMeta(GateBias, 9); + grad_node->SetTensorWrapperGateOut(GateOut); + } + + if (NonbatchedBias.initialized()) { + grad_node->SetTensorWrapperNonbatchedBias(NonbatchedBias); + grad_node->SetGradOutMeta(NonbatchedBias, 6); + } + + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QueryTransposeOut, 0); + grad_node->SetGradInMeta(QueryTransposeOut, 0); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_KeyTransposeOut, 1); + grad_node->SetGradInMeta(KeyTransposeOut, 1); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_ValueTransposeOut, 2); + grad_node->SetGradInMeta(ValueTransposeOut, 2); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_QKVTransposeOut, 3); + grad_node->SetGradInMeta(QKVTransposeOut, 3); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_SoftmaxOut, 4); + grad_node->SetGradInMeta(SoftmaxOut, 4); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_FMHAOut, 5); + grad_node->SetGradInMeta(FMHAOut, 5); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_GateOut, 6); + grad_node->SetGradInMeta(GateOut, 6); + egr::EagerUtils::SetOutRankWithSlot(p_autograd_Out, 7); + egr::EagerUtils::SetHistory(p_autograd_Out, grad_node); + grad_node->SetGradInMeta(Out, 7); + egr::EagerUtils::CheckAndRetainGrad(Out); + } + } + + return std::make_tuple(QueryTransposeOut, + KeyTransposeOut, + ValueTransposeOut, + QKVTransposeOut, + SoftmaxOut, + FMHAOut, + GateOut, + Out); +} diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt b/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt new file mode 100644 index 00000000000..fb5e1292235 --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/CMakeLists.txt @@ -0,0 +1,8 @@ +cc_library( + fused_gate_attention_node + SRCS fused_gate_attention_node.cc + DEPS ${eager_deps} ${fluid_deps}) + +set(fluid_manual_nodes + fused_gate_attention_node + PARENT_SCOPE) diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc b/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc new file mode 100644 index 00000000000..a1ccaf09de8 --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/fused_gate_attention_node.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/phi/api/all.h" + +paddle::small_vector, + egr::kSlotSmallVectorSize> +fused_gate_attentionGradNodeCompat::operator()( + paddle::small_vector, + egr::kSlotSmallVectorSize>& grads, + bool create_graph, + bool is_new_grad) { + VLOG(3) << "Running Eager Backward Node: fused_gate_attentionGradNodeCompat"; + + const auto& out_metas = OutputMeta(); + paddle::small_vector, + egr::kSlotSmallVectorSize> + outputs(12); + paddle::small_vector, + egr::kSlotSmallVectorSize> + hooked_grads0 = + fused_gate_attentionGradNodeCompat::ApplyGradientHooks(grads); + + bool merge_qkv = true; + if (attr_map_.count("merge_qkv")) { + merge_qkv = BOOST_GET_CONST(bool, attr_map_.at("merge_qkv")); + } + + bool has_gating = true; + if (attr_map_.count("has_gating")) { + has_gating = BOOST_GET_CONST(bool, attr_map_.at("has_gating")); + } + + std::map>> ins0 = + {{"FMHAOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->FMHAOut_))}, + {"Out@GRAD", egr::EagerUtils::TrySyncToVars(hooked_grads0[7])}, + {"OutLinearBias", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearBias_))}, + {"OutLinearWeight", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->OutLinearWeight_))}, + {"Query", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->Query_))}, + {"SoftmaxOut", + egr::EagerUtils::TrySyncToVars( + egr::EagerUtils::RecoverTensorWrapper(&this->SoftmaxOut_))}}; + std::map>> outs0; + + if ((!out_metas[11].empty()) && (!(out_metas[11][0].IsStopGradient()))) { + outs0.insert({"OutLinearBias@GRAD", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}); + } + if ((!out_metas[10].empty()) && (!(out_metas[10][0].IsStopGradient()))) { + outs0.insert({"OutLinearWeight@GRAD", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}); + } + if ((!out_metas[0].empty()) && (!(out_metas[0][0].IsStopGradient()))) { + outs0.insert({"Query@GRAD", + {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}}); + } + + if (merge_qkv) { + auto QKVTransposeOut = + egr::EagerUtils::RecoverTensorWrapper(&this->QKVTransposeOut_); + if (QKVTransposeOut.defined()) + ins0["QKVTransposeOut"] = egr::EagerUtils::TrySyncToVars(QKVTransposeOut); + auto QKVWeight = egr::EagerUtils::RecoverTensorWrapper(&this->QKVWeight_); + if (QKVWeight.defined()) + ins0["QKVWeight"] = egr::EagerUtils::TrySyncToVars(QKVWeight); + if (QKVWeight.defined() && (!out_metas[5].empty()) && + (!out_metas[5][0].IsStopGradient())) + outs0["QKVWeight@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } else { + auto Key = egr::EagerUtils::RecoverTensorWrapper(&this->Key_); + if (Key.defined()) ins0["Key"] = egr::EagerUtils::TrySyncToVars(Key); + auto QueryWeight = + egr::EagerUtils::RecoverTensorWrapper(&this->QueryWeight_); + if (QueryWeight.defined()) + ins0["QueryWeight"] = egr::EagerUtils::TrySyncToVars(QueryWeight); + auto KeyWeight = egr::EagerUtils::RecoverTensorWrapper(&this->KeyWeight_); + if (KeyWeight.defined()) + ins0["KeyWeight"] = egr::EagerUtils::TrySyncToVars(KeyWeight); + auto ValueWeight = + egr::EagerUtils::RecoverTensorWrapper(&this->ValueWeight_); + if (ValueWeight.defined()) + ins0["ValueWeight"] = egr::EagerUtils::TrySyncToVars(ValueWeight); + auto QueryTransposeOut = + egr::EagerUtils::RecoverTensorWrapper(&this->QueryTransposeOut_); + if (QueryTransposeOut.defined()) + ins0["QueryTransposeOut"] = + egr::EagerUtils::TrySyncToVars(QueryTransposeOut); + auto KeyTransposeOut = + egr::EagerUtils::RecoverTensorWrapper(&this->KeyTransposeOut_); + if (KeyTransposeOut.defined()) + ins0["KeyTransposeOut"] = egr::EagerUtils::TrySyncToVars(KeyTransposeOut); + auto ValueTransposeOut = + egr::EagerUtils::RecoverTensorWrapper(&this->ValueTransposeOut_); + if (ValueTransposeOut.defined()) + ins0["ValueTransposeOut"] = + egr::EagerUtils::TrySyncToVars(ValueTransposeOut); + + if (Key.defined() && (!out_metas[1].empty()) && + (!out_metas[1][0].IsStopGradient())) + outs0["Key@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + if (QueryWeight.defined() && (!out_metas[2].empty()) && + (!out_metas[2][0].IsStopGradient())) + outs0["QueryWeight@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + if (KeyWeight.defined() && (!out_metas[3].empty()) && + (!out_metas[3][0].IsStopGradient())) + outs0["KeyWeight@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + if (ValueWeight.defined() && (!out_metas[4].empty()) && + (!out_metas[4][0].IsStopGradient())) + outs0["ValueWeight@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + if (has_gating) { + auto GateBias = egr::EagerUtils::RecoverTensorWrapper(&this->GateBias_); + if (GateBias.defined()) + ins0["GateBias"] = egr::EagerUtils::TrySyncToVars(GateBias); + auto GateWeight = egr::EagerUtils::RecoverTensorWrapper(&this->GateWeight_); + if (GateWeight.defined()) + ins0["GateWeight"] = egr::EagerUtils::TrySyncToVars(GateWeight); + auto GateOut = egr::EagerUtils::RecoverTensorWrapper(&this->GateOut_); + if (GateOut.defined()) + ins0["GateOut"] = egr::EagerUtils::TrySyncToVars(GateOut); + if (GateBias.defined() && (!out_metas[9].empty()) && + (!out_metas[9][0].IsStopGradient())) + outs0["GateBias@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + if (GateWeight.defined() && (!out_metas[8].empty()) && + (!out_metas[8][0].IsStopGradient())) + outs0["GateWeight@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto NonbatchedBias = + egr::EagerUtils::RecoverTensorWrapper(&this->NonbatchedBias_); + if (NonbatchedBias.defined()) { + ins0["NonbatchedBias"] = egr::EagerUtils::TrySyncToVars(NonbatchedBias); + if ((!out_metas[6].empty()) && (!out_metas[6][0].IsStopGradient())) + outs0["NonbatchedBias@GRAD"] = {std::make_shared( + egr::Controller::Instance().GenerateUniqueName())}; + } + + auto& attrs_map0 = this->attr_map_; + // Pass the entire attribute map to TraceOp + // The underlying kernel will pickup whatever attribute they need at runtime + egr::Controller::Instance().GetCurrentTracer()->TraceOp( + "fused_gate_attention_grad", + ins0, + outs0, + attrs_map0, + egr::Controller::Instance().GetExpectedPlace(), + &this->default_attr_map_, + false, + {}); + + if (outs0.find("Query@GRAD") != outs0.end()) { + outputs[0] = egr::EagerUtils::GetOutputs(outs0["Query@GRAD"]); + } + if (outs0.find("OutLinearBias@GRAD") != outs0.end()) { + outputs[11] = egr::EagerUtils::GetOutputs(outs0["OutLinearBias@GRAD"]); + } + if (outs0.find("OutLinearWeight@GRAD") != outs0.end()) { + outputs[10] = egr::EagerUtils::GetOutputs(outs0["OutLinearWeight@GRAD"]); + } + + if (merge_qkv) { + if (outs0.find("QKVWeight@GRAD") != outs0.end()) { + outputs[5] = egr::EagerUtils::GetOutputs(outs0["QKVWeight@GRAD"]); + } + } else { + if (outs0.find("Key@GRAD") != outs0.end()) { + outputs[1] = egr::EagerUtils::GetOutputs(outs0["Key@GRAD"]); + } + if (outs0.find("QueryWeight@GRAD") != outs0.end()) { + outputs[2] = egr::EagerUtils::GetOutputs(outs0["QueryWeight@GRAD"]); + } + if (outs0.find("KeyWeight@GRAD") != outs0.end()) { + outputs[3] = egr::EagerUtils::GetOutputs(outs0["KeyWeight@GRAD"]); + } + if (outs0.find("ValueWeight@GRAD") != outs0.end()) { + outputs[4] = egr::EagerUtils::GetOutputs(outs0["ValueWeight@GRAD"]); + } + } + + if (has_gating) { + if (outs0.find("GateBias@GRAD") != outs0.end()) { + outputs[9] = egr::EagerUtils::GetOutputs(outs0["GateBias@GRAD"]); + } + if (outs0.find("GateWeight@GRAD") != outs0.end()) { + outputs[8] = egr::EagerUtils::GetOutputs(outs0["GateWeight@GRAD"]); + } + } + + if (NonbatchedBias.defined()) { + if (outs0.find("NonbatchedBias@GRAD") != outs0.end()) { + outputs[6] = egr::EagerUtils::GetOutputs(outs0["NonbatchedBias@GRAD"]); + } + } + + if (NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&outputs); + return outputs; +} diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h new file mode 100644 index 00000000000..0f0fac4b725 --- /dev/null +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h @@ -0,0 +1,176 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/imperative/tracer.h" + +class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { + public: + fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() { + VLOG(7) << " Construct fused_gate_attentionGradNodeCompat "; + } + fused_gate_attentionGradNodeCompat(size_t bwd_in_slot_num, + size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { + VLOG(7) << " Construct fused_gate_attentionGradNodeCompat "; + } + ~fused_gate_attentionGradNodeCompat() override { + VLOG(6) << " Destruct fused_gate_attentionGradNodeCompat "; + } + + virtual paddle::small_vector, + egr::kSlotSmallVectorSize> + operator()( + paddle::small_vector, // NOLINT + egr::kSlotSmallVectorSize>& grads, // NOLINT + bool create_graph = false, + bool is_new_grad = false) override; + + void ClearTensorWrappers() override { + FMHAOut_.clear(); + GateBias_.clear(); + GateOut_.clear(); + GateWeight_.clear(); + NonbatchedBias_.clear(); + OutLinearBias_.clear(); + OutLinearWeight_.clear(); + QKVTransposeOut_.clear(); + QKVWeight_.clear(); + Query_.clear(); + SoftmaxOut_.clear(); + Key_.clear(); + QueryWeight_.clear(); + KeyWeight_.clear(); + ValueWeight_.clear(); + QueryTransposeOut_.clear(); + KeyTransposeOut_.clear(); + ValueTransposeOut_.clear(); + + SetIsTensorWrappersCleared(true); + } + std::string name() override { return "fused_gate_attentionGradNodeCompat"; } + + std::shared_ptr Copy() const override { + { + auto copied_node = std::shared_ptr( + new fused_gate_attentionGradNodeCompat(*this)); + return copied_node; + } + } + + // SetX, SetY, ... + void SetTensorWrapperFMHAOut(const paddle::experimental::Tensor& FMHAOut) { + FMHAOut_ = egr::TensorWrapper(FMHAOut, false); + } + void SetTensorWrapperGateBias(const paddle::experimental::Tensor& GateBias) { + GateBias_ = egr::TensorWrapper(GateBias, false); + } + void SetTensorWrapperGateOut(const paddle::experimental::Tensor& GateOut) { + GateOut_ = egr::TensorWrapper(GateOut, false); + } + void SetTensorWrapperGateWeight( + const paddle::experimental::Tensor& GateWeight) { + GateWeight_ = egr::TensorWrapper(GateWeight, false); + } + void SetTensorWrapperNonbatchedBias( + const paddle::experimental::Tensor& NonbatchedBias) { + NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false); + } + void SetTensorWrapperOutLinearBias( + const paddle::experimental::Tensor& OutLinearBias) { + OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false); + } + void SetTensorWrapperOutLinearWeight( + const paddle::experimental::Tensor& OutLinearWeight) { + OutLinearWeight_ = egr::TensorWrapper(OutLinearWeight, false); + } + void SetTensorWrapperQKVTransposeOut( + const paddle::experimental::Tensor& QKVTransposeOut) { + QKVTransposeOut_ = egr::TensorWrapper(QKVTransposeOut, false); + } + void SetTensorWrapperQKVWeight( + const paddle::experimental::Tensor& QKVWeight) { + QKVWeight_ = egr::TensorWrapper(QKVWeight, false); + } + void SetTensorWrapperQuery(const paddle::experimental::Tensor& Query) { + Query_ = egr::TensorWrapper(Query, false); + } + void SetTensorWrapperSoftmaxOut( + const paddle::experimental::Tensor& SoftmaxOut) { + SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); + } + void SetTensorWrapperKey(const paddle::experimental::Tensor& Key) { + Key_ = egr::TensorWrapper(Key, false); + } + void SetTensorWrapperQueryWeight( + const paddle::experimental::Tensor& QueryWeight) { + QueryWeight_ = egr::TensorWrapper(QueryWeight, false); + } + void SetTensorWrapperKeyWeight( + const paddle::experimental::Tensor& KeyWeight) { + KeyWeight_ = egr::TensorWrapper(KeyWeight, false); + } + void SetTensorWrapperValueWeight( + const paddle::experimental::Tensor& ValueWeight) { + ValueWeight_ = egr::TensorWrapper(ValueWeight, false); + } + void SetTensorWrapperQueryTransposeOut( + const paddle::experimental::Tensor& QueryTransposeOut) { + QueryTransposeOut_ = egr::TensorWrapper(QueryTransposeOut, false); + } + void SetTensorWrapperKeyTransposeOut( + const paddle::experimental::Tensor& KeyTransposeOut) { + KeyTransposeOut_ = egr::TensorWrapper(KeyTransposeOut, false); + } + void SetTensorWrapperValueTransposeOut( + const paddle::experimental::Tensor& ValueTransposeOut) { + ValueTransposeOut_ = egr::TensorWrapper(ValueTransposeOut, false); + } + + // SetAttrMap + void SetAttrMap(paddle::framework::AttributeMap&& attr_map) { + attr_map_ = std::move(attr_map); + } + void SetDefaultAttrMap(paddle::framework::AttributeMap&& default_attr_map) { + default_attr_map_ = std::move(default_attr_map); + } + + private: + // TensorWrappers + egr::TensorWrapper FMHAOut_; + egr::TensorWrapper GateBias_; + egr::TensorWrapper GateOut_; + egr::TensorWrapper GateWeight_; + egr::TensorWrapper NonbatchedBias_; + egr::TensorWrapper OutLinearBias_; + egr::TensorWrapper OutLinearWeight_; + egr::TensorWrapper QKVTransposeOut_; + egr::TensorWrapper QKVWeight_; + egr::TensorWrapper Query_; + egr::TensorWrapper SoftmaxOut_; + + egr::TensorWrapper Key_; + egr::TensorWrapper QueryWeight_; + egr::TensorWrapper KeyWeight_; + egr::TensorWrapper ValueWeight_; + egr::TensorWrapper QueryTransposeOut_; + egr::TensorWrapper KeyTransposeOut_; + egr::TensorWrapper ValueTransposeOut_; + + // Attribute Map + paddle::framework::AttributeMap attr_map_; + paddle::framework::AttributeMap default_attr_map_; +}; diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 6910f9e537f..bbd6ea64946 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -51,7 +51,8 @@ static std::unordered_set ops_to_fill_zero_for_empty_grads = { "split", "rnn"}; /* --- Black Ops list that's NO NEED to apply code generation --- */ -static std::unordered_set black_ops_list = {"run_program"}; +static std::unordered_set black_ops_list = { + "run_program", "fused_gate_attention"}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; @@ -2972,7 +2973,10 @@ static std::string GenerateDygraphHFileIncludes() { "#include \"paddle/phi/api/all.h\"\n" "#include \"paddle/fluid/eager/utils.h\"\n" "#include \"paddle/fluid/imperative/tracer.h\"\n" - "#include \"paddle/fluid/framework/op_registry.h\"\n\n"; + "#include \"paddle/fluid/framework/op_registry.h\"\n" + "#include " + "\"paddle/fluid/eager/api/manual/fluid_manual/" + "dygraph_forward_api.h\"\n\n"; dygraph_forward_api_includes_str += "extern std::unordered_map> " @@ -3021,7 +3025,10 @@ static void GenerateNodeHFile(const std::string& node_h_path, "#pragma once\n" "#include \"paddle/fluid/eager/tensor_wrapper.h\"\n" "#include \"paddle/fluid/imperative/tracer.h\"\n" - "#include \"paddle/fluid/eager/grad_node_info.h\"\n\n"; + "#include \"paddle/fluid/eager/grad_node_info.h\"\n" + "#include " + "\"paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h\"\n\n"; + std::ofstream node_h_stream(node_h_path, std::ios::out); node_h_stream << node_h_include_str; node_h_stream << grad_node_str; diff --git a/paddle/fluid/eager/auto_code_generator/generate_file_structures.py b/paddle/fluid/eager/auto_code_generator/generate_file_structures.py index fdb8529515d..a7cd1dc8c46 100644 --- a/paddle/fluid/eager/auto_code_generator/generate_file_structures.py +++ b/paddle/fluid/eager/auto_code_generator/generate_file_structures.py @@ -1,11 +1,11 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -103,13 +103,13 @@ def GenerateFileStructureForIntermediateDygraph(eager_dir): with open(nodes_level_cmakelist_path, "w") as f: f.write( - "cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps})\n" + "cc_library(dygraph_node SRCS nodes.cc DEPS ${eager_deps} ${fluid_deps} ${fluid_manual_nodes})\n" ) f.write("add_dependencies(dygraph_node eager_codegen)") with open(forwards_level_cmakelist_path, "w") as f: f.write( - "cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})\n" + "cc_library(dygraph_function SRCS dygraph_forward_functions.cc DEPS ${eager_deps} ${fluid_deps} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ${fluid_manual_functions})\n" ) f.write("add_dependencies(dygraph_function eager_codegen)") diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cu b/paddle/fluid/operators/fused/fused_gate_attention_op.cu index 0d219a4f76d..7400246f407 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cu @@ -363,13 +363,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating); if (merge_qkv) { - PADDLE_ENFORCE_EQ(!key || query == key, - true, - platform::errors::InvalidArgument( - "key is expected to be nullptr or the same as " - "query, but recieved key=%p, query=%p.", - key, - query)); + PADDLE_ENFORCE_EQ( + !key || query == key || query->data() == key->data(), + true, + platform::errors::InvalidArgument( + "key is expected to be nullptr or the same as " + "query, but recieved key=%p, query=%p.", + key, + query)); // 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc) Tensor *qkv_out = config.GetQKVOut(); -- GitLab