// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/imperative/tracer.h" class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { public: fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() { VLOG(7) << " Construct fused_gate_attentionGradNodeCompat "; } fused_gate_attentionGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { VLOG(7) << " Construct fused_gate_attentionGradNodeCompat "; } ~fused_gate_attentionGradNodeCompat() override { VLOG(6) << " Destruct fused_gate_attentionGradNodeCompat "; } virtual paddle::small_vector, egr::kSlotSmallVectorSize> operator()( paddle::small_vector, // NOLINT egr::kSlotSmallVectorSize>& grads, // NOLINT bool create_graph = false, bool is_new_grad = false) override; void ClearTensorWrappers() override { FMHAOut_.clear(); GateBias_.clear(); GateOut_.clear(); GateWeight_.clear(); NonbatchedBias_.clear(); OutLinearBias_.clear(); OutLinearWeight_.clear(); QKVTransposeOut_.clear(); QKVWeight_.clear(); Query_.clear(); SoftmaxOut_.clear(); Key_.clear(); QueryWeight_.clear(); KeyWeight_.clear(); ValueWeight_.clear(); QueryTransposeOut_.clear(); KeyTransposeOut_.clear(); ValueTransposeOut_.clear(); SetIsTensorWrappersCleared(true); } std::string name() override { return "fused_gate_attentionGradNodeCompat"; } std::shared_ptr Copy() const override { { auto copied_node = std::shared_ptr( new fused_gate_attentionGradNodeCompat(*this)); return copied_node; } } // SetX, SetY, ... void SetTensorWrapperFMHAOut(const paddle::experimental::Tensor& FMHAOut) { FMHAOut_ = egr::TensorWrapper(FMHAOut, false); } void SetTensorWrapperGateBias(const paddle::experimental::Tensor& GateBias) { GateBias_ = egr::TensorWrapper(GateBias, false); } void SetTensorWrapperGateOut(const paddle::experimental::Tensor& GateOut) { GateOut_ = egr::TensorWrapper(GateOut, false); } void SetTensorWrapperGateWeight( const paddle::experimental::Tensor& GateWeight) { GateWeight_ = egr::TensorWrapper(GateWeight, false); } void SetTensorWrapperNonbatchedBias( const paddle::experimental::Tensor& NonbatchedBias) { NonbatchedBias_ = egr::TensorWrapper(NonbatchedBias, false); } void SetTensorWrapperOutLinearBias( const paddle::experimental::Tensor& OutLinearBias) { OutLinearBias_ = egr::TensorWrapper(OutLinearBias, false); } void SetTensorWrapperOutLinearWeight( const paddle::experimental::Tensor& OutLinearWeight) { OutLinearWeight_ = egr::TensorWrapper(OutLinearWeight, false); } void SetTensorWrapperQKVTransposeOut( const paddle::experimental::Tensor& QKVTransposeOut) { QKVTransposeOut_ = egr::TensorWrapper(QKVTransposeOut, false); } void SetTensorWrapperQKVWeight( const paddle::experimental::Tensor& QKVWeight) { QKVWeight_ = egr::TensorWrapper(QKVWeight, false); } void SetTensorWrapperQuery(const paddle::experimental::Tensor& Query) { Query_ = egr::TensorWrapper(Query, false); } void SetTensorWrapperSoftmaxOut( const paddle::experimental::Tensor& SoftmaxOut) { SoftmaxOut_ = egr::TensorWrapper(SoftmaxOut, false); } void SetTensorWrapperKey(const paddle::experimental::Tensor& Key) { Key_ = egr::TensorWrapper(Key, false); } void SetTensorWrapperQueryWeight( const paddle::experimental::Tensor& QueryWeight) { QueryWeight_ = egr::TensorWrapper(QueryWeight, false); } void SetTensorWrapperKeyWeight( const paddle::experimental::Tensor& KeyWeight) { KeyWeight_ = egr::TensorWrapper(KeyWeight, false); } void SetTensorWrapperValueWeight( const paddle::experimental::Tensor& ValueWeight) { ValueWeight_ = egr::TensorWrapper(ValueWeight, false); } void SetTensorWrapperQueryTransposeOut( const paddle::experimental::Tensor& QueryTransposeOut) { QueryTransposeOut_ = egr::TensorWrapper(QueryTransposeOut, false); } void SetTensorWrapperKeyTransposeOut( const paddle::experimental::Tensor& KeyTransposeOut) { KeyTransposeOut_ = egr::TensorWrapper(KeyTransposeOut, false); } void SetTensorWrapperValueTransposeOut( const paddle::experimental::Tensor& ValueTransposeOut) { ValueTransposeOut_ = egr::TensorWrapper(ValueTransposeOut, false); } // SetAttrMap void SetAttrMap(paddle::framework::AttributeMap&& attr_map) { attr_map_ = std::move(attr_map); } void SetDefaultAttrMap(paddle::framework::AttributeMap&& default_attr_map) { default_attr_map_ = std::move(default_attr_map); } private: // TensorWrappers egr::TensorWrapper FMHAOut_; egr::TensorWrapper GateBias_; egr::TensorWrapper GateOut_; egr::TensorWrapper GateWeight_; egr::TensorWrapper NonbatchedBias_; egr::TensorWrapper OutLinearBias_; egr::TensorWrapper OutLinearWeight_; egr::TensorWrapper QKVTransposeOut_; egr::TensorWrapper QKVWeight_; egr::TensorWrapper Query_; egr::TensorWrapper SoftmaxOut_; egr::TensorWrapper Key_; egr::TensorWrapper QueryWeight_; egr::TensorWrapper KeyWeight_; egr::TensorWrapper ValueWeight_; egr::TensorWrapper QueryTransposeOut_; egr::TensorWrapper KeyTransposeOut_; egr::TensorWrapper ValueTransposeOut_; // Attribute Map paddle::framework::AttributeMap attr_map_; paddle::framework::AttributeMap default_attr_map_; }; class fused_feedforwardGradNodeCompat : public egr::GradNodeBase { public: fused_feedforwardGradNodeCompat() : egr::GradNodeBase() { VLOG(7) << " Construct fused_feedforwardGradNodeCompat "; } fused_feedforwardGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { VLOG(7) << " Construct fused_feedforwardGradNodeCompat "; } ~fused_feedforwardGradNodeCompat() override { VLOG(6) << " Destruct fused_feedforwardGradNodeCompat "; } virtual paddle::small_vector, egr::kSlotSmallVectorSize> operator()( paddle::small_vector, // NOLINT egr::kSlotSmallVectorSize>& grads, // NOLINT bool create_graph = false, bool is_new_grad = false) override; void ClearTensorWrappers() override { Dropout1Mask_.clear(); Dropout1Out_.clear(); Dropout2Mask_.clear(); Dropout2Out_.clear(); Linear1Bias_.clear(); Linear1Out_.clear(); Linear1Weight_.clear(); Linear2Bias_.clear(); Linear2Weight_.clear(); Ln2Bias_.clear(); Ln2Mean_.clear(); Ln2Scale_.clear(); Ln2Variance_.clear(); X_.clear(); SetIsTensorWrappersCleared(true); } std::string name() override { return "fused_feedforwardGradNodeCompat"; } std::shared_ptr Copy() const override { { auto copied_node = std::shared_ptr( new fused_feedforwardGradNodeCompat(*this)); return copied_node; } } // SetX, SetY, ... void SetTensorWrapperDropout1Mask( const paddle::experimental::Tensor& Dropout1Mask) { Dropout1Mask_ = egr::TensorWrapper(Dropout1Mask, false); } void SetTensorWrapperDropout1Out( const paddle::experimental::Tensor& Dropout1Out) { Dropout1Out_ = egr::TensorWrapper(Dropout1Out, false); } void SetTensorWrapperDropout2Mask( const paddle::experimental::Tensor& Dropout2Mask) { Dropout2Mask_ = egr::TensorWrapper(Dropout2Mask, false); } void SetTensorWrapperDropout2Out( const paddle::experimental::Tensor& Dropout2Out) { Dropout2Out_ = egr::TensorWrapper(Dropout2Out, false); } void SetTensorWrapperLinear1Bias( const paddle::experimental::Tensor& Linear1Bias) { Linear1Bias_ = egr::TensorWrapper(Linear1Bias, false); } void SetTensorWrapperLinear1Out( const paddle::experimental::Tensor& Linear1Out) { Linear1Out_ = egr::TensorWrapper(Linear1Out, false); } void SetTensorWrapperLinear1Weight( const paddle::experimental::Tensor& Linear1Weight) { Linear1Weight_ = egr::TensorWrapper(Linear1Weight, false); } void SetTensorWrapperLinear2Bias( const paddle::experimental::Tensor& Linear2Bias) { Linear2Bias_ = egr::TensorWrapper(Linear2Bias, false); } void SetTensorWrapperLinear2Weight( const paddle::experimental::Tensor& Linear2Weight) { Linear2Weight_ = egr::TensorWrapper(Linear2Weight, false); } void SetTensorWrapperLn2Bias(const paddle::experimental::Tensor& Ln2Bias) { Ln2Bias_ = egr::TensorWrapper(Ln2Bias, false); } void SetTensorWrapperLn2Mean(const paddle::experimental::Tensor& Ln2Mean) { Ln2Mean_ = egr::TensorWrapper(Ln2Mean, false); } void SetTensorWrapperLn2Scale(const paddle::experimental::Tensor& Ln2Scale) { Ln2Scale_ = egr::TensorWrapper(Ln2Scale, false); } void SetTensorWrapperLn2Variance( const paddle::experimental::Tensor& Ln2Variance) { Ln2Variance_ = egr::TensorWrapper(Ln2Variance, false); } void SetTensorWrapperX(const paddle::experimental::Tensor& X) { X_ = egr::TensorWrapper(X, false); } void SetTensorWrapperLn1Scale(const paddle::experimental::Tensor& Ln1Scale) { Ln1Scale_ = egr::TensorWrapper(Ln1Scale, false); } void SetTensorWrapperLn1Bias(const paddle::experimental::Tensor& Ln1Bias) { Ln1Bias_ = egr::TensorWrapper(Ln1Bias, false); } void SetTensorWrapperLn1Out(const paddle::experimental::Tensor& Ln1Out) { Ln1Out_ = egr::TensorWrapper(Ln1Out, false); } void SetTensorWrapperLn1Mean(const paddle::experimental::Tensor& Ln1Mean) { Ln1Mean_ = egr::TensorWrapper(Ln1Mean, false); } void SetTensorWrapperLn1Variance( const paddle::experimental::Tensor& Ln1Variance) { Ln1Variance_ = egr::TensorWrapper(Ln1Variance, false); } // SetAttrMap void SetAttrMap(paddle::framework::AttributeMap&& attr_map) { attr_map_ = std::move(attr_map); } void SetDefaultAttrMap(paddle::framework::AttributeMap&& default_attr_map) { default_attr_map_ = std::move(default_attr_map); } private: // TensorWrappers egr::TensorWrapper Dropout1Mask_; egr::TensorWrapper Dropout1Out_; egr::TensorWrapper Dropout2Mask_; egr::TensorWrapper Dropout2Out_; egr::TensorWrapper Linear1Bias_; egr::TensorWrapper Linear1Out_; egr::TensorWrapper Linear1Weight_; egr::TensorWrapper Linear2Bias_; egr::TensorWrapper Linear2Weight_; egr::TensorWrapper Ln2Bias_; egr::TensorWrapper Ln2Mean_; egr::TensorWrapper Ln2Scale_; egr::TensorWrapper Ln2Variance_; egr::TensorWrapper X_; egr::TensorWrapper Ln1Scale_; egr::TensorWrapper Ln1Bias_; egr::TensorWrapper Ln1Out_; egr::TensorWrapper Ln1Mean_; egr::TensorWrapper Ln1Variance_; // Attribute Map paddle::framework::AttributeMap attr_map_; paddle::framework::AttributeMap default_attr_map_; };