From 8e02f290b99a0c9e896cdcd42b86b889d896eb57 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Sun, 29 Jan 2023 08:17:52 +0800 Subject: [PATCH] Fused attention pass backward pattern (#49855) --- .../framework/ir/fused_attention_pass.cc | 539 +++++++++++++++++- .../fluid/framework/ir/fused_attention_pass.h | 111 +++- .../unittests/test_fused_attention_pass.py | 19 +- 3 files changed, 659 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_attention_pass.cc b/paddle/fluid/framework/ir/fused_attention_pass.cc index 771bf958d21..72fa90db9b1 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.cc +++ b/paddle/fluid/framework/ir/fused_attention_pass.cc @@ -327,8 +327,441 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, bool has_attn_mask, bool do_dropout, bool add_residual) { - // TODO(Yuang Liu): finish the backward pattern - return nullptr; + // post layer norm + PDNode* post_layer_norm_grad_out_node{nullptr}; + if (post_layer_norm) { + auto* post_layer_norm_grad_node = + pattern->NewNode(post_layer_norm_grad_op_repr()) + ->assert_is_op("layer_norm_grad"); + auto* post_layer_norm_grad_bias_node = + pattern->NewNode(post_layer_norm_grad_bias_repr()) + ->assert_is_op_input("layer_norm_grad", "Bias"); + auto* post_layer_norm_grad_scale_node = + pattern->NewNode(post_layer_norm_grad_scale_repr()) + ->assert_is_op_input("layer_norm_grad", "Scale"); + auto* post_layer_norm_grad_mean_node = + pattern->NewNode(post_layer_norm_grad_mean_repr()) + ->assert_is_op_input("layer_norm_grad", "Mean"); + auto* post_layer_norm_grad_variance_node = + pattern->NewNode(post_layer_norm_grad_variance_repr()) + ->assert_is_op_input("layer_norm_grad", "Variance"); + auto* post_layer_norm_grad_x_node = + pattern->NewNode(post_layer_norm_grad_x_repr()) + ->assert_is_op_input("layer_norm_grad", "X"); + post_layer_norm_grad_out_node = + pattern->NewNode(post_layer_norm_grad_x_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "X@GRAD"); + auto* post_layer_norm_grad_bias_grad_node = + pattern->NewNode(post_layer_norm_grad_bias_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "Bias@GRAD"); + auto* post_layer_norm_grad_scale_grad_node = + pattern->NewNode(post_layer_norm_grad_scale_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "Scale@GRAD"); + post_layer_norm_grad_node + ->LinksFrom({x, + post_layer_norm_grad_bias_node, + post_layer_norm_grad_scale_node, + post_layer_norm_grad_mean_node, + post_layer_norm_grad_variance_node, + post_layer_norm_grad_x_node}) + .LinksTo({post_layer_norm_grad_out_node, + post_layer_norm_grad_bias_grad_node, + post_layer_norm_grad_scale_grad_node}); + } + + // add residual + PDNode* residual_ele_add_grad_out_node{nullptr}; + PDNode* residual_ele_add_grad_x_node{nullptr}; + PDNode* residual_ele_add_grad_x_grad_node{nullptr}; + if (add_residual) { + PDNode* ele_add_grad_input = x; + if (post_layer_norm) { + ele_add_grad_input = post_layer_norm_grad_out_node; + } + auto* residual_ele_add_grad_node = + pattern->NewNode(residual_ele_add_grad_op_repr()) + ->assert_is_op("elementwise_add_grad"); + residual_ele_add_grad_x_node = + pattern->NewNode(residual_ele_add_grad_x_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto* residual_ele_add_grad_bias_node = + pattern->NewNode(residual_ele_add_grad_bias_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + residual_ele_add_grad_out_node = + pattern->NewNode(residual_ele_add_grad_bias_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "Y@GRAD"); + residual_ele_add_grad_x_grad_node = + pattern->NewNode(residual_ele_add_grad_x_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "X@GRAD"); + ele_add_grad_input->assert_is_op_input("elementwise_add_grad", "Out@GRAD"); + residual_ele_add_grad_node + ->LinksFrom({ele_add_grad_input, + residual_ele_add_grad_x_node, + residual_ele_add_grad_bias_node}) + .LinksTo({residual_ele_add_grad_x_grad_node, + residual_ele_add_grad_out_node}); + } + + // get the real input x for dropout grad + PDNode* out_linear_grad_input_node = x; + if (post_layer_norm && !add_residual) { + out_linear_grad_input_node = post_layer_norm_grad_out_node; + } else if (add_residual) { + out_linear_grad_input_node = residual_ele_add_grad_out_node; + } + + // out linear part + auto* out_linear_dropout_grad_node = + pattern->NewNode(out_linear_dropout_grad_op_repr()) + ->assert_is_op("dropout_grad"); + auto* out_linear_dropout_grad_mask_node = + pattern->NewNode(out_linear_dropout_grad_mask_repr()) + ->assert_is_op_input("dropout_grad", "Mask"); + auto* out_linear_dropout_grad_out_node = + pattern->NewNode(out_linear_dropout_grad_out_repr()) + ->assert_is_op_output("dropout_grad", "X@GRAD"); + out_linear_grad_input_node->assert_is_op_input("dropout_grad", "Out@GRAD"); + out_linear_dropout_grad_node + ->LinksFrom( + {out_linear_grad_input_node, out_linear_dropout_grad_mask_node}) + .LinksTo({out_linear_dropout_grad_out_node}); + + auto* out_linear_ele_add_grad_node = + pattern->NewNode(out_linear_ele_add_grad_op_repr()) + ->assert_is_op("elementwise_add_grad"); + auto* out_linear_ele_add_grad_x_node = + pattern->NewNode(out_linear_ele_add_grad_x_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto* out_linear_ele_add_grad_bias_node = + pattern->NewNode(out_linear_ele_add_grad_bias_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + auto* out_linear_ele_add_grad_x_grad_node = + pattern->NewNode(out_linear_ele_add_grad_x_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "X@GRAD"); + auto* out_linear_ele_add_grad_bias_grad_node = + pattern->NewNode(out_linear_ele_add_grad_bias_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "Y@GRAD"); + out_linear_dropout_grad_out_node->assert_is_op_input("elementwise_add_grad", + "Out@GRAD"); + out_linear_ele_add_grad_node + ->LinksFrom({out_linear_dropout_grad_out_node, + out_linear_ele_add_grad_x_node, + out_linear_ele_add_grad_bias_node}) + .LinksTo({out_linear_ele_add_grad_x_grad_node, + out_linear_ele_add_grad_bias_grad_node}); + + auto* out_linear_matmul_grad_node = + pattern->NewNode(out_linear_matmul_grad_op_repr()) + ->assert_is_op("matmul_v2_grad"); + auto* out_linear_matmul_grad_x_node = + pattern->NewNode(out_linear_matmul_grad_x_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + auto* out_linear_matmul_grad_w_node = + pattern->NewNode(out_linear_matmul_grad_w_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto* out_linear_matmul_grad_x_grad_node = + pattern->NewNode(out_linear_matmul_grad_x_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "X@GRAD"); + auto* out_linear_matmul_grad_w_grad_node = + pattern->NewNode(out_linear_matmul_grad_w_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "Y@GRAD"); + out_linear_ele_add_grad_x_grad_node->assert_is_op_input("matmul_v2_grad", + "Out@GRAD"); + out_linear_matmul_grad_node + ->LinksFrom({out_linear_ele_add_grad_x_grad_node, + out_linear_matmul_grad_x_node, + out_linear_matmul_grad_w_node}) + .LinksTo({out_linear_matmul_grad_x_grad_node, + out_linear_matmul_grad_w_grad_node}); + + // core attention part + auto* qkv_reshape_grad_node = pattern->NewNode(qkv_reshape_grad_op_repr()) + ->assert_is_op("reshape2_grad"); + auto* qkv_reshape_grad_x_shape_node = + pattern->NewNode(qkv_reshape_grad_x_shape_repr()) + ->assert_is_op_input("reshape2_grad", "XShape"); + auto* qkv_reshape_grad_out_node = + pattern->NewNode(qkv_reshape_grad_out_repr()) + ->assert_is_op_output("reshape2_grad", "X@GRAD"); + out_linear_matmul_grad_x_grad_node->assert_is_op_input("reshape2_grad", + "Out@GRAD"); + qkv_reshape_grad_node + ->LinksFrom( + {out_linear_matmul_grad_x_grad_node, qkv_reshape_grad_x_shape_node}) + .LinksTo({qkv_reshape_grad_out_node}); + + auto* qkv_transpose_grad_node = pattern->NewNode(qkv_transpose_grad_op_repr()) + ->assert_is_op("transpose2_grad"); + auto* qkv_transpose_grad_x_shape_node = + pattern->NewNode(qkv_transpose_grad_x_shape_repr()) + ->assert_is_op_input("transpose2_grad", "XShape"); + auto* qkv_transpose_grad_out_node = + pattern->NewNode(qkv_transpose_grad_out_repr()) + ->assert_is_op_output("transpose2_grad", "X@GRAD"); + qkv_reshape_grad_out_node->assert_is_op_input("transpose2_grad", "Out@GRAD"); + qkv_transpose_grad_node + ->LinksFrom({qkv_reshape_grad_out_node, qkv_transpose_grad_x_shape_node}) + .LinksTo({qkv_transpose_grad_out_node}); + + auto* qkv_matmul_grad_node = pattern->NewNode(qkv_matmul_grad_op_repr()) + ->assert_is_op("matmul_v2_grad"); + auto* qkv_matmul_grad_x_node = + pattern->NewNode(qkv_matmul_grad_x_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + auto* qkv_matmul_grad_w_node = + pattern->NewNode(qkv_matmul_grad_w_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto* qkv_matmul_grad_x_grad_node = + pattern->NewNode(qkv_matmul_grad_x_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "X@GRAD"); + auto* qkv_matmul_grad_w_grad_node = + pattern->NewNode(qkv_matmul_grad_w_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "Y@GRAD"); + qkv_transpose_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD"); + qkv_matmul_grad_node + ->LinksFrom({qkv_transpose_grad_out_node, + qkv_matmul_grad_x_node, + qkv_matmul_grad_w_node}) + .LinksTo({qkv_matmul_grad_x_grad_node, qkv_matmul_grad_w_grad_node}); + + PDNode* attn_dropout_grad_out_node{nullptr}; + if (do_dropout) { + auto* attn_dropout_grad_node = pattern->NewNode(attn_dropout_grad_op_repr()) + ->assert_is_op("dropout_grad"); + auto* attn_dropout_grad_mask_node = + pattern->NewNode(attn_dropout_grad_mask_repr()) + ->assert_is_op_input("dropout_grad", "Mask"); + attn_dropout_grad_out_node = + pattern->NewNode(attn_dropout_grad_out_repr()) + ->assert_is_op_output("dropout_grad", "X@GRAD"); + qkv_matmul_grad_x_grad_node->assert_is_op_input("dropout_grad", "Out@GRAD"); + attn_dropout_grad_node + ->LinksFrom({qkv_matmul_grad_x_grad_node, attn_dropout_grad_mask_node}) + .LinksTo({attn_dropout_grad_out_node}); + } + + PDNode* qk_softmax_grad_input_node = + do_dropout ? attn_dropout_grad_out_node : qkv_matmul_grad_x_grad_node; + auto* qk_softmax_grad_node = + pattern->NewNode(qk_softmax_grad_op_repr())->assert_is_op("softmax_grad"); + auto* qk_softmax_grad_fwd_out_node = + pattern->NewNode(qk_softmax_grad_fwd_out_repr()) + ->assert_is_op_input("softmax_grad", "Out"); + auto* qk_softmax_grad_out = + pattern->NewNode(qk_softmax_grad_out_repr()) + ->assert_is_op_output("softmax_grad", "X@GRAD"); + qk_softmax_grad_input_node->assert_is_op_input("softmax_grad", "Out@GRAD"); + qk_softmax_grad_node + ->LinksFrom({qk_softmax_grad_input_node, qk_softmax_grad_fwd_out_node}) + .LinksTo({qk_softmax_grad_out}); + + PDNode* add_mask_ele_add_grad_x_grad_node{nullptr}; + if (has_attn_mask) { + auto* add_mask_ele_add_grad_node = + pattern->NewNode(add_mask_ele_add_grad_op_repr()) + ->assert_is_op("elementwise_add_grad"); + auto* add_mask_ele_add_grad_x_node = + pattern->NewNode(add_mask_ele_add_grad_x_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto* add_mask_ele_add_grad_bias_node = + pattern->NewNode(add_mask_ele_add_grad_bias_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + add_mask_ele_add_grad_x_grad_node = + pattern->NewNode(add_mask_ele_add_grad_x_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "X@GRAD"); + qk_softmax_grad_out->assert_is_op_input("elementwise_add_grad", "Out@GRAD"); + add_mask_ele_add_grad_node + ->LinksFrom({add_mask_ele_add_grad_x_node, + add_mask_ele_add_grad_bias_node, + qk_softmax_grad_out}) + .LinksTo({add_mask_ele_add_grad_x_grad_node}); + } + + PDNode* qk_scale_grad_input_node = + has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out; + auto* qk_scale_grad_node = + pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale"); + auto* qk_scale_grad_out_node = + pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale"); + qk_scale_grad_input_node->assert_is_op_input("scale", "X"); + qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node}) + .LinksTo({qk_scale_grad_out_node}); + + auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr()) + ->assert_is_op("matmul_v2_grad"); + auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + auto* qk_matmul_grad_w_node = pattern->NewNode(qk_matmul_grad_w_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto* qk_matmul_grad_x_grad_node = + pattern->NewNode(qk_matmul_grad_x_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "X@GRAD"); + auto* qk_matmul_grad_w_grad_node = + pattern->NewNode(qk_matmul_grad_w_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "Y@GRAD"); + qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD"); + qk_matmul_grad_node + ->LinksFrom({qk_scale_grad_out_node, + qk_matmul_grad_x_node, + qk_matmul_grad_w_node}) + .LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node}); + + // fuse qkv projection + auto* fuse_qkv_split_grad_node = + pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat"); + auto* fuse_qkv_split_grad_out_node = + pattern->NewNode(fuse_qkv_split_grad_out_repr()) + ->assert_is_op_output("concat"); + qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad + qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad + qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad + fuse_qkv_split_grad_node + ->LinksFrom({qk_matmul_grad_x_grad_node, + qk_matmul_grad_w_grad_node, + qkv_matmul_grad_w_grad_node}) + .LinksTo({fuse_qkv_split_grad_out_node}); + + auto* fuse_qkv_transpose_grad_node = + pattern->NewNode(fuse_qkv_transpose_grad_op_repr()) + ->assert_is_op("transpose2_grad"); + auto* fuse_qkv_transpose_grad_x_shape_node = + pattern->NewNode(fuse_qkv_transpose_grad_x_shape_repr()) + ->assert_is_op_input("transpose2_grad", "XShape"); + auto* fuse_qkv_transpose_grad_out_node = + pattern->NewNode(fuse_qkv_transpose_grad_out_repr()) + ->assert_is_op_output("transpose2_grad", "X@GRAD"); + fuse_qkv_split_grad_out_node->assert_is_op_input("transpose2_grad", + "Out@GRAD"); + fuse_qkv_transpose_grad_node + ->LinksFrom( + {fuse_qkv_split_grad_out_node, fuse_qkv_transpose_grad_x_shape_node}) + .LinksTo({fuse_qkv_transpose_grad_out_node}); + + auto* fuse_qkv_reshape_grad_node = + pattern->NewNode(fuse_qkv_reshape_grad_op_repr()) + ->assert_is_op("reshape2_grad"); + auto* fuse_qkv_reshape_grad_x_shape_node = + pattern->NewNode(fuse_qkv_reshape_grad_x_shape_repr()) + ->assert_is_op_input("reshape2_grad", "XShape"); + auto* fuse_qkv_reshape_grad_out_node = + pattern->NewNode(fuse_qkv_reshape_grad_out_repr()) + ->assert_is_op_output("reshape2_grad", "X@GRAD"); + fuse_qkv_transpose_grad_out_node->assert_is_op_input("reshape2_grad", + "Out@GRAD"); + fuse_qkv_reshape_grad_node + ->LinksFrom({fuse_qkv_transpose_grad_out_node, + fuse_qkv_reshape_grad_x_shape_node}) + .LinksTo({fuse_qkv_reshape_grad_out_node}); + + auto* fuse_qkv_ele_add_grad_node = + pattern->NewNode(fuse_qkv_ele_add_grad_op_repr()) + ->assert_is_op("elementwise_add_grad"); + auto* fuse_qkv_ele_add_grad_x_node = + pattern->NewNode(fuse_qkv_ele_add_grad_x_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto* fuse_qkv_ele_add_grad_bias_node = + pattern->NewNode(fuse_qkv_ele_add_grad_bias_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + auto* fuse_qkv_ele_add_grad_x_grad_node = + pattern->NewNode(fuse_qkv_ele_add_grad_x_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "X@GRAD"); + auto* fuse_qkv_ele_add_grad_bias_grad_node = + pattern->NewNode(fuse_qkv_ele_add_grad_bias_grad_repr()) + ->assert_is_op_output("elementwise_add_grad", "Y@GRAD"); + fuse_qkv_reshape_grad_out_node->assert_is_op_input("elementwise_add_grad", + "Out@GRAD"); + fuse_qkv_ele_add_grad_node + ->LinksFrom({fuse_qkv_reshape_grad_out_node, + fuse_qkv_ele_add_grad_x_node, + fuse_qkv_ele_add_grad_bias_node}) + .LinksTo({fuse_qkv_ele_add_grad_x_grad_node, + fuse_qkv_ele_add_grad_bias_grad_node}); + + auto* fuse_qkv_matmul_grad_node = + pattern->NewNode(fuse_qkv_matmul_grad_op_repr()) + ->assert_is_op("matmul_v2_grad"); + auto* fuse_qkv_matmul_grad_x_node = + pattern->NewNode(fuse_qkv_matmul_grad_x_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + auto* fuse_qkv_matmul_grad_w_node = + pattern->NewNode(fuse_qkv_matmul_grad_w_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto* fuse_qkv_matmul_grad_x_grad_node = + pattern->NewNode(fuse_qkv_matmul_grad_x_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "X@GRAD"); + auto* fuse_qkv_matmul_grad_w_grad_node = + pattern->NewNode(fuse_qkv_matmul_grad_w_grad_repr()) + ->assert_is_op_output("matmul_v2_grad", "Y@GRAD"); + fuse_qkv_ele_add_grad_x_grad_node->assert_is_op_input("matmul_v2_grad", + "Out@GRAD"); + fuse_qkv_matmul_grad_node + ->LinksFrom({fuse_qkv_ele_add_grad_x_grad_node, + fuse_qkv_matmul_grad_x_node, + fuse_qkv_matmul_grad_w_node}) + .LinksTo( + {fuse_qkv_matmul_grad_x_grad_node, fuse_qkv_matmul_grad_w_grad_node}); + + if (!pre_layer_norm) { + return fuse_qkv_matmul_grad_x_grad_node; + } + + // pre layer norm + auto* pre_layer_norm_grad_node = + pattern->NewNode(pre_layer_norm_grad_op_repr()) + ->assert_is_op("layer_norm_grad"); + auto* pre_layer_norm_grad_scale_node = + pattern->NewNode(pre_layer_norm_grad_scale_repr()) + ->assert_is_op_input("layer_norm_grad", "Scale"); + auto* pre_layer_norm_grad_bias_node = + pattern->NewNode(pre_layer_norm_grad_bias_repr()) + ->assert_is_op_input("layer_norm_grad", "Bias"); + auto* pre_layer_norm_grad_mean_node = + pattern->NewNode(pre_layer_norm_grad_mean_repr()) + ->assert_is_op_input("layer_norm_grad", "Mean"); + auto* pre_layer_norm_grad_variance_node = + pattern->NewNode(pre_layer_norm_grad_variance_repr()) + ->assert_is_op_input("layer_norm_grad", "Variance"); + auto* pre_layer_norm_grad_x_node = + add_residual ? residual_ele_add_grad_x_node + : pattern->NewNode(pre_layer_norm_grad_x_repr()) + ->assert_is_op_input("layer_norm_grad", "X"); + auto* pre_layer_norm_grad_scale_grad_node = + pattern->NewNode(pre_layer_norm_grad_scale_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "Scale@GRAD"); + auto* pre_layer_norm_grad_bias_grad_node = + pattern->NewNode(pre_layer_norm_grad_bias_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "Bias@GRAD"); + auto* pre_layer_norm_grad_x_grad_node = + pattern->NewNode(pre_layer_norm_grad_x_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "X@GRAD"); + fuse_qkv_matmul_grad_x_grad_node->assert_is_op_input("layer_norm_grad", + "Y@GRAD"); + pre_layer_norm_grad_node + ->LinksFrom({fuse_qkv_matmul_grad_x_grad_node, + pre_layer_norm_grad_scale_node, + pre_layer_norm_grad_bias_node, + pre_layer_norm_grad_mean_node, + pre_layer_norm_grad_variance_node, + pre_layer_norm_grad_x_node}) + .LinksTo({pre_layer_norm_grad_scale_grad_node, + pre_layer_norm_grad_bias_grad_node, + pre_layer_norm_grad_x_grad_node}); + + if (!add_residual) { + return pre_layer_norm_grad_x_grad_node; + } + + auto* grad_accumulation_sum_node = + pattern->NewNode(grad_accumulation_sum_op_repr())->assert_is_op("sum"); + auto* grad_accumulation_sum_out_node = + pattern->NewNode(grad_accumulation_out_repr()) + ->assert_is_op_output("sum"); + grad_accumulation_sum_node + ->LinksFrom( + {pre_layer_norm_grad_x_grad_node, residual_ele_add_grad_x_grad_node}) + .LinksTo({grad_accumulation_sum_out_node}); + + return grad_accumulation_sum_out_node; } } // namespace patterns @@ -437,7 +870,107 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { } ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { - // TODO(Yuang Liu): finish the pass + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "x")) + ->AsInput() + ->assert_is_op_input("layer_norm_grad", "Y@GRAD"); + patterns::FusedAttentionGradPattern fused_attention_grad_pattern( + gpd.mutable_pattern(), "fused_attention_grad_pattern"); + + fused_attention_grad_pattern(x, + /* pre_layer_norm */ true, + /* post_layer_norm */ true, + /* has_attn_mask */ true, + /* do_dropout */ true, + /* add_residual */ true); + + int found_fused_attention = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion"; + + GET_IR_NODE_FROM_SUBGRAPH(post_layer_norm_grad_op_node, + post_layer_norm_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node, + residual_ele_add_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_grad_op_node, + out_linear_dropout_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_grad_op_node, + out_linear_ele_add_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_grad_op_node, + out_linear_matmul_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qkv_reshape_grad_op_node, + qkv_reshape_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qkv_transpose_grad_op_node, + qkv_transpose_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qkv_matmul_grad_op_node, + qkv_matmul_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(attn_dropout_grad_op_node, + attn_dropout_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qk_softmax_grad_op_node, + qk_softmax_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_grad_op_node, + add_mask_ele_add_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qk_scale_grad_op_node, qk_scale_grad_op, fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qk_matmul_grad_op_node, + qk_matmul_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_split_grad_op_node, + fuse_qkv_split_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_grad_op_node, + fuse_qkv_transpose_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_reshape_grad_op_node, + fuse_qkv_reshape_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_grad_op_node, + fuse_qkv_ele_add_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_op_node, + fuse_qkv_matmul_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_grad_op_node, + pre_layer_norm_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(grad_accumulation_sum_op_node, + grad_accumulation_sum_op, + fused_attention_grad_pattern); + + // TODO(Yuang Liu): finish the handler + + GraphSafeRemoveNodes( + g, {post_layer_norm_grad_op_node, residual_ele_add_grad_op_node, + out_linear_dropout_grad_op_node, out_linear_ele_add_grad_op_node, + out_linear_matmul_grad_op_node, qkv_reshape_grad_op_node, + qkv_transpose_grad_op_node, qkv_matmul_grad_op_node, + attn_dropout_grad_op_node, qk_softmax_grad_op_node, + add_mask_ele_add_grad_op_node, qk_scale_grad_op_node, + qk_matmul_grad_op_node, fuse_qkv_split_grad_op_node, + fuse_qkv_transpose_grad_op_node, fuse_qkv_reshape_grad_op_node, + fuse_qkv_ele_add_grad_op_node, fuse_qkv_matmul_grad_op_node, + pre_layer_norm_grad_op_node, grad_accumulation_sum_op_node}); + + found_fused_attention++; + }; + + gpd(graph, handler); + AddStatis(found_fused_attention); + return graph; } diff --git a/paddle/fluid/framework/ir/fused_attention_pass.h b/paddle/fluid/framework/ir/fused_attention_pass.h index 5ec1aac41ec..d360f7f6520 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.h +++ b/paddle/fluid/framework/ir/fused_attention_pass.h @@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase { bool do_dropout, // dropout the softmax(qk) or not bool add_residual); // add residual to out linear or not - // TODO(Yuang Liu): add backward pattern + // post layer norm grad + PATTERN_DECL_NODE(post_layer_norm_grad_op); + PATTERN_DECL_NODE(post_layer_norm_grad_scale); + PATTERN_DECL_NODE(post_layer_norm_grad_bias); + PATTERN_DECL_NODE(post_layer_norm_grad_mean); + PATTERN_DECL_NODE(post_layer_norm_grad_variance); + PATTERN_DECL_NODE(post_layer_norm_grad_x); + PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad); + PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad); + PATTERN_DECL_NODE(post_layer_norm_grad_x_grad); + + // residual grad + PATTERN_DECL_NODE(residual_ele_add_grad_op); + PATTERN_DECL_NODE(residual_ele_add_grad_x); + PATTERN_DECL_NODE(residual_ele_add_grad_bias); + PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad); + PATTERN_DECL_NODE(residual_ele_add_grad_x_grad); + + // out linear grad + PATTERN_DECL_NODE(out_linear_dropout_grad_op); + PATTERN_DECL_NODE(out_linear_dropout_grad_mask); + PATTERN_DECL_NODE(out_linear_dropout_grad_out); + + PATTERN_DECL_NODE(out_linear_ele_add_grad_op); + PATTERN_DECL_NODE(out_linear_ele_add_grad_x); + PATTERN_DECL_NODE(out_linear_ele_add_grad_bias); + PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad); + PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad); + + PATTERN_DECL_NODE(out_linear_matmul_grad_op); + PATTERN_DECL_NODE(out_linear_matmul_grad_x); + PATTERN_DECL_NODE(out_linear_matmul_grad_w); + PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad); + PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad); + + // core attention grad + PATTERN_DECL_NODE(qkv_reshape_grad_op); + PATTERN_DECL_NODE(qkv_reshape_grad_x_shape); + PATTERN_DECL_NODE(qkv_reshape_grad_out); + + PATTERN_DECL_NODE(qkv_transpose_grad_op); + PATTERN_DECL_NODE(qkv_transpose_grad_x_shape); + PATTERN_DECL_NODE(qkv_transpose_grad_out); + + PATTERN_DECL_NODE(qkv_matmul_grad_op); + PATTERN_DECL_NODE(qkv_matmul_grad_x); + PATTERN_DECL_NODE(qkv_matmul_grad_w); + PATTERN_DECL_NODE(qkv_matmul_grad_x_grad); + PATTERN_DECL_NODE(qkv_matmul_grad_w_grad); + + PATTERN_DECL_NODE(attn_dropout_grad_op); + PATTERN_DECL_NODE(attn_dropout_grad_mask); + PATTERN_DECL_NODE(attn_dropout_grad_out); + + PATTERN_DECL_NODE(qk_softmax_grad_op); + PATTERN_DECL_NODE(qk_softmax_grad_fwd_out); + PATTERN_DECL_NODE(qk_softmax_grad_out); + + PATTERN_DECL_NODE(add_mask_ele_add_grad_op); + PATTERN_DECL_NODE(add_mask_ele_add_grad_x); + PATTERN_DECL_NODE(add_mask_ele_add_grad_bias); + PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad); + + PATTERN_DECL_NODE(qk_scale_grad_op); + PATTERN_DECL_NODE(qk_scale_grad_out); + + PATTERN_DECL_NODE(qk_matmul_grad_op); + PATTERN_DECL_NODE(qk_matmul_grad_x); + PATTERN_DECL_NODE(qk_matmul_grad_w); + PATTERN_DECL_NODE(qk_matmul_grad_x_grad); + PATTERN_DECL_NODE(qk_matmul_grad_w_grad); + + // fuse qkv projection grad + PATTERN_DECL_NODE(fuse_qkv_split_grad_op); // concat op + PATTERN_DECL_NODE(fuse_qkv_split_grad_out); + + PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op); + PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape); + PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out); + + PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op); + PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape); + PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out); + + PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op); + PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x); + PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias); + PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad); + PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad); + + PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op); + PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x); + PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w); + PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad); + PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad); + + // pre layer norm grad + PATTERN_DECL_NODE(pre_layer_norm_grad_op); + PATTERN_DECL_NODE(pre_layer_norm_grad_scale); + PATTERN_DECL_NODE(pre_layer_norm_grad_bias); + PATTERN_DECL_NODE(pre_layer_norm_grad_mean); + PATTERN_DECL_NODE(pre_layer_norm_grad_variance); + PATTERN_DECL_NODE(pre_layer_norm_grad_x); + PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad); + PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad); + PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad); + + // grad accumulation + PATTERN_DECL_NODE(grad_accumulation_sum_op); + PATTERN_DECL_NODE(grad_accumulation_out); }; } // namespace patterns diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py index ff2e2f73286..cce05d8747c 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py @@ -114,9 +114,7 @@ class TestFusedAttentionPass(unittest.TestCase): hidden_size = 768 num_heads = 12 - x_data = np.random.rand(batch_size, seq_len, hidden_size).astype( - 'float32' - ) + x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32') mask_data = np.random.rand( batch_size, num_heads, seq_len, seq_len ).astype('float32') @@ -127,7 +125,7 @@ class TestFusedAttentionPass(unittest.TestCase): with paddle.static.program_guard(main_prog, startup_prog): data = paddle.static.data( name="x", - shape=[-1, seq_len, hidden_size], + shape=[-1, seq_len, seq_len], dtype='float32', ) if self.add_mask: @@ -138,6 +136,7 @@ class TestFusedAttentionPass(unittest.TestCase): ) else: attn_mask = None + data_linear = paddle.nn.Linear(seq_len, hidden_size) multi_head_attn = MultiHeadAttention( hidden_size, num_heads, @@ -146,7 +145,9 @@ class TestFusedAttentionPass(unittest.TestCase): post_ln=self.post_ln, attn_dropout=self.attn_dropout, ) - out = multi_head_attn(data, attn_mask) + + attn_input = data_linear(data) + out = multi_head_attn(attn_input, attn_mask) loss = paddle.mean(out) sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) @@ -156,7 +157,13 @@ class TestFusedAttentionPass(unittest.TestCase): pass_manager.apply([main_prog], [startup_prog]) ops = main_prog.global_block().ops - assert ops[0].type == 'reduce_mean' + assert ops[2].type == 'reduce_mean' + assert ops[4].type == 'reduce_mean_grad' + # two ops for linear, one op for reduce mean + # one fill constant + # one op for reduce mean grad, two ops for linear bwd + # the eighth op should be the optimizer + assert ops[7].type == 'sgd' if __name__ == "__main__": -- GitLab