未验证 提交 6ef5d343 编写于 作者: S sneaxiy 提交者: GitHub

Refine the memory usage of fused_attention and fused_feedforward ops (#47236)

* fix fused_attention fused_feedforward

* fix ci

* fix ci

* fix ci PADDLE_GET_CONST

* fix ci ut
上级 17a03629
...@@ -17,6 +17,23 @@ ...@@ -17,6 +17,23 @@
#include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
template <typename T>
const T& GetAttrWithDefault(
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs,
const std::string& name) {
auto iter1 = attrs.find(name);
if (iter1 != attrs.end()) {
return PADDLE_GET_CONST(T, iter1->second);
}
auto iter2 = default_attrs.find(name);
if (iter2 != default_attrs.end()) {
return PADDLE_GET_CONST(T, iter2->second);
}
PADDLE_THROW(
phi::errors::InvalidArgument("Attribute(%s) cannot be found.", name));
}
class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase { class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
public: public:
fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() { fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() {
...@@ -240,7 +257,9 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase { ...@@ -240,7 +257,9 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase {
} }
void SetTensorWrapperDropout2Out( void SetTensorWrapperDropout2Out(
const paddle::experimental::Tensor& Dropout2Out) { const paddle::experimental::Tensor& Dropout2Out) {
Dropout2Out_ = egr::TensorWrapper(Dropout2Out, false); auto pre_layer_norm = GetAttrWithDefault<bool>(
attr_map_, default_attr_map_, "pre_layer_norm");
Dropout2Out_ = egr::TensorWrapper(Dropout2Out, pre_layer_norm);
} }
void SetTensorWrapperLinear1Bias( void SetTensorWrapperLinear1Bias(
const paddle::experimental::Tensor& Linear1Bias) { const paddle::experimental::Tensor& Linear1Bias) {
...@@ -427,27 +446,27 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase { ...@@ -427,27 +446,27 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase {
} }
void SetTensorWrapperOutLinearOut( void SetTensorWrapperOutLinearOut(
const paddle::experimental::Tensor& OutLinearOut) { const paddle::experimental::Tensor& OutLinearOut) {
OutLinearOut_ = egr::TensorWrapper(OutLinearOut, false); OutLinearOut_ = egr::TensorWrapper(OutLinearOut, true);
} }
void SetTensorWrapperOutLinearW( void SetTensorWrapperOutLinearW(
const paddle::experimental::Tensor& OutLinearW) { const paddle::experimental::Tensor& OutLinearW) {
OutLinearW_ = egr::TensorWrapper(OutLinearW, false); OutLinearW_ = egr::TensorWrapper(OutLinearW, false);
} }
void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) { void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) {
QKOut_ = egr::TensorWrapper(QKOut, false); QKOut_ = egr::TensorWrapper(QKOut, true);
} }
void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) { void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) {
QKTVOut_ = egr::TensorWrapper(QKTVOut, false); QKTVOut_ = egr::TensorWrapper(QKTVOut, true);
} }
void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) { void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) {
QKVBias_ = egr::TensorWrapper(QKVBias, false); QKVBias_ = egr::TensorWrapper(QKVBias, false);
} }
void SetTensorWrapperQKVBiasOut( void SetTensorWrapperQKVBiasOut(
const paddle::experimental::Tensor& QKVBiasOut) { const paddle::experimental::Tensor& QKVBiasOut) {
QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, false); QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, true);
} }
void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) { void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) {
QKVOut_ = egr::TensorWrapper(QKVOut, false); QKVOut_ = egr::TensorWrapper(QKVOut, true);
} }
void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) { void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) {
QKVW_ = egr::TensorWrapper(QKVW, false); QKVW_ = egr::TensorWrapper(QKVW, false);
......
...@@ -704,6 +704,13 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -704,6 +704,13 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer,
"QKVBiasOut",
"QKVOut",
"QKOut",
"QKTVOut",
"OutLinearOut");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -713,7 +720,9 @@ REGISTER_OPERATOR(fused_attention, ...@@ -713,7 +720,9 @@ REGISTER_OPERATOR(fused_attention,
ops::FusedAttentionOpMaker, ops::FusedAttentionOpMaker,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>, ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>); ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp); REGISTER_OPERATOR(fused_attention_grad,
ops::FusedAttentionGradOp,
ops::FusedAttentionGradNoNeedBufferInferer);
REGISTER_OP_VERSION(fused_attention) REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -416,13 +416,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -416,13 +416,11 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut"); auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2"); auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut"); auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
auto *qktv_out = ctx.Input<phi::DenseTensor>("QKTVOut");
auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut"); auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
auto *attn_dropout_mask_out = auto *attn_dropout_mask_out =
ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut"); ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut"); auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut"); auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
auto *out_linear_out = ctx.Input<phi::DenseTensor>("OutLinearOut");
auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean"); auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance"); auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut"); auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
...@@ -430,12 +428,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -430,12 +428,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut"); ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut");
auto *fmha_out_data = fmha_out->data<T>(); auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>(); auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>(); auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>(); (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>(); auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
// output's grad // output's grad
......
...@@ -276,10 +276,12 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel { ...@@ -276,10 +276,12 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
"Input", "Input",
"Dropout1Out", "Dropout1Out",
"FusedFeedForwardGrad"); "FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"), if (!pre_layer_norm) {
"Input", OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"),
"Dropout2Out", "Input",
"FusedFeedForwardGrad"); "Dropout2Out",
"FusedFeedForwardGrad");
}
OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"), OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"),
"Input", "Input",
"Linear1Weight", "Linear1Weight",
...@@ -368,10 +370,12 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -368,10 +370,12 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask")); op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
op->SetInput("Linear1Out", this->Output("Linear1Out")); op->SetInput("Linear1Out", this->Output("Linear1Out"));
op->SetInput("Dropout1Out", this->Output("Dropout1Out")); op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
bool pre_layer_norm = PADDLE_GET_CONST(bool, op->GetAttr("pre_layer_norm")); bool pre_layer_norm = PADDLE_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
if (!pre_layer_norm) {
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
if (pre_layer_norm) { if (pre_layer_norm) {
......
...@@ -337,7 +337,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -337,7 +337,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const phi::DenseTensor& linear1_out, const phi::DenseTensor& linear1_out,
const phi::DenseTensor* ln1_out, const phi::DenseTensor* ln1_out,
const phi::DenseTensor& dropout1_out, const phi::DenseTensor& dropout1_out,
const phi::DenseTensor& dropout2_out, const phi::DenseTensor* dropout2_out,
const phi::DenseTensor& linear1_weight, const phi::DenseTensor& linear1_weight,
const phi::DenseTensor* linear1_bias, const phi::DenseTensor* linear1_bias,
const phi::DenseTensor& linear2_weight, const phi::DenseTensor& linear2_weight,
...@@ -420,7 +420,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -420,7 +420,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad( fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx, ctx,
d_out.data<T>(), d_out.data<T>(),
dropout2_out.data<T>(), dropout2_out->data<T>(),
dropout2_mask.data<uint8_t>(), dropout2_mask.data<uint8_t>(),
ln2_gamma_ptr, ln2_gamma_ptr,
ln2_mean->data<U>(), ln2_mean->data<U>(),
...@@ -504,7 +504,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> { ...@@ -504,7 +504,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
auto* ln1_out = auto* ln1_out =
pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Out") : nullptr; pre_layer_norm ? context.Input<phi::DenseTensor>("Ln1Out") : nullptr;
auto dropout1_out = *context.Input<phi::DenseTensor>("Dropout1Out"); auto dropout1_out = *context.Input<phi::DenseTensor>("Dropout1Out");
auto dropout2_out = *context.Input<phi::DenseTensor>("Dropout2Out"); auto* dropout2_out = context.Input<phi::DenseTensor>("Dropout2Out");
auto linear1_weight = *context.Input<phi::DenseTensor>("Linear1Weight"); auto linear1_weight = *context.Input<phi::DenseTensor>("Linear1Weight");
auto* linear1_bias = context.Input<phi::DenseTensor>("Linear1Bias"); auto* linear1_bias = context.Input<phi::DenseTensor>("Linear1Bias");
auto linear2_weight = *context.Input<phi::DenseTensor>("Linear2Weight"); auto linear2_weight = *context.Input<phi::DenseTensor>("Linear2Weight");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册