diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h index 32389e553d03c44c6ed350bb8ec40dd2156f03e4..7e0d679689c4a21f1c872621c8937ca153aeb23a 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h @@ -17,6 +17,23 @@ #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/imperative/tracer.h" +template +const T& GetAttrWithDefault( + const paddle::framework::AttributeMap& attrs, + const paddle::framework::AttributeMap& default_attrs, + const std::string& name) { + auto iter1 = attrs.find(name); + if (iter1 != attrs.end()) { + return PADDLE_GET_CONST(T, iter1->second); + } + auto iter2 = default_attrs.find(name); + if (iter2 != default_attrs.end()) { + return PADDLE_GET_CONST(T, iter2->second); + } + PADDLE_THROW( + phi::errors::InvalidArgument("Attribute(%s) cannot be found.", name)); +} + class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { public: fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() { @@ -240,7 +257,9 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase { } void SetTensorWrapperDropout2Out( const paddle::experimental::Tensor& Dropout2Out) { - Dropout2Out_ = egr::TensorWrapper(Dropout2Out, false); + auto pre_layer_norm = GetAttrWithDefault( + attr_map_, default_attr_map_, "pre_layer_norm"); + Dropout2Out_ = egr::TensorWrapper(Dropout2Out, pre_layer_norm); } void SetTensorWrapperLinear1Bias( const paddle::experimental::Tensor& Linear1Bias) { @@ -427,27 +446,27 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase { } void SetTensorWrapperOutLinearOut( const paddle::experimental::Tensor& OutLinearOut) { - OutLinearOut_ = egr::TensorWrapper(OutLinearOut, false); + OutLinearOut_ = egr::TensorWrapper(OutLinearOut, true); } void SetTensorWrapperOutLinearW( const paddle::experimental::Tensor& OutLinearW) { OutLinearW_ = egr::TensorWrapper(OutLinearW, false); } void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) { - QKOut_ = egr::TensorWrapper(QKOut, false); + QKOut_ = egr::TensorWrapper(QKOut, true); } void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) { - QKTVOut_ = egr::TensorWrapper(QKTVOut, false); + QKTVOut_ = egr::TensorWrapper(QKTVOut, true); } void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) { QKVBias_ = egr::TensorWrapper(QKVBias, false); } void SetTensorWrapperQKVBiasOut( const paddle::experimental::Tensor& QKVBiasOut) { - QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, false); + QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, true); } void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) { - QKVOut_ = egr::TensorWrapper(QKVOut, false); + QKVOut_ = egr::TensorWrapper(QKVOut, true); } void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) { QKVW_ = egr::TensorWrapper(QKVW, false); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 30badd3125588d3866f0bd0efb7b096b456aac88..8d03ba451bdae47ee94a8426b8aac481cb1320bd 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -704,6 +704,13 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer, + "QKVBiasOut", + "QKVOut", + "QKOut", + "QKTVOut", + "OutLinearOut"); + } // namespace operators } // namespace paddle @@ -713,7 +720,9 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOpMaker, ops::FusedAttentionGradOpMaker, ops::FusedAttentionGradOpMaker); -REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); +REGISTER_OPERATOR(fused_attention_grad, + ops::FusedAttentionGradOp, + ops::FusedAttentionGradNoNeedBufferInferer); REGISTER_OP_VERSION(fused_attention) .AddCheckpoint( diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 5ff01e0bc18cc4308c3cb2c31159adaf61203c5a..ac9e219075174d043367ae6a3bd52d6e4ec8a4ad 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -410,28 +410,24 @@ class FusedAttentionGradKernel : public framework::OpKernel { (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data(); // fw output - auto *fmha_out = ctx.Input("FMHAOut"); - auto *transpose_out_2 = ctx.Input("TransposeOut2"); - auto *qk_out = ctx.Input("QKOut"); - auto *qktv_out = ctx.Input("QKTVOut"); - auto *softmax_out = ctx.Input("SoftmaxOut"); - auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); - auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); - auto *src_mask_out = ctx.Input("SrcMaskOut"); - auto *out_linear_out = ctx.Input("OutLinearOut"); - auto *ln_2_mean = ctx.Input("Ln2Mean"); - auto *ln_2_var = ctx.Input("Ln2Variance"); - auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); + auto *fmha_out = ctx.Input("FMHAOut"); + auto *transpose_out_2 = ctx.Input("TransposeOut2"); + auto *qk_out = ctx.Input("QKOut"); + auto *softmax_out = ctx.Input("SoftmaxOut"); + auto *attn_dropout_mask_out = + ctx.Input("AttnDropoutMaskOut"); + auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); + auto *src_mask_out = ctx.Input("SrcMaskOut"); + auto *ln_2_mean = ctx.Input("Ln2Mean"); + auto *ln_2_var = ctx.Input("Ln2Variance"); + auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); auto *bias_dropout_residual_out = ctx.Input("BiasDropoutResidualOut"); auto *fmha_out_data = fmha_out->data(); auto *transpose_out_2_data = transpose_out_2->data(); - auto *qk_out_data = qk_out->data(); - auto *qktv_out_data = qktv_out->data(); auto *softmax_out_data = softmax_out->data(); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->data(); - auto *out_linear_out_data = out_linear_out->data(); auto *dropout_mask_out_data = dropout_mask_out->data(); // output's grad diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 9b8b256a9ee54f85a11bd811eda69a63caf6be4f..f47f01465da2df9b4b8936ed6059401d08fb92c6 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -276,10 +276,12 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { "Input", "Dropout1Out", "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), - "Input", - "Dropout2Out", - "FusedFeedForwardGrad"); + if (!pre_layer_norm) { + OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), + "Input", + "Dropout2Out", + "FusedFeedForwardGrad"); + } OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), "Input", "Linear1Weight", @@ -368,10 +370,12 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker { op->SetInput("Dropout2Mask", this->Output("Dropout2Mask")); op->SetInput("Linear1Out", this->Output("Linear1Out")); op->SetInput("Dropout1Out", this->Output("Dropout1Out")); - op->SetInput("Dropout2Out", this->Output("Dropout2Out")); op->SetAttrMap(this->Attrs()); bool pre_layer_norm = PADDLE_GET_CONST(bool, op->GetAttr("pre_layer_norm")); + if (!pre_layer_norm) { + op->SetInput("Dropout2Out", this->Output("Dropout2Out")); + } op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); if (pre_layer_norm) { diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 95f13806562f2aa424a83bb5a01dcf2f7cdb2354..b8ba7b8810000ba5c2cc93306d044cde77083a49 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -339,7 +339,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const framework::Tensor& linear1_out, const framework::Tensor* ln1_out, const framework::Tensor& dropout1_out, - const framework::Tensor& dropout2_out, + const framework::Tensor* dropout2_out, const framework::Tensor& linear1_weight, const framework::Tensor* linear1_bias, const framework::Tensor& linear2_weight, @@ -422,7 +422,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( ctx, d_out.data(), - dropout2_out.data(), + dropout2_out->data(), dropout2_mask.data(), ln2_gamma_ptr, ln2_mean->data(), @@ -506,7 +506,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { auto* ln1_out = pre_layer_norm ? context.Input("Ln1Out") : nullptr; auto dropout1_out = *context.Input("Dropout1Out"); - auto dropout2_out = *context.Input("Dropout2Out"); + auto* dropout2_out = context.Input("Dropout2Out"); auto linear1_weight = *context.Input("Linear1Weight"); auto* linear1_bias = context.Input("Linear1Bias"); auto linear2_weight = *context.Input("Linear2Weight");