未验证 提交 9a6dd8f8 编写于 作者: S sneaxiy 提交者: GitHub

[Cherry-pick][Release/2.4]Refine the memory usage of fused_attention and...

[Cherry-pick][Release/2.4]Refine the memory usage of fused_attention and fused_feedforward ops (#47235)

* fix fused_attention fused_feedforward

* fix ci

* fix ci

* fix ci PADDLE_GET_CONST

* fix ci ut
上级 942ab42f
......@@ -17,6 +17,23 @@
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/imperative/tracer.h"
template <typename T>
const T& GetAttrWithDefault(
const paddle::framework::AttributeMap& attrs,
const paddle::framework::AttributeMap& default_attrs,
const std::string& name) {
auto iter1 = attrs.find(name);
if (iter1 != attrs.end()) {
return PADDLE_GET_CONST(T, iter1->second);
}
auto iter2 = default_attrs.find(name);
if (iter2 != default_attrs.end()) {
return PADDLE_GET_CONST(T, iter2->second);
}
PADDLE_THROW(
phi::errors::InvalidArgument("Attribute(%s) cannot be found.", name));
}
class fused_gate_attentionGradNodeCompat : public egr::GradNodeBase {
public:
fused_gate_attentionGradNodeCompat() : egr::GradNodeBase() {
......@@ -240,7 +257,9 @@ class fused_feedforwardGradNodeCompat : public egr::GradNodeBase {
}
void SetTensorWrapperDropout2Out(
const paddle::experimental::Tensor& Dropout2Out) {
Dropout2Out_ = egr::TensorWrapper(Dropout2Out, false);
auto pre_layer_norm = GetAttrWithDefault<bool>(
attr_map_, default_attr_map_, "pre_layer_norm");
Dropout2Out_ = egr::TensorWrapper(Dropout2Out, pre_layer_norm);
}
void SetTensorWrapperLinear1Bias(
const paddle::experimental::Tensor& Linear1Bias) {
......@@ -427,27 +446,27 @@ class fused_attentionGradNodeCompat : public egr::GradNodeBase {
}
void SetTensorWrapperOutLinearOut(
const paddle::experimental::Tensor& OutLinearOut) {
OutLinearOut_ = egr::TensorWrapper(OutLinearOut, false);
OutLinearOut_ = egr::TensorWrapper(OutLinearOut, true);
}
void SetTensorWrapperOutLinearW(
const paddle::experimental::Tensor& OutLinearW) {
OutLinearW_ = egr::TensorWrapper(OutLinearW, false);
}
void SetTensorWrapperQKOut(const paddle::experimental::Tensor& QKOut) {
QKOut_ = egr::TensorWrapper(QKOut, false);
QKOut_ = egr::TensorWrapper(QKOut, true);
}
void SetTensorWrapperQKTVOut(const paddle::experimental::Tensor& QKTVOut) {
QKTVOut_ = egr::TensorWrapper(QKTVOut, false);
QKTVOut_ = egr::TensorWrapper(QKTVOut, true);
}
void SetTensorWrapperQKVBias(const paddle::experimental::Tensor& QKVBias) {
QKVBias_ = egr::TensorWrapper(QKVBias, false);
}
void SetTensorWrapperQKVBiasOut(
const paddle::experimental::Tensor& QKVBiasOut) {
QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, false);
QKVBiasOut_ = egr::TensorWrapper(QKVBiasOut, true);
}
void SetTensorWrapperQKVOut(const paddle::experimental::Tensor& QKVOut) {
QKVOut_ = egr::TensorWrapper(QKVOut, false);
QKVOut_ = egr::TensorWrapper(QKVOut, true);
}
void SetTensorWrapperQKVW(const paddle::experimental::Tensor& QKVW) {
QKVW_ = egr::TensorWrapper(QKVW, false);
......
......@@ -704,6 +704,13 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer,
"QKVBiasOut",
"QKVOut",
"QKOut",
"QKTVOut",
"OutLinearOut");
} // namespace operators
} // namespace paddle
......@@ -713,7 +720,9 @@ REGISTER_OPERATOR(fused_attention,
ops::FusedAttentionOpMaker,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);
REGISTER_OPERATOR(fused_attention_grad,
ops::FusedAttentionGradOp,
ops::FusedAttentionGradNoNeedBufferInferer);
REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
......
......@@ -410,28 +410,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
(out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
// fw output
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
auto *qk_out = ctx.Input<Tensor>("QKOut");
auto *qktv_out = ctx.Input<Tensor>("QKTVOut");
auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
auto *attn_dropout_mask_out = ctx.Input<Tensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Input<Tensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Input<Tensor>("SrcMaskOut");
auto *out_linear_out = ctx.Input<Tensor>("OutLinearOut");
auto *ln_2_mean = ctx.Input<Tensor>("Ln2Mean");
auto *ln_2_var = ctx.Input<Tensor>("Ln2Variance");
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
auto *attn_dropout_mask_out =
ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
// output's grad
......
......@@ -276,10 +276,12 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {
"Input",
"Dropout1Out",
"FusedFeedForwardGrad");
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"),
"Input",
"Dropout2Out",
"FusedFeedForwardGrad");
if (!pre_layer_norm) {
OP_INOUT_CHECK(ctx->HasInput("Dropout2Out"),
"Input",
"Dropout2Out",
"FusedFeedForwardGrad");
}
OP_INOUT_CHECK(ctx->HasInput("Linear1Weight"),
"Input",
"Linear1Weight",
......@@ -368,10 +370,12 @@ class FusedFeedForwardOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Dropout2Mask", this->Output("Dropout2Mask"));
op->SetInput("Linear1Out", this->Output("Linear1Out"));
op->SetInput("Dropout1Out", this->Output("Dropout1Out"));
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
op->SetAttrMap(this->Attrs());
bool pre_layer_norm = PADDLE_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
if (!pre_layer_norm) {
op->SetInput("Dropout2Out", this->Output("Dropout2Out"));
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
if (pre_layer_norm) {
......
......@@ -339,7 +339,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
const framework::Tensor& linear1_out,
const framework::Tensor* ln1_out,
const framework::Tensor& dropout1_out,
const framework::Tensor& dropout2_out,
const framework::Tensor* dropout2_out,
const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
......@@ -422,7 +422,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx,
d_out.data<T>(),
dropout2_out.data<T>(),
dropout2_out->data<T>(),
dropout2_mask.data<uint8_t>(),
ln2_gamma_ptr,
ln2_mean->data<U>(),
......@@ -506,7 +506,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
auto* ln1_out =
pre_layer_norm ? context.Input<framework::Tensor>("Ln1Out") : nullptr;
auto dropout1_out = *context.Input<framework::Tensor>("Dropout1Out");
auto dropout2_out = *context.Input<framework::Tensor>("Dropout2Out");
auto* dropout2_out = context.Input<framework::Tensor>("Dropout2Out");
auto linear1_weight = *context.Input<framework::Tensor>("Linear1Weight");
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto linear2_weight = *context.Input<framework::Tensor>("Linear2Weight");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册