From 6ef5d3436f615908d2be75d09bdca1f3bc2023d8 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 26 Oct 2022 10:49:43 +0800 Subject: [PATCH] Refine the memory usage of fused_attention and fused_feedforward ops (#47236) * fix fused_attention fused_feedforward * fix ci * fix ci * fix ci PADDLE_GET_CONST * fix ci ut --- .../api/manual/fluid_manual/nodes/nodes.h | 31 +++++++++++++++---- .../operators/fused/fused_attention_op.cc | 11 ++++++- .../operators/fused/fused_attention_op.cu | 5 --- .../operators/fused/fused_feedforward_op.cc | 14 ++++++--- .../operators/fused/fused_feedforward_op.cu | 6 ++-- 5 files changed, 47 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h index 32389e553d0..7e0d679689c 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h +++ b/paddle/fluid/eager/api/manual/fluid_manual/nodes/nodes.h @@ -17,6 +17,23 @@ #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/imperative/tracer.h" +template +const T& GetAttrWithDefault( + const paddle::framework::AttributeMap& attrs, + const paddle::framework::AttributeMap& default_attrs, + const std::string& name) { + auto iter1 = attrs.find(name); + if (iter1 != attrs.end()) { + return PADDLE_GET_CONST(T, iter1->second); + } + auto iter2 = default_attrs.find(name); + if (iter2 != default_attrs.end()) { + return PADDLE_GET_CONST(T, iter2->second); + } + PADDLE_THROW( + phi::errors::InvalidArgument("Attribute(%s) cannot be found.", name)); +} + class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { public: fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() { @@ -240,7 +257,9 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase { } void SetTensorWrapperDropout2Out( const paddle::experimental::Tensor& Dropout2Out) { - Dropout2Out_ = egr::TensorWrapper(Dropout2Out, false); + auto pre_layer_norm = GetAttrWithDefault( + attr_map_, default_attr_map_, "pre_layer_norm"); + Dropout2Out_ = egr::TensorWrapper(Dropout2Out, pre_layer_norm); } void SetTensorWrapperLinear1Bias( const paddle::experimental::Tensor& Linear1Bias) { @@ -427,27 +446,27 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase { } void SetTensorWrapperOutLinearOut( const paddle::experimental::Tensor& OutLinearOut) { - OutLinearOut_ = egr::TensorWrapper(OutLinearOut, false); + OutLinearOut_ = egr::TensorWrapper(OutLinearOut, true); } void SetTensorWrapperOutLinearW( const paddle::experimental::Tensor& OutLinearW) { OutLinearW_ = egr::TensorWrapper(OutLinearW, false); } void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) { - QKOut_ = egr::TensorWrapper(QKOut, false); + QKOut_ = egr::TensorWrapper(QKOut, true); } void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) { - QKTVOut_ = egr::TensorWrapper(QKTVOut, false); + QKTVOut_ = egr::TensorWrapper(QKTVOut, true); } void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) { QKVBias_ = egr::TensorWrapper(QKVBias, false); } void SetTensorWrapperQKVBiasOut( const paddle::experimental::Tensor& QKVBiasOut) { - QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, false); + QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, true); } void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) { - QKVOut_ = egr::TensorWrapper(QKVOut, false); + QKVOut_ = egr::TensorWrapper(QKVOut, true); } void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) { QKVW_ = egr::TensorWrapper(QKVW, false); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index e1c3bcdd83f..03c97ec345f 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -704,6 +704,13 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer, + "QKVBiasOut", + "QKVOut", + "QKOut", + "QKTVOut", + "OutLinearOut"); + } // namespace operators } // namespace paddle @@ -713,7 +720,9 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOpMaker, ops::FusedAttentionGradOpMaker, ops::FusedAttentionGradOpMaker); -REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); +REGISTER_OPERATOR(fused_attention_grad, + ops::FusedAttentionGradOp, + ops::FusedAttentionGradNoNeedBufferInferer); REGISTER_OP_VERSION(fused_attention) .AddCheckpoint( diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index d03b76adef3..f37c1a0be8b 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -416,13 +416,11 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *fmha_out = ctx.Input("FMHAOut"); auto *transpose_out_2 = ctx.Input("TransposeOut2"); auto *qk_out = ctx.Input("QKOut"); - auto *qktv_out = ctx.Input("QKTVOut"); auto *softmax_out = ctx.Input("SoftmaxOut"); auto *attn_dropout_mask_out = ctx.Input("AttnDropoutMaskOut"); auto *attn_dropout_out = ctx.Input("AttnDropoutOut"); auto *src_mask_out = ctx.Input("SrcMaskOut"); - auto *out_linear_out = ctx.Input("OutLinearOut"); auto *ln_2_mean = ctx.Input("Ln2Mean"); auto *ln_2_var = ctx.Input("Ln2Variance"); auto *dropout_mask_out = ctx.Input("DropoutMaskOut"); @@ -430,12 +428,9 @@ class FusedAttentionGradKernel : public framework::OpKernel { ctx.Input("BiasDropoutResidualOut"); auto *fmha_out_data = fmha_out->data(); auto *transpose_out_2_data = transpose_out_2->data(); - auto *qk_out_data = qk_out->data(); - auto *qktv_out_data = qktv_out->data(); auto *softmax_out_data = softmax_out->data(); auto *src_mask_out_data = (src_mask == nullptr) ? nullptr : src_mask_out->data(); - auto *out_linear_out_data = out_linear_out->data(); auto *dropout_mask_out_data = dropout_mask_out->data(); // output's grad diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cc b/paddle/fluid/operators/fused/fused_feedforward_op.cc index 71fe468f780..aaf84c7b1ea 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cc +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cc @@ -276,10 +276,12 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { "Input", "Dropout1Out", "FusedFeedForwardGrad"); - OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), - "Input", - "Dropout2Out", - "FusedFeedForwardGrad"); + if (!pre_layer_norm) { + OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), + "Input", + "Dropout2Out", + "FusedFeedForwardGrad"); + } OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), "Input", "Linear1Weight", @@ -368,10 +370,12 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker { op->SetInput("Dropout2Mask", this->Output("Dropout2Mask")); op->SetInput("Linear1Out", this->Output("Linear1Out")); op->SetInput("Dropout1Out", this->Output("Dropout1Out")); - op->SetInput("Dropout2Out", this->Output("Dropout2Out")); op->SetAttrMap(this->Attrs()); bool pre_layer_norm = PADDLE_GET_CONST(bool, op->GetAttr("pre_layer_norm")); + if (!pre_layer_norm) { + op->SetInput("Dropout2Out", this->Output("Dropout2Out")); + } op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); if (pre_layer_norm) { diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 90af129296c..39dfe969e3d 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -337,7 +337,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { const phi::DenseTensor& linear1_out, const phi::DenseTensor* ln1_out, const phi::DenseTensor& dropout1_out, - const phi::DenseTensor& dropout2_out, + const phi::DenseTensor* dropout2_out, const phi::DenseTensor& linear1_weight, const phi::DenseTensor* linear1_bias, const phi::DenseTensor& linear2_weight, @@ -420,7 +420,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( ctx, d_out.data(), - dropout2_out.data(), + dropout2_out->data(), dropout2_mask.data(), ln2_gamma_ptr, ln2_mean->data(), @@ -504,7 +504,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { auto* ln1_out = pre_layer_norm ? context.Input("Ln1Out") : nullptr; auto dropout1_out = *context.Input("Dropout1Out"); - auto dropout2_out = *context.Input("Dropout2Out"); + auto* dropout2_out = context.Input("Dropout2Out"); auto linear1_weight = *context.Input("Linear1Weight"); auto* linear1_bias = context.Input("Linear1Bias"); auto linear2_weight = *context.Input("Linear2Weight"); -- GitLab