From 2b848aef55e47c514cef216af51ba7bcad1e43b7 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Wed, 1 Feb 2023 20:10:35 +0800 Subject: [PATCH] Fused attention pass fwd, create the fused_attention op. (#50125) --- .../framework/ir/fused_attention_pass.cc | 246 +++++++++++++++--- .../fluid/framework/ir/fused_attention_pass.h | 24 +- .../unittests/test_fused_attention_pass.py | 14 +- 3 files changed, 231 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_attention_pass.cc b/paddle/fluid/framework/ir/fused_attention_pass.cc index 72fa90db9b..7b0f469ff8 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.cc +++ b/paddle/fluid/framework/ir/fused_attention_pass.cc @@ -22,7 +22,6 @@ namespace patterns { PDNode* FusedAttentionPattern::operator()(PDNode* x, bool pre_layer_norm, - bool post_layer_norm, bool has_attn_mask, bool do_dropout, bool add_residual) { @@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node}) .LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node}); - if (!add_residual && !post_layer_norm) { + if (!add_residual && pre_layer_norm) { return out_linear_dropout_out_node; } @@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, residual_ele_add_node->LinksFrom({x, out_linear_dropout_out_node}) .LinksTo({residual_ele_add_out_node}); - if (!post_layer_norm) { + if (pre_layer_norm) { return residual_ele_add_out_node; } } @@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, PDNode* FusedAttentionGradPattern::operator()(PDNode* x, bool pre_layer_norm, - bool post_layer_norm, bool has_attn_mask, bool do_dropout, bool add_residual) { // post layer norm PDNode* post_layer_norm_grad_out_node{nullptr}; - if (post_layer_norm) { + if (!pre_layer_norm) { auto* post_layer_norm_grad_node = pattern->NewNode(post_layer_norm_grad_op_repr()) ->assert_is_op("layer_norm_grad"); @@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, PDNode* residual_ele_add_grad_x_grad_node{nullptr}; if (add_residual) { PDNode* ele_add_grad_input = x; - if (post_layer_norm) { + if (!pre_layer_norm) { ele_add_grad_input = post_layer_norm_grad_out_node; } auto* residual_ele_add_grad_node = @@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, // get the real input x for dropout grad PDNode* out_linear_grad_input_node = x; - if (post_layer_norm && !add_residual) { + if (!pre_layer_norm && !add_residual) { out_linear_grad_input_node = post_layer_norm_grad_out_node; } else if (add_residual) { out_linear_grad_input_node = residual_ele_add_grad_out_node; @@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, void FusedAttentionsPass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); - graph = PreMaskDropResPostFwd(graph); - graph = PreMaskDropResPostBwd(graph); + graph = PreMaskDropResFwd(graph); + graph = PreMaskDropResBwd(graph); } -ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { +ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "x")) @@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { fused_attention_pattern(x, /* pre_layer_norm */ true, - /* post_layer_norm */ true, /* has_attn_mask */ true, /* do_dropout */ true, /* add_residual */ true); @@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { fused_attention_pattern); GET_IR_NODE_FROM_SUBGRAPH( residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern); + + OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block()); + fused_attention_op_desc.SetType("fused_attention"); + fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()}); + + fused_attention_op_desc.SetAttr("pre_layer_norm", true); + GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_scale_node, + pre_layer_norm_scale, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + pre_layer_norm_bias_node, pre_layer_norm_bias, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + pre_layer_norm_out_node, pre_layer_norm_out, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + pre_layer_norm_mean_node, pre_layer_norm_mean, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_variance_node, + pre_layer_norm_variance, + fused_attention_pattern); + fused_attention_op_desc.SetInput("LnScale", + {pre_layer_norm_scale_node->Name()}); + fused_attention_op_desc.SetInput("LnBias", + {pre_layer_norm_bias_node->Name()}); + fused_attention_op_desc.SetOutput("LnOut", + {pre_layer_norm_out_node->Name()}); + fused_attention_op_desc.SetOutput("LnMean", + {pre_layer_norm_mean_node->Name()}); + fused_attention_op_desc.SetOutput("LnVariance", + {pre_layer_norm_variance_node->Name()}); + fused_attention_op_desc.SetAttr( + "epsilon", + PADDLE_GET_CONST(float, + pre_layer_norm_op_node->Op()->GetAttr("epsilon"))); + + fused_attention_op_desc.SetAttr("transpose_qkv_wb", true); + std::vector shape = PADDLE_GET_CONST( + std::vector, fuse_qkv_reshape_op_node->Op()->GetAttr("shape")); + fused_attention_op_desc.SetAttr("num_heads", shape[2]); GET_IR_NODE_FROM_SUBGRAPH( - post_layer_norm_op_node, post_layer_norm_op, fused_attention_pattern); + fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node, + fuse_qkv_ele_add_bias, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_out_node, + fuse_qkv_ele_add_out, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_out_node, + fuse_qkv_transpose_out, + fused_attention_pattern); + fused_attention_op_desc.SetInput("QKVW", {fuse_qkv_matmul_w_node->Name()}); + fused_attention_op_desc.SetInput("QKVBias", + {fuse_qkv_ele_add_bias_node->Name()}); + fused_attention_op_desc.SetOutput("QKVOut", + {fuse_qkv_matmul_out_node->Name()}); + fused_attention_op_desc.SetOutput("QKVBiasOut", + {fuse_qkv_ele_add_out_node->Name()}); + fused_attention_op_desc.SetOutput("TransposeOut2", + {fuse_qkv_transpose_out_node->Name()}); - // TODO(Yuang Liu): finish the handler + GET_IR_NODE_FROM_SUBGRAPH( + qk_matmul_out_node, qk_matmul_out, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_mask_node, + add_mask_ele_add_mask, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_out_node, + add_mask_ele_add_out, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qk_softmax_out_node, qk_softmax_out, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + attn_dropout_out_node, attn_dropout_out, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + attn_dropout_mask_node, attn_dropout_mask, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qkv_matmul_out_node, qkv_matmul_out, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qkv_reshape_out_node, qkv_reshape_out, fused_attention_pattern); + fused_attention_op_desc.SetOutput("QKOut", {qk_matmul_out_node->Name()}); + fused_attention_op_desc.SetInput("SrcMask", + {add_mask_ele_add_mask_node->Name()}); + fused_attention_op_desc.SetOutput("SrcMaskOut", + {add_mask_ele_add_out_node->Name()}); + fused_attention_op_desc.SetOutput("SoftmaxOut", + {qk_softmax_out_node->Name()}); + fused_attention_op_desc.SetAttr( + "attn_dropout_rate", + PADDLE_GET_CONST(float, + attn_dropout_op_node->Op()->GetAttr("dropout_prob"))); + fused_attention_op_desc.SetAttr( + "is_test", + PADDLE_GET_CONST(bool, attn_dropout_op_node->Op()->GetAttr("is_test"))); + fused_attention_op_desc.SetAttr( + "attn_dropout_fix_seed", + PADDLE_GET_CONST(bool, + attn_dropout_op_node->Op()->GetAttr("fix_seed"))); + fused_attention_op_desc.SetAttr( + "attn_dropout_seed", + PADDLE_GET_CONST(int, attn_dropout_op_node->Op()->GetAttr("seed"))); + fused_attention_op_desc.SetAttr( + "attn_dropout_implementation", + PADDLE_GET_CONST( + std::string, + attn_dropout_op_node->Op()->GetAttr("dropout_implementation"))); + fused_attention_op_desc.SetOutput("AttnDropoutMaskOut", + {attn_dropout_mask_node->Name()}); + fused_attention_op_desc.SetOutput("AttnDropoutOut", + {attn_dropout_out_node->Name()}); + fused_attention_op_desc.SetOutput("QKTVOut", {qkv_matmul_out_node->Name()}); + fused_attention_op_desc.SetOutput("FMHAOut", + {qkv_reshape_out_node->Name()}); + + GET_IR_NODE_FROM_SUBGRAPH( + out_linear_matmul_w_node, out_linear_matmul_w, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_out_node, + out_linear_matmul_out, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_bias_node, + out_linear_ele_add_bias, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_out_node, + out_linear_ele_add_out, + fused_attention_pattern); + fused_attention_op_desc.SetInput("OutLinearW", + {out_linear_matmul_w_node->Name()}); + fused_attention_op_desc.SetInput("OutLinearBias", + {out_linear_ele_add_bias_node->Name()}); + fused_attention_op_desc.SetOutput("OutLinearOut", + {out_linear_matmul_out_node->Name()}); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_mask_node, + out_linear_dropout_mask, + fused_attention_pattern); + fused_attention_op_desc.SetAttr( + "dropout_rate", + PADDLE_GET_CONST( + float, out_linear_dropout_op_node->Op()->GetAttr("dropout_prob"))); + fused_attention_op_desc.SetAttr( + "dropout_fix_seed", + PADDLE_GET_CONST( + bool, out_linear_dropout_op_node->Op()->GetAttr("fix_seed"))); + fused_attention_op_desc.SetAttr( + "dropout_seed", + PADDLE_GET_CONST(int, + out_linear_dropout_op_node->Op()->GetAttr("seed"))); + fused_attention_op_desc.SetAttr( + "dropout_implementation", + PADDLE_GET_CONST(std::string, + out_linear_dropout_op_node->Op()->GetAttr( + "dropout_implementation"))); + fused_attention_op_desc.SetOutput("DropoutMaskOut", + {out_linear_dropout_mask_node->Name()}); + + GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_out_node, + residual_ele_add_out, + fused_attention_pattern); + fused_attention_op_desc.SetAttr("add_residual", true); + fused_attention_op_desc.SetOutput("Y", {residual_ele_add_out_node->Name()}); + + auto fused_attention_node = g->CreateOpNode(&fused_attention_op_desc); + + IR_NODE_LINK_TO(subgraph.at(x), fused_attention_node); + IR_NODE_LINK_TO(pre_layer_norm_scale_node, fused_attention_node); + IR_NODE_LINK_TO(pre_layer_norm_bias_node, fused_attention_node); + IR_NODE_LINK_TO(fuse_qkv_matmul_w_node, fused_attention_node); + IR_NODE_LINK_TO(fuse_qkv_ele_add_bias_node, fused_attention_node); + IR_NODE_LINK_TO(add_mask_ele_add_mask_node, fused_attention_node); + IR_NODE_LINK_TO(out_linear_matmul_w_node, fused_attention_node); + IR_NODE_LINK_TO(out_linear_ele_add_bias_node, fused_attention_node); + + IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_out_node); + IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_mean_node); + IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_variance_node); + IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_matmul_out_node); + IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_ele_add_out_node); + IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_transpose_out_node); + IR_NODE_LINK_TO(fused_attention_node, qk_matmul_out_node); + IR_NODE_LINK_TO(fused_attention_node, add_mask_ele_add_out_node); + IR_NODE_LINK_TO(fused_attention_node, qk_softmax_out_node); + IR_NODE_LINK_TO(fused_attention_node, attn_dropout_mask_node); + IR_NODE_LINK_TO(fused_attention_node, attn_dropout_out_node); + IR_NODE_LINK_TO(fused_attention_node, qkv_matmul_out_node); + IR_NODE_LINK_TO(fused_attention_node, qkv_reshape_out_node); + IR_NODE_LINK_TO(fused_attention_node, out_linear_matmul_out_node); + IR_NODE_LINK_TO(fused_attention_node, out_linear_dropout_mask_node); + IR_NODE_LINK_TO(fused_attention_node, residual_ele_add_out_node); GraphSafeRemoveNodes(g, {pre_layer_norm_op_node, @@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { out_linear_matmul_op_node, out_linear_ele_add_op_node, out_linear_dropout_op_node, - residual_ele_add_op_node, - post_layer_norm_op_node}); + residual_ele_add_op_node}); found_fused_attention++; }; @@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { return graph; } -ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { +ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "x")) ->AsInput() - ->assert_is_op_input("layer_norm_grad", "Y@GRAD"); + ->assert_is_op_input("elementwise_add_grad", "Out@GRAD"); patterns::FusedAttentionGradPattern fused_attention_grad_pattern( gpd.mutable_pattern(), "fused_attention_grad_pattern"); fused_attention_grad_pattern(x, /* pre_layer_norm */ true, - /* post_layer_norm */ true, /* has_attn_mask */ true, /* do_dropout */ true, /* add_residual */ true); @@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { Graph* g) { VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion"; - GET_IR_NODE_FROM_SUBGRAPH(post_layer_norm_grad_op_node, - post_layer_norm_grad_op, - fused_attention_grad_pattern); GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node, residual_ele_add_grad_op, fused_attention_grad_pattern); @@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { // TODO(Yuang Liu): finish the handler - GraphSafeRemoveNodes( - g, {post_layer_norm_grad_op_node, residual_ele_add_grad_op_node, - out_linear_dropout_grad_op_node, out_linear_ele_add_grad_op_node, - out_linear_matmul_grad_op_node, qkv_reshape_grad_op_node, - qkv_transpose_grad_op_node, qkv_matmul_grad_op_node, - attn_dropout_grad_op_node, qk_softmax_grad_op_node, - add_mask_ele_add_grad_op_node, qk_scale_grad_op_node, - qk_matmul_grad_op_node, fuse_qkv_split_grad_op_node, - fuse_qkv_transpose_grad_op_node, fuse_qkv_reshape_grad_op_node, - fuse_qkv_ele_add_grad_op_node, fuse_qkv_matmul_grad_op_node, - pre_layer_norm_grad_op_node, grad_accumulation_sum_op_node}); + GraphSafeRemoveNodes(g, + {residual_ele_add_grad_op_node, + out_linear_dropout_grad_op_node, + out_linear_ele_add_grad_op_node, + out_linear_matmul_grad_op_node, + qkv_reshape_grad_op_node, + qkv_transpose_grad_op_node, + qkv_matmul_grad_op_node, + attn_dropout_grad_op_node, + qk_softmax_grad_op_node, + add_mask_ele_add_grad_op_node, + qk_scale_grad_op_node, + qk_matmul_grad_op_node, + fuse_qkv_split_grad_op_node, + fuse_qkv_transpose_grad_op_node, + fuse_qkv_reshape_grad_op_node, + fuse_qkv_ele_add_grad_op_node, + fuse_qkv_matmul_grad_op_node, + pre_layer_norm_grad_op_node, + grad_accumulation_sum_op_node}); found_fused_attention++; }; diff --git a/paddle/fluid/framework/ir/fused_attention_pass.h b/paddle/fluid/framework/ir/fused_attention_pass.h index d360f7f652..41a90bd599 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.h +++ b/paddle/fluid/framework/ir/fused_attention_pass.h @@ -28,7 +28,7 @@ namespace patterns { // Declare patterns for multi head attention. // Can detect: -// 1. Pre layer norm, post layer norm or sandwich layer norm. +// 1. Pre layer norm or post layer norm. // 2. Add attn mask for qk product before the softmax or not. // 3. Do attn dropout or not. // 4. Add residual to the out linear result or not. @@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase { : PatternBase(pattern, name_scope, "fused_attention_pattern") {} PDNode* operator()(PDNode* x, - bool pre_layer_norm, // do pre ln or not - bool post_layer_norm, // do post ln or not - bool has_attn_mask, // add attn mask to qk or not - bool do_dropout, // dropout the softmax(qk) or not - bool add_residual); // add residual to out linear or not + bool pre_layer_norm, // do pre ln or not + bool has_attn_mask, // add attn mask to qk or not + bool do_dropout, // dropout the softmax(qk) or not + bool add_residual); // add residual to out linear or not // pre layer norm PATTERN_DECL_NODE(pre_layer_norm_op); @@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase { : PatternBase(pattern, name_scope, "fused_attention_pattern") {} PDNode* operator()(PDNode* x, - bool pre_layer_norm, // pre ln - bool post_layer_norm, // post ln - bool has_attn_mask, // add attn mask to qk or not - bool do_dropout, // dropout the softmax(qk) or not - bool add_residual); // add residual to out linear or not + bool pre_layer_norm, // pre ln + bool has_attn_mask, // add attn mask to qk or not + bool do_dropout, // dropout the softmax(qk) or not + bool add_residual); // add residual to out linear or not // post layer norm grad PATTERN_DECL_NODE(post_layer_norm_grad_op); @@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase { // If true, the function name will have an abbreviation part. // If false, the function name won't contain an abbreviation for it. - ir::Graph* PreMaskDropResPostFwd(Graph* graph) const; + ir::Graph* PreMaskDropResFwd(Graph* graph) const; - ir::Graph* PreMaskDropResPostBwd(Graph* graph) const; + ir::Graph* PreMaskDropResBwd(Graph* graph) const; }; } // namespace ir diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py index cce05d8747..12366a574d 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py @@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer): num_heads, add_residual=True, pre_ln=True, - post_ln=False, attn_dropout=True, ): super(MultiHeadAttention, self).__init__() @@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer): self.add_residual = add_residual self.pre_ln = pre_ln - self.post_ln = post_ln self.attn_dropout = attn_dropout self.head_dim = embed_dim // num_heads @@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer): if self.add_residual: out = residual + out - if self.post_ln: + if not self.pre_ln: # post layer norm out = self.norm2(out) @@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase): def setUp(self): self.add_residual = True self.pre_ln = True - self.post_ln = True self.attn_dropout = True self.add_mask = True @@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase): ).astype('float32') main_prog = paddle.static.Program() + main_prog.random_seed = 1234 startup_prog = paddle.static.Program() with paddle.static.program_guard(main_prog, startup_prog): @@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase): num_heads, add_residual=self.add_residual, pre_ln=self.pre_ln, - post_ln=self.post_ln, attn_dropout=self.attn_dropout, ) @@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase): pass_manager.apply([main_prog], [startup_prog]) ops = main_prog.global_block().ops - assert ops[2].type == 'reduce_mean' - assert ops[4].type == 'reduce_mean_grad' + assert ops[2].type == 'fused_attention' + assert ops[3].type == 'reduce_mean' + assert ops[5].type == 'reduce_mean_grad' # two ops for linear, one op for reduce mean # one fill constant # one op for reduce mean grad, two ops for linear bwd # the eighth op should be the optimizer - assert ops[7].type == 'sgd' + assert ops[8].type == 'sgd' if __name__ == "__main__": -- GitLab