From 29eec2dd47ee5d0a6bd0c56f05bd1cbad041db9c Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Wed, 4 Jan 2023 09:46:55 +0800 Subject: [PATCH] add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383) --- .../fused_multi_transformer_decoder_pass.cc | 86 +- .../ir/fused_multi_transformer_decoder_pass.h | 12 +- .../fused_multi_transformer_encoder_pass.cc | 3481 +++++++++++------ .../ir/fused_multi_transformer_encoder_pass.h | 184 +- ...d_multi_transformer_encoder_pass_tester.cc | 252 +- .../inference/api/paddle_pass_builder.cc | 2 + 6 files changed, 2776 insertions(+), 1241 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index b5d0661ae7..87e63a2cb4 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -31,6 +31,8 @@ namespace framework { namespace ir { namespace patterns { +static const std::unordered_set FFN_ACTS{"relu", "gelu"}; + PDNode* FusedMultiTransformerDecoderPattern::operator()() { auto* input0 = pattern->NewNode(input0_repr()); input0->assert_is_op_input("layer_norm", "X"); @@ -359,13 +361,13 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("gelu"); + ->assert_is_ops_input(FFN_ACTS); - auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); - auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) - ->assert_is_op_output("gelu") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2"); + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); auto* ffn_matmul1 = pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); @@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { .LinksTo({ffn_matmul0_out_var}); ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) .LinksTo({ffn_eltadd0_out_var}); - ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); - ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); @@ -678,13 +680,13 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("gelu"); + ->assert_is_ops_input(FFN_ACTS); - auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); - auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) - ->assert_is_op_output("gelu") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2"); + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); auto* ffn_matmul1 = pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); @@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { .LinksTo({ffn_matmul0_out_var}); ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) .LinksTo({ffn_eltadd0_out_var}); - ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); - ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); @@ -1026,13 +1028,13 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("gelu"); + ->assert_is_ops_input(FFN_ACTS); - auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); - auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) - ->assert_is_op_output("gelu") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2"); + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); auto* ffn_matmul1 = pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); @@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { .LinksTo({ffn_matmul0_out_var}); ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) .LinksTo({ffn_eltadd0_out_var}); - ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); - ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) .LinksTo({ffn_matmul1_out_var}); ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) .LinksTo({ffn_c_allreduce_sum_out_var}); @@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, + Node* ffn_act, Node* ffn_output) { auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); @@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); + fused_multi_transformer_op_desc.SetAttr("act_method", + ffn_act->Op()->Type()); // output dropout attribute fused_multi_transformer_op_desc.SetAttr("is_test", true); @@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu, ffn_gelu, fused_multi_transformer_pattern); + ffn_act, ffn_act, fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern); + ffn_act_out, ffn_act_out, fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); @@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, + ffn_act, ffn_output); std::unordered_set marked_nodes({layer_norm, @@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ffn_eltadd1, ffn_eltadd0_out, ffn_eltadd1_out, - ffn_gelu, - ffn_gelu_out, + ffn_act, + ffn_act_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, + Node* ffn_act, Node* ffn_output) { auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); @@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); + fused_multi_transformer_op_desc.SetAttr("act_method", + ffn_act->Op()->Type()); // output dropout attribute fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); @@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); @@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, + ffn_act, ffn_output); std::unordered_set marked_nodes({layer_norm, @@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_eltadd1, ffn_eltadd0_out, ffn_eltadd1_out, - ffn_gelu, - ffn_gelu_out, + ffn_act, + ffn_act_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, + Node* ffn_act, Node* ffn_output) { auto* matmul_linear_op = matmul_linear->Op(); auto* ffn_matmul_1_op = ffn_matmul1->Op(); @@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); + fused_multi_transformer_op_desc.SetAttr("act_method", + ffn_act->Op()->Type()); // output dropout attribute fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); @@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); @@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, + ffn_act, ffn_output); std::unordered_set marked_nodes({layer_norm, @@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_eltadd1, ffn_eltadd0_out, ffn_eltadd1_out, - ffn_gelu, - ffn_gelu_out, + ffn_act, + ffn_act_out, ffn_eltadd_out}); // Remove unneeded nodes. diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h index bfdc38c708..6e00de0bb8 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h @@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_out); - PATTERN_DECL_NODE(ffn_gelu); - PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_out); @@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_out); - PATTERN_DECL_NODE(ffn_gelu); - PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_out); @@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_out); - PATTERN_DECL_NODE(ffn_gelu); - PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_out); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index b7a723d813..e9e4c32d9e 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -25,43 +25,19 @@ namespace framework { namespace ir { namespace patterns { -PDNode* FusedMultiTransformerEncoderPattern::operator()() { - auto* input0 = pattern->NewNode(input0_repr()); - input0->assert_is_op_input("layer_norm", "X"); - - // pre-LayerNorm - auto* layer_norm = - pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); - auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Scale"); - auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Bias"); - auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) - ->AsOutput() - ->assert_is_op_output("layer_norm", "Mean"); - auto* layer_norm_variance_var = - pattern->NewNode(layer_norm_variance_repr()) - ->AsOutput() - ->assert_is_op_output("layer_norm", "Variance"); - auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) - ->AsIntermediate() - ->assert_is_op_output("layer_norm", "Y") - ->assert_is_op_input("matmul_v2", "X") - ->assert_more([](Node* x) { - if (x->outputs.size() == 3) { - return true; - } else { - return false; - } - }); +static const std::unordered_set FFN_ACTS{"relu", "gelu"}; - layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) - .LinksTo( - {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); +PDNode* FusedMultiTransformerEncoderPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()) + ->assert_is_op_input("matmul_v2", "X") + ->assert_is_op_input("elementwise_add", "X") + ->assert_more([](Node* x) { + if (x->outputs.size() == 4) { + return true; + } else { + return false; + } + }); // Q path Nodes auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); @@ -95,15 +71,20 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) ->assert_is_op_output("transpose2") ->AsIntermediate() - ->assert_is_op_input("matmul", "X"); + ->assert_is_op_input("scale"); + auto* scale_q = pattern->NewNode(scale_q_repr())->assert_is_op("scale"); + auto* scale_q_out_var = pattern->NewNode(scale_q_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // Q path Links - matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var}) - .LinksTo({matmul0_out_var}); + matmul0->LinksFrom({input0, matmul0_w_var}).LinksTo({matmul0_out_var}); eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) .LinksTo({eltadd0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + scale_q->LinksFrom({transpose2_0_out_var}).LinksTo({scale_q_out_var}); // K path Nodes auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2"); @@ -137,20 +118,11 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) ->assert_is_op_output("transpose2") - ->AsOutput() - ->assert_is_op_input("matmul", "Y") - ->assert_is_op_input("while") - ->assert_more([](Node* x) { - if (x->outputs.size() == 2) { - return true; - } else { - return false; - } - }); + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "Y"); // K path Links - matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var}) - .LinksTo({matmul1_out_var}); + matmul1->LinksFrom({input0, matmul1_w_var}).LinksTo({matmul1_out_var}); eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var}) .LinksTo({eltadd1_out_var}); reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); @@ -187,29 +159,21 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) ->assert_is_op_output("transpose2") - ->AsOutput() - ->assert_is_op_input("matmul_v2", "Y") - ->assert_is_op_input("while") - ->assert_more([](Node* x) { - if (x->outputs.size() == 2) { - return true; - } else { - return false; - } - }); + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "Y"); // V path Links - matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var}) - .LinksTo({matmul2_out_var}); + matmul2->LinksFrom({input0, matmul2_w_var}).LinksTo({matmul2_out_var}); eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var}) .LinksTo({eltadd2_out_var}); reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); // QK path Nodes - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); auto* eltadd_qk = @@ -230,7 +194,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk - matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + matmul_qk->LinksFrom({scale_q_out_var, transpose2_1_out_var}) .LinksTo({matmul_qk_out_var}); eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); @@ -297,42 +261,41 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); - // while loop - auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); - while0->LinksFrom({transpose2_1_out_var, transpose2_2_out_var}); - - // Feed Forward LayerNorm Nodes - auto* ffn_layer_norm = - pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); - auto* ffn_layer_norm_scale_var = - pattern->NewNode(ffn_layer_norm_scale_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Scale"); - auto* ffn_layer_norm_bias_var = - pattern->NewNode(ffn_layer_norm_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Bias"); - auto* ffn_layer_norm_mean_var = - pattern->NewNode(ffn_layer_norm_mean_repr()) - ->AsOutput() - ->assert_is_op_output("layer_norm", "Mean"); - auto* ffn_layer_norm_variance_var = - pattern->NewNode(ffn_layer_norm_variance_repr()) + // post-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) ->AsOutput() ->assert_is_op_output("layer_norm", "Variance"); - auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) - ->AsIntermediate() - ->assert_is_op_output("layer_norm", "Y") - ->assert_is_op_input("matmul_v2", "X"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X") + ->assert_is_op_input("elementwise_add", "X") + ->assert_more([](Node* x) { + if (x->outputs.size() == 2) { + return true; + } else { + return false; + } + }); - ffn_layer_norm - ->LinksFrom( - {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) - .LinksTo({ffn_layer_norm_out_var, - ffn_layer_norm_mean_var, - ffn_layer_norm_variance_var}); + layer_norm + ->LinksFrom({attention_output, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = @@ -353,13 +316,13 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("gelu"); + ->assert_is_ops_input(FFN_ACTS); - auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); - auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) - ->assert_is_op_output("gelu") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2"); + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); auto* ffn_matmul1 = pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); @@ -385,22 +348,55 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); auto* ffn_output = pattern->NewNode(ffn_output_repr()) ->assert_is_op_output("elementwise_add") - ->AsOutput(); + ->AsIntermediate() + ->assert_is_op_input("layer_norm"); - ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var}) + ffn_matmul0->LinksFrom({layer_norm_out_var, ffn_matmul0_w_var}) .LinksTo({ffn_matmul0_out_var}); ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) .LinksTo({ffn_eltadd0_out_var}); - ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); - ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) + ffn_eltadd_out->LinksFrom({layer_norm_out_var, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); - return ffn_output; + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + + ffn_layer_norm + ->LinksFrom( + {ffn_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + return ffn_layer_norm_out_var; } PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { @@ -649,13 +645,13 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("gelu"); + ->assert_is_ops_input(FFN_ACTS); - auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); - auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) - ->assert_is_op_output("gelu") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2"); + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); auto* ffn_matmul1 = pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); @@ -687,8 +683,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { .LinksTo({ffn_matmul0_out_var}); ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) .LinksTo({ffn_eltadd0_out_var}); - ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); - ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); @@ -699,47 +695,41 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { return ffn_output; } -PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { - auto* input0 = pattern->NewNode(input0_repr()); - input0->assert_is_op_input("layer_norm", "X"); - - // pre-LayerNorm - auto* layer_norm = - pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); - auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Scale"); - auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Bias"); - auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) - ->AsOutput() - ->assert_is_op_output("layer_norm", "Mean"); - auto* layer_norm_variance_var = - pattern->NewNode(layer_norm_variance_repr()) - ->AsOutput() - ->assert_is_op_output("layer_norm", "Variance"); - auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) - ->AsIntermediate() - ->assert_is_op_output("layer_norm", "Y") - ->assert_is_op_input("c_identity", "X"); - - layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) - .LinksTo( - {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); - +PDNode* MultiDevicesFusedMultiTransformerEncoderPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()) + ->assert_is_op_input("c_identity", "X") + ->assert_is_op_input("elementwise_add", "X") + ->assert_more([](Node* x) { + if (x->outputs.size() == 4) { + return true; + } else { + return false; + } + }); // communication c_identity - auto* c_identity = - pattern->NewNode(c_identity_repr())->assert_is_op("c_identity"); - auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr()) - ->AsIntermediate() - ->assert_is_op_output("c_identity", "Out") - ->assert_is_op_input("matmul_v2", "X"); - c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var}); + auto* c_identity0 = + pattern->NewNode(c_identity0_repr())->assert_is_op("c_identity"); + auto* c_identity0_out_var = pattern->NewNode(c_identity0_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("c_identity", "Out") + ->assert_is_op_input("matmul_v2", "X"); + auto* c_identity1 = + pattern->NewNode(c_identity1_repr())->assert_is_op("c_identity"); + auto* c_identity1_out_var = pattern->NewNode(c_identity1_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("c_identity", "Out") + ->assert_is_op_input("matmul_v2", "X"); + auto* c_identity2 = + pattern->NewNode(c_identity2_repr())->assert_is_op("c_identity"); + auto* c_identity2_out_var = pattern->NewNode(c_identity2_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("c_identity", "Out") + ->assert_is_op_input("matmul_v2", "X"); + c_identity0->LinksFrom({input0}).LinksTo({c_identity0_out_var}); + c_identity1->LinksFrom({input0}).LinksTo({c_identity1_out_var}); + c_identity2->LinksFrom({input0}).LinksTo({c_identity2_out_var}); - // QKV fused path Nodes + // Q path Nodes auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) ->AsInput() @@ -771,62 +761,125 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) ->assert_is_op_output("transpose2") ->AsIntermediate() - ->assert_is_op_input("split", "X"); - - auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); - auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) - ->assert_is_op_output("split") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); - auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) - ->assert_is_op_output("split") - ->AsOutput() - ->assert_is_op_input("matmul_v2", "Y") - ->assert_is_op_input("while"); - auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) - ->assert_is_op_output("split") - ->AsOutput() - ->assert_is_op_input("matmul_v2", "Y") - ->assert_is_op_input("while"); + ->assert_is_op_input("scale"); + auto* scale_q = pattern->NewNode(scale_q_repr())->assert_is_op("scale"); + auto* scale_q_out_var = pattern->NewNode(scale_q_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); - // QKV fused path Links - matmul0->LinksFrom({c_identity_out_var, matmul0_w_var}) + // Q path Links + matmul0->LinksFrom({c_identity0_out_var, matmul0_w_var}) .LinksTo({matmul0_out_var}); eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) .LinksTo({eltadd0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); - split0->LinksFrom({transpose2_0_out_var}) - .LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var}); + scale_q->LinksFrom({transpose2_0_out_var}).LinksTo({scale_q_out_var}); - // while loop - auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); - while0->LinksFrom({split0_k_out_var, split0_v_out_var}); + // K path Nodes + auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2"); + auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); - // QK path Nodes - auto* matmul_qk = - pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); - auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); - matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + auto* eltadd1 = + pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); - auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); - auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) - ->assert_is_op_output("scale") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add", "X"); + auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); - auto* eltadd_qk = - pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); - auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "Y"); - auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) - ->assert_is_op_output("elementwise_add") - ->AsIntermediate() - ->assert_is_op_input("softmax"); + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); - auto* softmax_qk = + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "Y"); + + // K path Links + matmul1->LinksFrom({c_identity1_out_var, matmul1_w_var}) + .LinksTo({matmul1_out_var}); + eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var}) + .LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + + // V path Nodes + auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2"); + auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd2 = + pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "Y"); + + // V path Links + matmul2->LinksFrom({c_identity2_out_var, matmul2_w_var}) + .LinksTo({matmul2_out_var}); + eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var}) + .LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + + // QK path Nodes + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") @@ -834,10 +887,9 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk - matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) + matmul_qk->LinksFrom({scale_q_out_var, transpose2_1_out_var}) .LinksTo({matmul_qk_out_var}); - scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); - eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); @@ -897,7 +949,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -912,38 +964,41 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); - // Feed Forward LayerNorm Nodes - auto* ffn_layer_norm = - pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); - auto* ffn_layer_norm_scale_var = - pattern->NewNode(ffn_layer_norm_scale_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Scale"); - auto* ffn_layer_norm_bias_var = - pattern->NewNode(ffn_layer_norm_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layer_norm", "Bias"); - auto* ffn_layer_norm_mean_var = - pattern->NewNode(ffn_layer_norm_mean_repr()) - ->AsIntermediate() - ->assert_is_op_output("layer_norm", "Mean"); - auto* ffn_layer_norm_variance_var = - pattern->NewNode(ffn_layer_norm_variance_repr()) - ->AsIntermediate() + // post-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() ->assert_is_op_output("layer_norm", "Variance"); - auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) - ->AsIntermediate() - ->assert_is_op_output("layer_norm", "Y") - ->assert_is_op_input("c_identity", "X"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X") + ->assert_is_op_input("elementwise_add", "X") + ->assert_more([](Node* x) { + if (x->outputs.size() == 2) { + return true; + } else { + return false; + } + }); - ffn_layer_norm - ->LinksFrom( - {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) - .LinksTo({ffn_layer_norm_out_var, - ffn_layer_norm_mean_var, - ffn_layer_norm_variance_var}); + layer_norm + ->LinksFrom({attention_output, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); // communication c_identity auto* ffn_c_identity = @@ -952,7 +1007,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ->assert_is_op_output("c_identity", "Out") ->AsIntermediate() ->assert_is_op_input("matmul_v2", "X"); - ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) + ffn_c_identity->LinksFrom({layer_norm_out_var}) .LinksTo({ffn_c_identity_out_var}); // Feed Forward fc1 -> gelu -> fc2 @@ -974,13 +1029,13 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("gelu"); + ->assert_is_ops_input(FFN_ACTS); - auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); - auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) - ->assert_is_op_output("gelu") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2"); + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); auto* ffn_matmul1 = pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); @@ -1015,297 +1070,1504 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); auto* ffn_output = pattern->NewNode(ffn_output_repr()) ->assert_is_op_output("elementwise_add") - ->AsOutput(); + ->AsIntermediate() + ->assert_is_op_input("layer_norm"); ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var}) .LinksTo({ffn_matmul0_out_var}); ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) .LinksTo({ffn_eltadd0_out_var}); - ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); - ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) .LinksTo({ffn_matmul1_out_var}); ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) .LinksTo({ffn_c_allreduce_sum_out_var}); ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) + ffn_eltadd_out->LinksFrom({layer_norm_out_var, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); - return ffn_output; -} - -} // namespace patterns - -template -inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, - phi::DenseTensor* wk_tensor, - phi::DenseTensor* wv_tensor, - const int num_head, - const int dim_head, - const int dim_embed) { - auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); - auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); - auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); - - auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); - - phi::DenseTensor tmp_combined_w_tensor; - tmp_combined_w_tensor.Resize(combined_w_dims); - auto* tmp_combined_w_data = - tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); - std::vector w_vec = {wq_data, wk_data, wv_data}; - // Combine the three fc weights together. - for (int i = 0; i < 3; i++) { - for (int j = 0; j < num_head; j++) { - for (int k = 0; k < dim_head; k++) { - for (int l = 0; l < dim_embed; l++) { - int out_idx = i * num_head * dim_head * dim_embed + - j * dim_head * dim_embed + k * dim_embed + l; - int in_idx = l * num_head * dim_head + j * dim_head + k; - tmp_combined_w_data[out_idx] = w_vec[i][in_idx]; - } - } - } - } + ffn_layer_norm + ->LinksFrom( + {ffn_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); - wq_tensor->Resize(combined_w_dims); - auto* new_combined_w_data = wq_tensor->mutable_data(platform::CPUPlace()); - memcpy( - new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel()); + return ffn_layer_norm_out_var; } -template -inline void QKVBiasProcess(phi::DenseTensor* bq_tensor, - phi::DenseTensor* bk_tensor, - phi::DenseTensor* bv_tensor, - const int num_head, - const int dim_head, - const int dim_embed) { - auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); - auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); - auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); +PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); - auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head}); + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X"); - phi::DenseTensor tmp_combined_bias_tensor; - tmp_combined_bias_tensor.Resize(combined_bias_dims); - auto* tmp_combined_bias_data = - tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); - size_t bias_size = bq_tensor->numel(); - memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size); - memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size); - memcpy( - tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size); + // communication c_identity + auto* c_identity = + pattern->NewNode(c_identity_repr())->assert_is_op("c_identity"); + auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("c_identity", "Out") + ->assert_is_op_input("matmul_v2", "X"); + c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var}); - bq_tensor->Resize(combined_bias_dims); - auto* new_combined_bias_data = - bq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_bias_data, - tmp_combined_bias_data, - sizeof(T) * bq_tensor->numel()); -} + // QKV fused path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("split", "X"); + + auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); + auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); + auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) + ->assert_is_op_output("split") + ->AsOutput() + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_op_input("while"); + auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) + ->assert_is_op_output("split") + ->AsOutput() + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_op_input("while"); + + // QKV fused path Links + matmul0->LinksFrom({c_identity_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + split0->LinksFrom({transpose2_0_out_var}) + .LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var}); + + // while loop + auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); + while0->LinksFrom({split0_k_out_var, split0_v_out_var}); + + // QK path Nodes + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); + auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "X"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); + + // QK path Linsk + matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) + .LinksTo({matmul_qk_out_var}); + scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); + eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("c_allreduce_sum"); + + // communication c_allreduce_sum + auto* c_allreduce_sum = + pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum"); + auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + c_allreduce_sum->LinksFrom({matmul_linear_out_var}) + .LinksTo({c_allreduce_sum_out_var}); + eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) + .LinksTo({attention_output}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // communication c_identity + auto* ffn_c_identity = + pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity"); + auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr()) + ->assert_is_op_output("c_identity", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); + ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) + .LinksTo({ffn_c_identity_out_var}); + + // Feed Forward fc1 -> gelu -> fc2 + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_ops_input(FFN_ACTS); + + auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS); + auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr()) + ->assert_is_ops_output(FFN_ACTS) + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("c_allreduce_sum"); + + // communication c_allreduce_sum + auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr()) + ->assert_is_op("c_allreduce_sum"); + auto* ffn_c_allreduce_sum_out_var = + pattern->NewNode(ffn_c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var}); + ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) + .LinksTo({ffn_c_allreduce_sum_out_var}); + ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +} // namespace patterns + +template +inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, + phi::DenseTensor* wk_tensor, + phi::DenseTensor* wv_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + auto* wq_data = wq_tensor->data(); + auto* wk_data = wk_tensor->data(); + auto* wv_data = wv_tensor->data(); + + auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); + + phi::DenseTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + dev_ctx->Alloc(&tmp_combined_w_tensor); + auto* tmp_combined_w_data = tmp_combined_w_tensor.data(); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + // Combine the three fc weights together. + for (int i = 0; i < 3; i++) { + for (int j = 0; j < num_head; j++) { + for (int k = 0; k < dim_head; k++) { + for (int l = 0; l < dim_embed; l++) { + int out_idx = i * num_head * dim_head * dim_embed + + j * dim_head * dim_embed + k * dim_embed + l; + int in_idx = l * num_head * dim_head + j * dim_head + k; + tmp_combined_w_data[out_idx] = w_vec[i][in_idx]; + } + } + } + } + + wq_tensor->Resize(combined_w_dims); + dev_ctx->Alloc(wq_tensor); + auto* new_combined_w_data = wq_tensor->data(); + memcpy( + new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel()); +} + +template +inline void QKVBiasProcess(phi::DenseTensor* bq_tensor, + phi::DenseTensor* bk_tensor, + phi::DenseTensor* bv_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + auto* bq_data = bq_tensor->data(); + auto* bk_data = bk_tensor->data(); + auto* bv_data = bv_tensor->data(); + + auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head}); + + phi::DenseTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + dev_ctx->Alloc(&tmp_combined_bias_tensor); + auto* tmp_combined_bias_data = tmp_combined_bias_tensor.data(); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size); + memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size); + memcpy( + tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size); + + bq_tensor->Resize(combined_bias_dims); + dev_ctx->Alloc(bq_tensor); + auto* new_combined_bias_data = bq_tensor->data(); + memcpy(new_combined_bias_data, + tmp_combined_bias_data, + sizeof(T) * bq_tensor->numel()); +} + +inline void QKVWeightsBiasProcess(phi::DenseTensor* wq_tensor, + phi::DenseTensor* wk_tensor, + phi::DenseTensor* wv_tensor, + phi::DenseTensor* bq_tensor, + phi::DenseTensor* bk_tensor, + phi::DenseTensor* bv_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + switch (wq_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::INT8: + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32/fp16/int8.")); + break; + } + switch (bq_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVBiasProcess( + bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVBiasProcess( + bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported bias dtype. " + "we now only support fp32/fp16.")); + break; + } +} + +template +inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + auto* qkv_w_data = qkv_w_tensor->data(); + auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); + + phi::DenseTensor tmp_transpose_w_tensor; + tmp_transpose_w_tensor.Resize(transpose_w_dims); + dev_ctx->Alloc(&tmp_transpose_w_tensor); + auto* tmp_transpose_w_data = tmp_transpose_w_tensor.data(); + + // transpose qkv matmul Y to QKVWeights + for (int i = 0; i < 3; i++) { + for (int j = 0; j < num_head; j++) { + for (int k = 0; k < dim_head; k++) { + for (int l = 0; l < dim_embed; l++) { + int out_idx = i * num_head * dim_head * dim_embed + + j * dim_head * dim_embed + k * dim_embed + l; + int in_idx = + l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k; + tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx]; + } + } + } + } + + qkv_w_tensor->Resize(transpose_w_dims); + dev_ctx->Alloc(qkv_w_tensor); + auto* new_transpose_w_data = qkv_w_tensor->data(); + memcpy(new_transpose_w_data, + tmp_transpose_w_data, + sizeof(T) * qkv_w_tensor->numel()); +} + +template +inline void QKVBiasProcessFuseQKV(phi::DenseTensor* qkv_b_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + auto* qkv_b_data = qkv_b_tensor->data(); + auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head}); + + phi::DenseTensor tmp_transpose_b_tensor; + tmp_transpose_b_tensor.Resize(transpose_b_dims); + dev_ctx->Alloc(&tmp_transpose_b_tensor); + auto* tmp_transpose_b_data = tmp_transpose_b_tensor.data(); + + // transpose qkv elemenwise_add Y to QKVBias + for (int i = 0; i < 3; i++) { + for (int j = 0; j < num_head; j++) { + for (int k = 0; k < dim_head; k++) { + int out_idx = i * num_head * dim_head + j * dim_head + k; + int in_idx = j * 3 * dim_head + i * dim_head + k; + tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx]; + } + } + } + + qkv_b_tensor->Resize({3, num_head, dim_head}); + dev_ctx->Alloc(qkv_b_tensor); + auto* new_transpose_b_data = qkv_b_tensor->data(); + memcpy(new_transpose_b_data, + tmp_transpose_b_data, + sizeof(T) * qkv_b_tensor->numel()); +} + +inline void QKVWeightsBiasProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, + phi::DenseTensor* qkv_b_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + switch (qkv_w_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVWeightsProcessFuseQKV( + qkv_w_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVWeightsProcessFuseQKV( + qkv_w_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::INT8: + QKVWeightsProcessFuseQKV( + qkv_w_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32/fp16/int8.")); + break; + } + switch (qkv_b_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVBiasProcessFuseQKV( + qkv_b_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVBiasProcessFuseQKV(qkv_b_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported bias dtype. " + "we now only support fp32/fp16.")); + break; + } +} + +// Just use for fused_multi_transformer_int8 +inline void TransposeWeights(phi::DenseTensor* weight_tensor) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + int m = weight_tensor->dims()[0]; + int n = weight_tensor->dims()[1]; + phi::DenseTensor tmp_weight_tensor; + tmp_weight_tensor.Resize({n, m}); + dev_ctx->Alloc(&tmp_weight_tensor); + auto tmp_weight_data = tmp_weight_tensor.data(); + auto weight_data = weight_tensor->data(); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + int in_idx = i * n + j; + int out_idx = j * m + i; + tmp_weight_data[out_idx] = weight_data[in_idx]; + } + } + weight_tensor->Resize({n, m}); + dev_ctx->Alloc(weight_tensor); + auto new_weight_data = weight_tensor->data(); + memcpy(new_weight_data, tmp_weight_data, sizeof(int8_t) * m * n); +} + +inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) { + auto var_desc = VarDesc(name); + var_desc.SetDataType(framework::proto::VarType::FP32); + var_desc.SetPersistable(true); + auto node = graph->CreateVarNode(&var_desc); + return node; +} + +int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "FusedMultiTransformerEncoderPass with int8"; + } else { + VLOG(3) << "FusedMultiTransformerEncoderPass with fp"; + } + + // Create pattern. + patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern( + pattern, name_scope); + fused_multi_transformer_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* matmul0, + Node* matmul0_w, + Node* matmul1_w, + Node* matmul2_w, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* transpose2_1_out, + Node* transpose2_2_out, + Node* eltadd_qk_b, + Node* reshape2_0, + Node* matmul_linear, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0, + Node* ffn_matmul0_w, + Node* ffn_matmul1, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_act, + Node* ffn_layer_norm_out) { + auto* matmul0_op = matmul0->Op(); + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_0_op = ffn_matmul0->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); + + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + auto* wq_tensor = + scope->FindVar(matmul0_w->Name())->GetMutable(); + auto* wk_tensor = + scope->FindVar(matmul1_w->Name())->GetMutable(); + auto* wv_tensor = + scope->FindVar(matmul2_w->Name())->GetMutable(); + + auto* bq_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); + + // NOTE(minghaoBD): to make it compatible with strucutured pruning on + // num_head dimension: + // 1. get dim_head from reshape.shape[3], dim_embed from + // layer_norm_bias.shape[0] + // 2. calculate num_head according to wq_tensor.shape[1] and dim_head + auto reshape_desc = reshape2_0->Op(); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3); + auto* layer_norm_bias_tensor = + scope->FindVar(layer_norm_bias->Name())->GetMutable(); + int dim_embed = layer_norm_bias_tensor->dims()[0]; + int num_head = wq_tensor->dims()[1] / dim_head; + + QKVWeightsBiasProcess(wq_tensor, + wk_tensor, + wv_tensor, + bq_tensor, + bk_tensor, + bv_tensor, + num_head, + dim_head, + dim_embed); + + if (enable_int8) { + auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) + ->GetMutable(); + auto* ffn0_w_tensor = + scope->FindVar(ffn_matmul0_w->Name())->GetMutable(); + auto* ffn1_w_tensor = + scope->FindVar(ffn_matmul1_w->Name())->GetMutable(); + + TransposeWeights(out_linear_w_tensor); + TransposeWeights(ffn0_w_tensor); + TransposeWeights(ffn1_w_tensor); + } + + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = matmul0_w->Var(); + combined_w_desc->SetShape({3, num_head, dim_head, dim_embed}); + combined_w_desc->SetPersistable(true); + + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, num_head, dim_head}); + combined_bias_desc->SetPersistable(true); + + scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()}); + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // CacheKV input + VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); + // FIXME: only support max_seq_len <= 1024 + cache_kv_desc.SetDataType( + framework::TransToProtoVarType(bq_tensor->dtype())); + cache_kv_desc.SetPersistable(false); + auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); + + OpDesc fill_const_op_desc(layer_norm->Op()->Block()); + fill_const_op_desc.SetType("fill_constant_batch_size_like"); + fill_const_op_desc.SetInput("Input", {input0->Name()}); + fill_const_op_desc.SetOutput("Out", {cache_kv->Name()}); + std::vector shape = {2, -1, num_head, 1024, dim_head}; + fill_const_op_desc.SetAttr("shape", shape); + fill_const_op_desc.SetAttr("input_dim_idx", 0); + fill_const_op_desc.SetAttr("output_dim_idx", 1); + fill_const_op_desc.SetAttr("value", 0); + fill_const_op_desc.SetAttr( + "dtype", + static_cast(framework::TransToProtoVarType(bq_tensor->dtype()))); + auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); + + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", + {ffn_layer_norm_out->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", false); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + fused_multi_transformer_op_desc.SetAttr("is_test", true); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + fused_multi_transformer_op_desc.SetAttr("act_method", + {ffn_act->Op()->Type()}); + + // Quantization attribute/Input + if (enable_int8) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + // Set input scale + std::string qkv_input_name = matmul0_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name)); + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name)); + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Calc outscale and Set them + auto qkv_weight_scale = + PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); + auto out_weight_scale = + PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale")); + auto ffn0_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale")); + auto ffn1_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale")); + + auto qkv_out_scales = std::vector( + 3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f)); + auto out_out_scales = std::vector( + dim_embed, + (out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f)); + auto ffn0_out_scales = std::vector( + 4 * dim_embed, + (ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f)); + auto ffn1_out_scales = std::vector( + dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale"); + auto out_out_scale_var = + scope->Var(matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_var = + scope->Var(ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_var = + scope->Var(ffn_matmul1_w->Name() + "_out_scale"); + + auto* qkv_out_scale_tensor = + qkv_out_scale_var->GetMutable(); + qkv_out_scale_tensor->Resize({3 * dim_embed}); + dev_ctx->Alloc(qkv_out_scale_tensor); + auto qkv_out_scale_data = qkv_out_scale_tensor->data(); + memcpy(qkv_out_scale_data, + qkv_out_scales.data(), + qkv_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + + auto* out_out_scale_tensor = + out_out_scale_var->GetMutable(); + out_out_scale_tensor->Resize({dim_embed}); + dev_ctx->Alloc(out_out_scale_tensor); + auto out_out_scale_data = out_out_scale_tensor->data(); + memcpy(out_out_scale_data, + out_out_scales.data(), + out_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + + auto* ffn0_out_scale_tensor = + ffn0_out_scale_var->GetMutable(); + ffn0_out_scale_tensor->Resize({4 * dim_embed}); + dev_ctx->Alloc(ffn0_out_scale_tensor); + auto ffn0_out_scale_data = ffn0_out_scale_tensor->data(); + memcpy(ffn0_out_scale_data, + ffn0_out_scales.data(), + ffn0_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + + auto* ffn1_out_scale_tensor = + ffn1_out_scale_var->GetMutable(); + ffn1_out_scale_tensor->Resize({dim_embed}); + dev_ctx->Alloc(ffn1_out_scale_tensor); + auto ffn1_out_scale_data = ffn1_out_scale_tensor->data(); + memcpy(ffn1_out_scale_data, + ffn1_out_scales.data(), + ffn1_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } + + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + IR_NODE_LINK_TO(input0, fill_const_op); + IR_NODE_LINK_TO(fill_const_op, cache_kv); + IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_layer_norm_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_encoder pass in " + "op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer encoder fuse"; + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul1, matmul1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul1_out, matmul1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul1_w, matmul1_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1, reshape2_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1, transpose2_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + scale_q, scale_q, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_q_out, scale_q_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul2, matmul2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_out, matmul2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_w, matmul2_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2, reshape2_2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2, transpose2_2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + attention_output, attention_output, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_act, ffn_act, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_act_out, ffn_act_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1, eltadd1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1_b, eltadd1_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1_out, eltadd1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2, eltadd2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2_b, eltadd2_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2_out, eltadd2_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_pattern) + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + matmul0, + matmul0_w, + matmul1_w, + matmul2_w, + eltadd0_b, + eltadd1_b, + eltadd2_b, + transpose2_1_out, + transpose2_2_out, + eltadd_qk_b, + reshape2_0, + matmul_linear, + matmul_linear_w, + eltadd_linear_b, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0, + ffn_matmul0_w, + ffn_matmul1, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_act, + ffn_layer_norm_out); -inline void QKVWeightsBiasProcess(phi::DenseTensor* wq_tensor, - phi::DenseTensor* wk_tensor, - phi::DenseTensor* wv_tensor, - phi::DenseTensor* bq_tensor, - phi::DenseTensor* bk_tensor, - phi::DenseTensor* bv_tensor, - const int num_head, - const int dim_head, - const int dim_embed) { - switch (wq_tensor->dtype()) { - case paddle::experimental::DataType::FLOAT16: - QKVWeightsProcess( - wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); - break; - case paddle::experimental::DataType::FLOAT32: - QKVWeightsProcess( - wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); - break; - case paddle::experimental::DataType::INT8: - QKVWeightsProcess( - wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported weight dtype. " - "we now only support fp32/fp16/int8.")); - break; - } - switch (bq_tensor->dtype()) { - case paddle::experimental::DataType::FLOAT16: - QKVBiasProcess( - bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed); - break; - case paddle::experimental::DataType::FLOAT32: - QKVBiasProcess( - bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported bias dtype. " - "we now only support fp32/fp16.")); - break; - } -} + std::unordered_set marked_nodes({layer_norm, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + matmul0, + matmul1, + matmul2, + matmul0_out, + matmul1_out, + matmul2_out, + eltadd0, + eltadd1, + eltadd2, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + scale_q, + scale_q_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_out, + eltadd_linear, + eltadd_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_act, + ffn_act_out, + ffn_output, + ffn_eltadd_out}); -template -inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, - const int num_head, - const int dim_head, - const int dim_embed) { - auto* qkv_w_data = qkv_w_tensor->data(); - auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); - phi::DenseTensor tmp_transpose_w_tensor; - tmp_transpose_w_tensor.Resize(transpose_w_dims); - auto* tmp_transpose_w_data = - tmp_transpose_w_tensor.mutable_data(platform::CPUPlace()); + return fusion_count; +} - // transpose qkv matmul Y to QKVWeights - for (int i = 0; i < 3; i++) { - for (int j = 0; j < num_head; j++) { - for (int k = 0; k < dim_head; k++) { - for (int l = 0; l < dim_embed; l++) { - int out_idx = i * num_head * dim_head * dim_embed + - j * dim_head * dim_embed + k * dim_embed + l; - int in_idx = - l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k; - tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx]; - } - } - } - } +void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multi_transformer pass, The scope should not be null.")); - qkv_w_tensor->Resize(transpose_w_dims); - auto* new_transpose_w_data = - qkv_w_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_transpose_w_data, - tmp_transpose_w_data, - sizeof(T) * qkv_w_tensor->numel()); + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); + } + AddStatis(fusion_count); } -template -inline void QKVBiasProcessFuseQKV(phi::DenseTensor* qkv_b_tensor, - const int num_head, - const int dim_head, - const int dim_embed) { - auto* qkv_b_data = qkv_b_tensor->data(); - auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head}); +FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); - phi::DenseTensor tmp_transpose_b_tensor; - tmp_transpose_b_tensor.Resize(transpose_b_dims); - auto* tmp_transpose_b_data = - tmp_transpose_b_tensor.mutable_data(platform::CPUPlace()); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); - // transpose qkv elemenwise_add Y to QKVBias - for (int i = 0; i < 3; i++) { - for (int j = 0; j < num_head; j++) { - for (int k = 0; k < dim_head; k++) { - int out_idx = i * num_head * dim_head + j * dim_head + k; - int in_idx = j * 3 * dim_head + i * dim_head + k; - tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx]; - } - } - } + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); - qkv_b_tensor->Resize({3, num_head, dim_head}); - auto* new_transpose_b_data = - qkv_b_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_transpose_b_data, - tmp_transpose_b_data, - sizeof(T) * qkv_b_tensor->numel()); -} + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); -inline void QKVWeightsBiasProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, - phi::DenseTensor* qkv_b_tensor, - const int num_head, - const int dim_head, - const int dim_embed) { - switch (qkv_w_tensor->dtype()) { - case paddle::experimental::DataType::FLOAT16: - QKVWeightsProcessFuseQKV( - qkv_w_tensor, num_head, dim_head, dim_embed); - break; - case paddle::experimental::DataType::FLOAT32: - QKVWeightsProcessFuseQKV( - qkv_w_tensor, num_head, dim_head, dim_embed); - break; - case paddle::experimental::DataType::INT8: - QKVWeightsProcessFuseQKV( - qkv_w_tensor, num_head, dim_head, dim_embed); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported weight dtype. " - "we now only support fp32/fp16/int8.")); - break; - } - switch (qkv_b_tensor->dtype()) { - case paddle::experimental::DataType::FLOAT16: - QKVBiasProcessFuseQKV( - qkv_b_tensor, num_head, dim_head, dim_embed); - break; - case paddle::experimental::DataType::FLOAT32: - QKVBiasProcessFuseQKV(qkv_b_tensor, num_head, dim_head, dim_embed); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported bias dtype. " - "we now only support fp32/fp16.")); - break; - } -} + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); -// Just use for fused_multi_transformer_int8 -inline void TransposeWeights(phi::DenseTensor* weight_tensor) { - int m = weight_tensor->dims()[0]; - int n = weight_tensor->dims()[1]; - phi::DenseTensor tmp_weight_tensor; - auto tmp_weight_data = - tmp_weight_tensor.mutable_data({n, m}, platform::CPUPlace()); - auto weight_data = weight_tensor->data(); - for (int i = 0; i < m; ++i) { - for (int j = 0; j < n; ++j) { - int in_idx = i * n + j; - int out_idx = j * m + i; - tmp_weight_data[out_idx] = weight_data[in_idx]; - } - } - weight_tensor->Resize({n, m}); - auto new_weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_weight_data, tmp_weight_data, sizeof(int8_t) * m * n); -} + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); -inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) { - auto var_desc = VarDesc(name); - var_desc.SetDataType(framework::proto::VarType::FP32); - var_desc.SetPersistable(true); - auto node = graph->CreateVarNode(&var_desc); - return node; + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } -int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, - const std::string& name_scope, - Scope* scope) const { +int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); bool enable_int8 = graph->Get("enable_int8"); if (enable_int8) { - VLOG(3) << "FusedMultiTransformerEncoderPass with int8"; + VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with int8"; } else { - VLOG(3) << "FusedMultiTransformerEncoderPass with fp"; + VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with fp"; } // Create pattern. - patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern( - pattern, name_scope); - fused_multi_transformer_pattern(); + patterns::FusedMultiTransformerEncoderFuseQKVPattern + fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); + fused_multi_transformer_fuse_qkv_pattern(); // Create New OpDesc auto fuse_creater = [&](Node* input0, @@ -1316,13 +2578,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* layer_norm_variance, Node* matmul0, Node* matmul0_w, - Node* matmul1_w, - Node* matmul2_w, Node* eltadd0_b, - Node* eltadd1_b, - Node* eltadd2_b, - Node* transpose2_1_out, - Node* transpose2_2_out, + Node* split0_k_out, + Node* split0_v_out, Node* eltadd_qk_b, Node* reshape2_0, Node* matmul_linear, @@ -1340,6 +2598,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, + Node* ffn_act, Node* ffn_output) { auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); @@ -1355,43 +2614,28 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); int layer_idx = atoi(ln_idx_str.c_str()) / 2; - auto* wq_tensor = + auto* qkv_w_tensor = scope->FindVar(matmul0_w->Name())->GetMutable(); - auto* wk_tensor = - scope->FindVar(matmul1_w->Name())->GetMutable(); - auto* wv_tensor = - scope->FindVar(matmul2_w->Name())->GetMutable(); - - auto* bq_tensor = + auto* qkv_b_tensor = scope->FindVar(eltadd0_b->Name())->GetMutable(); - auto* bk_tensor = - scope->FindVar(eltadd1_b->Name())->GetMutable(); - auto* bv_tensor = - scope->FindVar(eltadd2_b->Name())->GetMutable(); // NOTE(minghaoBD): to make it compatible with strucutured pruning on // num_head dimension: // 1. get dim_head from reshape.shape[3], dim_embed from // layer_norm_bias.shape[0] - // 2. calculate num_head according to wq_tensor.shape[1] and dim_head + // 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head auto reshape_desc = reshape2_0->Op(); int dim_head = PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(3); + .at(3) / + 3; // 3 for qkv auto* layer_norm_bias_tensor = scope->FindVar(layer_norm_bias->Name())->GetMutable(); int dim_embed = layer_norm_bias_tensor->dims()[0]; - int num_head = wq_tensor->dims()[1] / dim_head; + int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head; - QKVWeightsBiasProcess(wq_tensor, - wk_tensor, - wv_tensor, - bq_tensor, - bk_tensor, - bv_tensor, - num_head, - dim_head, - dim_embed); + QKVWeightsBiasProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); if (enable_int8) { auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) @@ -1406,18 +2650,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, TransposeWeights(ffn1_w_tensor); } - // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. - auto* combined_w_desc = matmul0_w->Var(); - combined_w_desc->SetShape({3, num_head, dim_head, dim_embed}); - combined_w_desc->SetPersistable(true); - - auto* combined_bias_desc = eltadd0_b->Var(); - combined_bias_desc->SetShape({3, num_head, dim_head}); - combined_bias_desc->SetPersistable(true); - - scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()}); - scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); - // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); fused_multi_transformer_op_desc.SetType(enable_int8 @@ -1442,7 +2674,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); // FIXME: only support max_seq_len <= 1024 cache_kv_desc.SetDataType( - framework::TransToProtoVarType(bq_tensor->dtype())); + framework::TransToProtoVarType(qkv_b_tensor->dtype())); cache_kv_desc.SetPersistable(false); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); @@ -1455,9 +2687,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("value", 0); - fill_const_op_desc.SetAttr( - "dtype", - static_cast(framework::TransToProtoVarType(bq_tensor->dtype()))); + fill_const_op_desc.SetAttr("dtype", + static_cast(framework::TransToProtoVarType( + qkv_b_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -1490,12 +2722,17 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); + fused_multi_transformer_op_desc.SetAttr("act_method", + ffn_act->Op()->Type()); + // output dropout attribute fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); // Quantization attribute/Input if (enable_int8) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); // Set input scale std::string qkv_input_name = matmul0_op->Input("X")[0]; auto qkv_in_scale = PADDLE_GET_CONST( @@ -1512,6 +2749,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); // Calc outscale and Set them + // TODO(wufeisheng): Currently just match layer-wise weight scale, where + // channel-wise weight scale should also be surpported. auto qkv_weight_scale = PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); auto out_weight_scale = @@ -1555,36 +2794,44 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, auto ffn1_out_scale_var = scope->Var(ffn_matmul1_w->Name() + "_out_scale"); - auto qkv_out_scale_data = - qkv_out_scale_var->GetMutable() - ->mutable_data({3 * dim_embed}, platform::CPUPlace()); + auto* qkv_out_scale_tensor = + qkv_out_scale_var->GetMutable(); + qkv_out_scale_tensor->Resize({3 * dim_embed}); + dev_ctx->Alloc(qkv_out_scale_tensor); + auto qkv_out_scale_data = qkv_out_scale_tensor->data(); memcpy(qkv_out_scale_data, qkv_out_scales.data(), qkv_out_scales.size() * sizeof(float)); fused_multi_transformer_op_desc.SetInput( "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); - auto out_out_scale_data = - out_out_scale_var->GetMutable() - ->mutable_data({dim_embed}, platform::CPUPlace()); + auto* out_out_scale_tensor = + out_out_scale_var->GetMutable(); + out_out_scale_tensor->Resize({dim_embed}); + dev_ctx->Alloc(out_out_scale_tensor); + auto out_out_scale_data = out_out_scale_tensor->data(); memcpy(out_out_scale_data, out_out_scales.data(), out_out_scales.size() * sizeof(float)); fused_multi_transformer_op_desc.SetInput( "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); - auto ffn0_out_scale_data = - ffn0_out_scale_var->GetMutable() - ->mutable_data({4 * dim_embed}, platform::CPUPlace()); + auto* ffn0_out_scale_tensor = + ffn0_out_scale_var->GetMutable(); + ffn0_out_scale_tensor->Resize({4 * dim_embed}); + dev_ctx->Alloc(ffn0_out_scale_tensor); + auto ffn0_out_scale_data = ffn0_out_scale_tensor->data(); memcpy(ffn0_out_scale_data, ffn0_out_scales.data(), ffn0_out_scales.size() * sizeof(float)); fused_multi_transformer_op_desc.SetInput( "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); - auto ffn1_out_scale_data = - ffn1_out_scale_var->GetMutable() - ->mutable_data({dim_embed}, platform::CPUPlace()); + auto* ffn1_out_scale_tensor = + ffn1_out_scale_var->GetMutable(); + ffn1_out_scale_tensor->Resize({dim_embed}); + dev_ctx->Alloc(ffn1_out_scale_tensor); + auto ffn1_out_scale_data = ffn1_out_scale_tensor->data(); memcpy(ffn1_out_scale_data, ffn1_out_scales.data(), ffn1_out_scales.size() * sizeof(float)); @@ -1641,27 +2888,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, auto while_Xs = while0->Op()->Input("X"); while_Xs.erase( std::remove( - std::begin(while_Xs), std::end(while_Xs), transpose2_1_out->Name()), - std::end(while_Xs)); - while_Xs.erase( - std::remove( - std::begin(while_Xs), std::end(while_Xs), transpose2_2_out->Name()), - std::end(while_Xs)); - while_Xs.erase( - std::remove( - std::begin(while_Xs), std::end(while_Xs), matmul1_w->Name()), - std::end(while_Xs)); - while_Xs.erase( - std::remove( - std::begin(while_Xs), std::end(while_Xs), matmul2_w->Name()), - std::end(while_Xs)); - while_Xs.erase( - std::remove( - std::begin(while_Xs), std::end(while_Xs), eltadd1_b->Name()), + std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()), std::end(while_Xs)); while_Xs.erase( std::remove( - std::begin(while_Xs), std::end(while_Xs), eltadd2_b->Name()), + std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()), std::end(while_Xs)); while_Xs.emplace_back(cache_kv->Name()); while0->Op()->SetInput("X", while_Xs); @@ -1670,227 +2901,214 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, // 1. delete k, v // 2. add cache_kv auto while_Outs = while0->Op()->Output("Out"); - while_Outs.erase(std::remove(std::begin(while_Outs), - std::end(while_Outs), - transpose2_1_out->Name()), - std::end(while_Outs)); - while_Outs.erase(std::remove(std::begin(while_Outs), - std::end(while_Outs), - transpose2_2_out->Name()), - std::end(while_Outs)); + while_Outs.erase( + std::remove( + std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()), + std::end(while_Outs)); + while_Outs.erase( + std::remove( + std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()), + std::end(while_Outs)); while_Outs.emplace_back(cache_kv->Name()); while0->Op()->SetOutput("Out", while_Outs); // link CacheKV to while IR_NODE_LINK_TO(cache_kv, while0) // unlink origin KV output to while - IR_NODE_UNLINK(transpose2_1_out, while0); - IR_NODE_UNLINK(transpose2_2_out, while0); - IR_NODE_UNLINK(while0, transpose2_1_out); - IR_NODE_UNLINK(while0, transpose2_2_out); - // unlink KV weight/bias to while after merged into Q weight/bias - IR_NODE_UNLINK(matmul1_w, while0); - IR_NODE_UNLINK(matmul2_w, while0); - IR_NODE_UNLINK(eltadd1_b, while0); - IR_NODE_UNLINK(eltadd2_b, while0); + IR_NODE_UNLINK(split0_k_out, while0); + IR_NODE_UNLINK(split0_v_out, while0); + IR_NODE_UNLINK(while0, split0_k_out); + IR_NODE_UNLINK(while0, split0_v_out); }; int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { if (!IsCompat(subgraph, graph)) { - LOG(WARNING) << "fused_multi_transformer_encoder pass in " - "op compat failed."; - return; - } - - VLOG(4) << "handle MultiTransformer encoder fuse"; - GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern); - - GET_IR_NODE_FROM_SUBGRAPH( - layer_norm, layer_norm, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, - layer_norm_variance, - fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - layer_norm_out, layer_norm_out, fused_multi_transformer_pattern); - - GET_IR_NODE_FROM_SUBGRAPH( - matmul0, matmul0, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul0_out, matmul0_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul0_w, matmul0_w, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - reshape2_0, reshape2_0, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - transpose2_0, transpose2_0, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern); + LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv " + "pass in op compat failed."; + return; + } + VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse"; GET_IR_NODE_FROM_SUBGRAPH( - matmul1, matmul1, fused_multi_transformer_pattern); + input0, input0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( - matmul1_out, matmul1_out, fused_multi_transformer_pattern); + layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, + layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, + layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, + layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, + layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( - matmul1_w, matmul1_w, fused_multi_transformer_pattern); + matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - reshape2_1, reshape2_1, fused_multi_transformer_pattern); + matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern); + matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - transpose2_1, transpose2_1, fused_multi_transformer_pattern); + reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, + reshape2_0_out, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern); + transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, + transpose2_0_out, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul2, matmul2, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul2_out, matmul2_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul2_w, matmul2_w, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - reshape2_2, reshape2_2, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern); + split0, split0, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - transpose2_2, transpose2_2, fused_multi_transformer_pattern); + split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern); - + split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - attention_output, attention_output, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH(while0, while0, fused_multi_transformer_pattern) + split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, + ffn_layer_norm, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, ffn_layer_norm_scale, - fused_multi_transformer_pattern); + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, ffn_layer_norm_bias, - fused_multi_transformer_pattern); + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, ffn_layer_norm_mean, - fused_multi_transformer_pattern); + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, ffn_layer_norm_variance, - fused_multi_transformer_pattern); + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, ffn_layer_norm_out, - fused_multi_transformer_pattern); + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern); + ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, + ffn_matmul0_out, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern); + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern); + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, + ffn_eltadd0_out, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu, ffn_gelu, fused_multi_transformer_pattern); + ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern); + ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern); + ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, + ffn_matmul1_out, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern); + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern); + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, + ffn_eltadd1_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, + ffn_eltadd_out, + fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - ffn_output, ffn_output, fused_multi_transformer_pattern) + ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) // nodes need be removed GET_IR_NODE_FROM_SUBGRAPH( - eltadd0, eltadd0, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - eltadd0_b, eltadd0_b, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - eltadd0_out, eltadd0_out, fused_multi_transformer_pattern); - - GET_IR_NODE_FROM_SUBGRAPH( - eltadd1, eltadd1, fused_multi_transformer_pattern); + eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd1_b, eltadd1_b, fused_multi_transformer_pattern); + eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd1_out, eltadd1_out, fused_multi_transformer_pattern); + eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd2, eltadd2, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - eltadd2_b, eltadd2_b, fused_multi_transformer_pattern); + matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd2_out, eltadd2_out, fused_multi_transformer_pattern); + matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul_qk, matmul_qk, fused_multi_transformer_pattern); + scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern); + scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd_qk, eltadd_qk, fused_multi_transformer_pattern); + eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern); + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern); + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - softmax_qk, softmax_qk, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern); + softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, + softmax_qk_out, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern); + matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, + matmul_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( - transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern); + reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, + reshape2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, + transpose2_qkv, + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, - fused_multi_transformer_pattern); + fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul_linear, matmul_linear, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - eltadd_linear, eltadd_linear, fused_multi_transformer_pattern) + matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, + matmul_linear_w, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, + matmul_linear_out, + fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( - eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) + eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, + eltadd_linear_b, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, + eltadd_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( - eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) + eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( - eltadd_out, eltadd_out, fused_multi_transformer_pattern) + while0, while0, fused_multi_transformer_fuse_qkv_pattern) fuse_creater(input0, layer_norm, @@ -1900,13 +3118,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, layer_norm_variance, matmul0, matmul0_w, - matmul1_w, - matmul2_w, eltadd0_b, - eltadd1_b, - eltadd2_b, - transpose2_1_out, - transpose2_2_out, + split0_k_out, + split0_v_out, eltadd_qk_b, reshape2_0, matmul_linear, @@ -1924,6 +3138,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, + ffn_act, ffn_output); std::unordered_set marked_nodes({layer_norm, @@ -1931,31 +3146,21 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, layer_norm_variance, layer_norm_out, matmul0, - matmul1, - matmul2, matmul0_out, - matmul1_out, - matmul2_out, eltadd0, - eltadd1, - eltadd2, eltadd0_out, - eltadd1_out, - eltadd2_out, reshape2_0, - reshape2_1, - reshape2_2, reshape2_0_out, - reshape2_1_out, - reshape2_2_out, transpose2_0, - transpose2_1, - transpose2_2, transpose2_0_out, - transpose2_1_out, - transpose2_2_out, + split0, + split0_q_out, + split0_k_out, + split0_v_out, matmul_qk, matmul_qk_out, + scale_qk, + scale_qk_out, eltadd_qk, eltadd_qk_out, softmax_qk, @@ -1984,8 +3189,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ffn_eltadd1, ffn_eltadd0_out, ffn_eltadd1_out, - ffn_gelu, - ffn_gelu_out, + ffn_act, + ffn_act_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -1997,23 +3202,25 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, return fusion_count; } -void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { +void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE_NOT_NULL( scope, platform::errors::Fatal( - "During the multi_transformer pass, The scope should not be null.")); + "During the fused_multi_transformer_encoder pass, " + "The scope should not be null.")); int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { - graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } -FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { +FusedMultiTransformerEncoderFuseQKVPass:: + FusedMultiTransformerEncoderFuseQKVPass() { AddOpCompat(OpCompat("layer_norm")) .AddInput("X") .IsTensor() @@ -2033,12 +3240,29 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { .AddOutput("Variance") .IsTensor() .End() - .AddAttr("epsilon") - .IsNumGE(0.0f) - .IsNumLE(0.001f) + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. .End() - .AddAttr("begin_norm_axis") - .IsNumGT(0) + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() .End(); AddOpCompat(OpCompat("matmul_v2")) @@ -2168,56 +3392,54 @@ FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { .End(); } -int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( +int MultiDevicesFusedMultiTransformerEncoderPass::BuildFusion( Graph* graph, const std::string& name_scope, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - bool enable_int8 = graph->Get("enable_int8"); - if (enable_int8) { - VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with int8"; - } else { - VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with fp"; - } // Create pattern. - patterns::FusedMultiTransformerEncoderFuseQKVPattern - fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); - fused_multi_transformer_fuse_qkv_pattern(); + patterns::MultiDevicesFusedMultiTransformerEncoderPattern + multi_devices_fused_multi_transformer_pattern(pattern, name_scope); + multi_devices_fused_multi_transformer_pattern(); // Create New OpDesc auto fuse_creater = [&](Node* input0, + Node* c_identity, Node* layer_norm, Node* layer_norm_scale, Node* layer_norm_bias, Node* layer_norm_mean, Node* layer_norm_variance, - Node* matmul0, Node* matmul0_w, + Node* matmul1_w, + Node* matmul2_w, Node* eltadd0_b, - Node* split0_k_out, - Node* split0_v_out, + Node* eltadd1_b, + Node* eltadd2_b, + Node* transpose2_1_out, + Node* transpose2_2_out, Node* eltadd_qk_b, Node* reshape2_0, - Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* while0, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, - Node* ffn_matmul0, Node* ffn_matmul0_w, - Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_output) { - auto* matmul0_op = matmul0->Op(); - auto* matmul_linear_op = matmul_linear->Op(); - auto* ffn_matmul_0_op = ffn_matmul0->Op(); - auto* ffn_matmul_1_op = ffn_matmul1->Op(); + Node* ffn_act, + Node* ffn_layer_norm_out) { + auto reshape_desc = reshape2_0->Op(); + int num_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3); // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: @@ -2228,47 +3450,47 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); int layer_idx = atoi(ln_idx_str.c_str()) / 2; - auto* qkv_w_tensor = + auto* wq_tensor = scope->FindVar(matmul0_w->Name())->GetMutable(); - auto* qkv_b_tensor = + auto* wk_tensor = + scope->FindVar(matmul1_w->Name())->GetMutable(); + auto* wv_tensor = + scope->FindVar(matmul2_w->Name())->GetMutable(); + + auto* bq_tensor = scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); - // NOTE(minghaoBD): to make it compatible with strucutured pruning on - // num_head dimension: - // 1. get dim_head from reshape.shape[3], dim_embed from - // layer_norm_bias.shape[0] - // 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head - auto reshape_desc = reshape2_0->Op(); - int dim_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(3) / - 3; // 3 for qkv - auto* layer_norm_bias_tensor = - scope->FindVar(layer_norm_bias->Name())->GetMutable(); - int dim_embed = layer_norm_bias_tensor->dims()[0]; - int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head; + int dim_embed = wq_tensor->dims()[0]; - QKVWeightsBiasProcessFuseQKV( - qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + QKVWeightsBiasProcess(wq_tensor, + wk_tensor, + wv_tensor, + bq_tensor, + bk_tensor, + bv_tensor, + num_head, + dim_head, + dim_embed); - if (enable_int8) { - auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) - ->GetMutable(); - auto* ffn0_w_tensor = - scope->FindVar(ffn_matmul0_w->Name())->GetMutable(); - auto* ffn1_w_tensor = - scope->FindVar(ffn_matmul1_w->Name())->GetMutable(); + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = matmul0_w->Var(); + combined_w_desc->SetShape({3, num_head, dim_head, dim_embed}); + combined_w_desc->SetPersistable(true); - TransposeWeights(out_linear_w_tensor); - TransposeWeights(ffn0_w_tensor); - TransposeWeights(ffn1_w_tensor); - } + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, num_head, dim_head}); + combined_bias_desc->SetPersistable(true); + + scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()}); + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType(enable_int8 - ? "fused_multi_transformer_int8" - : "fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -2288,7 +3510,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); // FIXME: only support max_seq_len <= 1024 cache_kv_desc.SetDataType( - framework::TransToProtoVarType(qkv_b_tensor->dtype())); + framework::TransToProtoVarType(wq_tensor->dtype())); cache_kv_desc.SetPersistable(false); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); @@ -2301,9 +3523,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("value", 0); - fill_const_op_desc.SetAttr("dtype", - static_cast(framework::TransToProtoVarType( - qkv_b_tensor->dtype()))); + fill_const_op_desc.SetAttr( + "dtype", + static_cast(framework::TransToProtoVarType(wq_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -2329,137 +3551,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( {ffn_eltadd1_b->Name()}); // 2. Output setting - fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("Out", + {ffn_layer_norm_out->Name()}); fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()}); // Attribute setting - fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", false); fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); - // output dropout attribute fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + fused_multi_transformer_op_desc.SetAttr("act_method", + {ffn_act->Op()->Type()}); - // Quantization attribute/Input - if (enable_int8) { - // Set input scale - std::string qkv_input_name = matmul0_op->Input("X")[0]; - auto qkv_in_scale = PADDLE_GET_CONST( - float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name)); - std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; - auto out_linear_in_scale = PADDLE_GET_CONST( - float, - matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); - std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0]; - auto ffn0_in_scale = PADDLE_GET_CONST( - float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name)); - std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; - auto ffn1_in_scale = PADDLE_GET_CONST( - float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); - - // Calc outscale and Set them - // TODO(wufeisheng): Currently just match layer-wise weight scale, where - // channel-wise weight scale should also be surpported. - auto qkv_weight_scale = - PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); - auto out_weight_scale = - PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale")); - auto ffn0_weight_scale = - PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale")); - auto ffn1_weight_scale = - PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale")); - - auto qkv_out_scales = std::vector( - 3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f)); - auto out_out_scales = std::vector( - dim_embed, - (out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f)); - auto ffn0_out_scales = std::vector( - 4 * dim_embed, - (ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f)); - auto ffn1_out_scales = std::vector( - dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f)); - - // Inverse input scale - qkv_in_scale = 1.0f / qkv_in_scale; - out_linear_in_scale = 1.0f / out_linear_in_scale; - ffn0_in_scale = 1.0f / ffn0_in_scale; - ffn1_in_scale = 1.0f / ffn1_in_scale; - - fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", - std::vector{qkv_in_scale}); - fused_multi_transformer_op_desc.SetAttr( - "out_linear_in_scale", std::vector{out_linear_in_scale}); - fused_multi_transformer_op_desc.SetAttr( - "ffn1_in_scale", std::vector{ffn0_in_scale}); - fused_multi_transformer_op_desc.SetAttr( - "ffn2_in_scale", std::vector{ffn1_in_scale}); - - auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale"); - auto out_out_scale_var = - scope->Var(matmul_linear_w->Name() + "_out_scale"); - auto ffn0_out_scale_var = - scope->Var(ffn_matmul0_w->Name() + "_out_scale"); - auto ffn1_out_scale_var = - scope->Var(ffn_matmul1_w->Name() + "_out_scale"); - - auto qkv_out_scale_data = - qkv_out_scale_var->GetMutable() - ->mutable_data({3 * dim_embed}, platform::CPUPlace()); - memcpy(qkv_out_scale_data, - qkv_out_scales.data(), - qkv_out_scales.size() * sizeof(float)); - fused_multi_transformer_op_desc.SetInput( - "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); - - auto out_out_scale_data = - out_out_scale_var->GetMutable() - ->mutable_data({dim_embed}, platform::CPUPlace()); - memcpy(out_out_scale_data, - out_out_scales.data(), - out_out_scales.size() * sizeof(float)); - fused_multi_transformer_op_desc.SetInput( - "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); - - auto ffn0_out_scale_data = - ffn0_out_scale_var->GetMutable() - ->mutable_data({4 * dim_embed}, platform::CPUPlace()); - memcpy(ffn0_out_scale_data, - ffn0_out_scales.data(), - ffn0_out_scales.size() * sizeof(float)); - fused_multi_transformer_op_desc.SetInput( - "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); - - auto ffn1_out_scale_data = - ffn1_out_scale_var->GetMutable() - ->mutable_data({dim_embed}, platform::CPUPlace()); - memcpy(ffn1_out_scale_data, - ffn1_out_scales.data(), - ffn1_out_scales.size() * sizeof(float)); - fused_multi_transformer_op_desc.SetInput( - "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); - } + // parallel ring id + auto* c_identity_op = c_identity->Op(); + fused_multi_transformer_op_desc.SetAttr("ring_id", + c_identity_op->GetAttr("ring_id")); auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); - - if (enable_int8) { - auto qkv_out_scale_node = - CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); - auto out_out_scale_node = CreatePersistableVarNode( - graph, matmul_linear_w->Name() + "_out_scale"); - auto ffn0_out_scale_node = - CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); - auto ffn1_out_scale_node = - CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); - - IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); - IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); - IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); - IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); - } - IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -2468,300 +3580,383 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); - IR_NODE_LINK_TO(input0, fill_const_op); - IR_NODE_LINK_TO(fill_const_op, cache_kv); - IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); - - IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); - IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); - IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); - IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); - IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); - IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); - IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); - IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); - - IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); - - // rewrite while OP input - // 1. delete k, v - // 2. delete matmul1/2_w eltadd1/2_w - // 3. add cache_kv - auto while_Xs = while0->Op()->Input("X"); - while_Xs.erase( - std::remove( - std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()), - std::end(while_Xs)); - while_Xs.erase( - std::remove( - std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()), - std::end(while_Xs)); - while_Xs.emplace_back(cache_kv->Name()); - while0->Op()->SetInput("X", while_Xs); - - // rewrite while OP output - // 1. delete k, v - // 2. add cache_kv - auto while_Outs = while0->Op()->Output("Out"); - while_Outs.erase( - std::remove( - std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()), - std::end(while_Outs)); - while_Outs.erase( - std::remove( - std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()), - std::end(while_Outs)); - while_Outs.emplace_back(cache_kv->Name()); - while0->Op()->SetOutput("Out", while_Outs); + IR_NODE_LINK_TO(input0, fill_const_op); + IR_NODE_LINK_TO(fill_const_op, cache_kv); + IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); - // link CacheKV to while - IR_NODE_LINK_TO(cache_kv, while0) - // unlink origin KV output to while - IR_NODE_UNLINK(split0_k_out, while0); - IR_NODE_UNLINK(split0_v_out, while0); - IR_NODE_UNLINK(while0, split0_k_out); - IR_NODE_UNLINK(while0, split0_v_out); + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_layer_norm_out); }; int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { if (!IsCompat(subgraph, graph)) { - LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv " - "pass in op compat failed."; + LOG(WARNING) << "fused_multi_transformer_encoder pass in " + "op compat failed."; return; } - VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse"; - GET_IR_NODE_FROM_SUBGRAPH( - input0, input0, fused_multi_transformer_fuse_qkv_pattern); - + VLOG(4) << "handle MultiTransformer encoder fuse"; GET_IR_NODE_FROM_SUBGRAPH( - layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); + input0, input0, multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(c_identity0, + c_identity0, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity0_out, + c_identity0_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity1, + c_identity1, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity1_out, + c_identity1_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity2, + c_identity2, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity2_out, + c_identity2_out, + multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity, + ffn_c_identity, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out, + ffn_c_identity_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); + matmul0, matmul0, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul0_out, + matmul0_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + matmul0_w, matmul0_w, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); + reshape2_0, reshape2_0, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, - fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, + transpose2_0, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - split0, split0, fused_multi_transformer_fuse_qkv_pattern); + matmul1, matmul1, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul1_out, + matmul1_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); + matmul1_w, matmul1_w, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); + reshape2_1, reshape2_1, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, + reshape2_1_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, + transpose2_1, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, + transpose2_1_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( - split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); + scale_q, scale_q, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_q_out, + scale_q_out, + multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul2, matmul2, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul2_out, + matmul2_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_w, matmul2_w, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2, reshape2_2, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, + reshape2_2_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, + transpose2_2, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, + transpose2_2_out, + multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(attention_output, + attention_output, + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, ffn_layer_norm, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, ffn_layer_norm_scale, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, ffn_layer_norm_bias, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, ffn_layer_norm_mean, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, ffn_layer_norm_variance, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, ffn_layer_norm_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0, + ffn_matmul0, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, ffn_matmul0_out, - fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_w, + ffn_matmul0_w, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0, + ffn_eltadd0, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_b, + ffn_eltadd0_b, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, ffn_eltadd0_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + ffn_act, ffn_act, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_act_out, + ffn_act_out, + multi_devices_fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1, + ffn_matmul1, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, ffn_matmul1_out, - fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_w, + ffn_matmul1_w, + multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum, + ffn_c_allreduce_sum, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out, + ffn_c_allreduce_sum_out, + multi_devices_fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1, + ffn_eltadd1, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_b, + ffn_eltadd1_b, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, ffn_eltadd1_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, ffn_eltadd_out, - fused_multi_transformer_fuse_qkv_pattern) + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( - ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) + ffn_output, ffn_output, multi_devices_fused_multi_transformer_pattern) // nodes need be removed GET_IR_NODE_FROM_SUBGRAPH( - eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + eltadd0, eltadd0, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); + eltadd0_b, eltadd0_b, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, + eltadd0_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); + eltadd1, eltadd1, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + eltadd1_b, eltadd1_b, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, + eltadd1_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); + eltadd2, eltadd2, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); + eltadd2_b, eltadd2_b, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, + eltadd2_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); + matmul_qk, matmul_qk, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, + matmul_qk_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( - eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); + eltadd_qk, eltadd_qk, multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, + eltadd_qk_b, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, + eltadd_qk_out, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); + softmax_qk, softmax_qk, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, - fused_multi_transformer_fuse_qkv_pattern); - + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( - matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); + matmul_qkv, matmul_qkv, multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, + reshape2_qkv, + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, - fused_multi_transformer_fuse_qkv_pattern); + multi_devices_fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear, + matmul_linear, + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, matmul_linear_w, - fused_multi_transformer_fuse_qkv_pattern) + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, matmul_linear_out, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern) + multi_devices_fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum, + c_allreduce_sum, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out, + c_allreduce_sum_out, + multi_devices_fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear, + eltadd_linear, + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, eltadd_linear_b, - fused_multi_transformer_fuse_qkv_pattern) + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, eltadd_linear_out, - fused_multi_transformer_fuse_qkv_pattern) - - GET_IR_NODE_FROM_SUBGRAPH( - eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) - + multi_devices_fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( - while0, while0, fused_multi_transformer_fuse_qkv_pattern) + eltadd_out, eltadd_out, multi_devices_fused_multi_transformer_pattern) fuse_creater(input0, + c_identity0, layer_norm, layer_norm_scale, layer_norm_bias, layer_norm_mean, layer_norm_variance, - matmul0, matmul0_w, + matmul1_w, + matmul2_w, eltadd0_b, - split0_k_out, - split0_v_out, + eltadd1_b, + eltadd2_b, + transpose2_1_out, + transpose2_2_out, eltadd_qk_b, reshape2_0, - matmul_linear, matmul_linear_w, eltadd_linear_b, - while0, ffn_layer_norm, ffn_layer_norm_scale, ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, - ffn_matmul0, ffn_matmul0_w, - ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_output); - - std::unordered_set marked_nodes({layer_norm, + ffn_act, + ffn_layer_norm_out); + + std::unordered_set marked_nodes({c_identity0, + c_identity0_out, + c_identity1, + c_identity1_out, + c_identity2, + c_identity2_out, + layer_norm, layer_norm_mean, layer_norm_variance, layer_norm_out, matmul0, + matmul1, + matmul2, matmul0_out, + matmul1_out, + matmul2_out, eltadd0, + eltadd1, + eltadd2, eltadd0_out, + eltadd1_out, + eltadd2_out, reshape2_0, + reshape2_1, + reshape2_2, reshape2_0_out, + reshape2_1_out, + reshape2_2_out, transpose2_0, + transpose2_1, + transpose2_2, transpose2_0_out, - split0, - split0_q_out, - split0_k_out, - split0_v_out, + transpose2_1_out, + transpose2_2_out, + scale_q, + scale_q_out, matmul_qk, matmul_qk_out, - scale_qk, - scale_qk_out, eltadd_qk, eltadd_qk_out, softmax_qk, @@ -2775,23 +3970,29 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( transpose2_qkv_out, matmul_linear, matmul_linear_out, + c_allreduce_sum, + c_allreduce_sum_out, eltadd_linear, eltadd_linear_out, eltadd_out, ffn_layer_norm, ffn_layer_norm_mean, ffn_layer_norm_variance, - ffn_layer_norm_out, + ffn_c_identity, + ffn_c_identity_out, ffn_matmul0, ffn_matmul1, ffn_matmul0_out, ffn_matmul1_out, + ffn_c_allreduce_sum, + ffn_c_allreduce_sum_out, ffn_eltadd0, ffn_eltadd1, ffn_eltadd0_out, ffn_eltadd1_out, - ffn_gelu, - ffn_gelu_out, + ffn_act, + ffn_act_out, + ffn_output, ffn_eltadd_out}); // Remove unneeded nodes. @@ -2803,25 +4004,25 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( return fusion_count; } -void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { +void MultiDevicesFusedMultiTransformerEncoderPass::ApplyImpl( + Graph* graph) const { FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); PADDLE_ENFORCE_NOT_NULL( scope, platform::errors::Fatal( - "During the fused_multi_transformer_encoder pass, " - "The scope should not be null.")); + "During the multi_transformer pass, The scope should not be null.")); int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { - graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } -FusedMultiTransformerEncoderFuseQKVPass:: - FusedMultiTransformerEncoderFuseQKVPass() { +MultiDevicesFusedMultiTransformerEncoderPass:: + MultiDevicesFusedMultiTransformerEncoderPass() { AddOpCompat(OpCompat("layer_norm")) .AddInput("X") .IsTensor() @@ -2849,23 +4050,6 @@ FusedMultiTransformerEncoderFuseQKVPass:: .IsNumGT(0) .End(); - AddOpCompat(OpCompat("scale")) - .AddInput("X") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("scale") - .IsType() // copy to new op. so unconstrained. - .End() - .AddAttr("bias") - .IsNumEQ(0.f) - .End() - .AddAttr("bias_after_scale") // bias is 0, so unconstrained. - .IsType() - .End(); - AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") // the shape shoule be (B, S, N*H) .IsTensor() @@ -2935,24 +4119,20 @@ FusedMultiTransformerEncoderFuseQKVPass:: .IsType>() .End(); - AddOpCompat(OpCompat("matmul")) + AddOpCompat(OpCompat("scale")) .AddInput("X") .IsTensor() .End() - .AddInput("Y") - .IsTensor() - .End() .AddOutput("Out") .IsTensor() .End() - .AddAttr("alpha") - .IsNumGE(0.0f) - .IsNumLE(1.0f) + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. .End() - .AddAttr("transpose_X") - .IsBoolEQ(false) + .AddAttr("bias") + .IsNumEQ(0.f) .End() - .AddAttr("transpose_Y") + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. .IsType() .End(); @@ -2978,18 +4158,12 @@ FusedMultiTransformerEncoderFuseQKVPass:: .IsType() .End(); - AddOpCompat(OpCompat("while")) - .AddInput("X") // A set of variables, unconstrained - .End() - .AddInput("Condition") // An scalar + AddOpCompat(OpCompat("relu")) + .AddInput("X") .IsTensor() .End() - .AddOutput("Out") // A set of variables, unconstrained - .End() - .AddOutput("StepScopes") // A vector of local scope, unconstrained - .End() - .AddAttr("sub_block") - .IsType() + .AddOutput("Out") + .IsTensor() .End(); } @@ -3040,6 +4214,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, + Node* ffn_act, Node* ffn_output) { auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); @@ -3163,6 +4338,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); + fused_multi_transformer_op_desc.SetAttr("act_method", + ffn_act->Op()->Type()); // output dropout attribute fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); @@ -3175,6 +4352,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( // Quantization attribute/Input if (enable_int8) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); // Set input scale std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0]; auto qkv_in_scale = PADDLE_GET_CONST( @@ -3240,36 +4419,44 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( auto ffn1_out_scale_var = scope->Var(ffn_matmul1_w->Name() + "_out_scale"); - auto qkv_out_scale_data = - qkv_out_scale_var->GetMutable() - ->mutable_data({3 * dim_embed}, platform::CPUPlace()); + auto* qkv_out_scale_tensor = + qkv_out_scale_var->GetMutable(); + qkv_out_scale_tensor->Resize({3 * dim_embed}); + dev_ctx->Alloc(qkv_out_scale_tensor); + auto qkv_out_scale_data = qkv_out_scale_tensor->data(); memcpy(qkv_out_scale_data, qkv_out_scales.data(), qkv_out_scales.size() * sizeof(float)); fused_multi_transformer_op_desc.SetInput( "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); - auto out_out_scale_data = - out_out_scale_var->GetMutable() - ->mutable_data({dim_embed}, platform::CPUPlace()); + auto* out_out_scale_tensor = + out_out_scale_var->GetMutable(); + out_out_scale_tensor->Resize({dim_embed}); + dev_ctx->Alloc(out_out_scale_tensor); + auto out_out_scale_data = out_out_scale_tensor->data(); memcpy(out_out_scale_data, out_out_scales.data(), out_out_scales.size() * sizeof(float)); fused_multi_transformer_op_desc.SetInput( "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); - auto ffn0_out_scale_data = - ffn0_out_scale_var->GetMutable() - ->mutable_data({4 * dim_embed}, platform::CPUPlace()); + auto* ffn0_out_scale_tensor = + ffn0_out_scale_var->GetMutable(); + ffn0_out_scale_tensor->Resize({4 * dim_embed}); + dev_ctx->Alloc(ffn0_out_scale_tensor); + auto ffn0_out_scale_data = ffn0_out_scale_tensor->data(); memcpy(ffn0_out_scale_data, ffn0_out_scales.data(), ffn0_out_scales.size() * sizeof(float)); fused_multi_transformer_op_desc.SetInput( "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); - auto ffn1_out_scale_data = - ffn1_out_scale_var->GetMutable() - ->mutable_data({dim_embed}, platform::CPUPlace()); + auto* ffn1_out_scale_tensor = + ffn1_out_scale_var->GetMutable(); + ffn1_out_scale_tensor->Resize({dim_embed}); + dev_ctx->Alloc(ffn1_out_scale_tensor); + auto ffn1_out_scale_data = ffn1_out_scale_tensor->data(); memcpy(ffn1_out_scale_data, ffn1_out_scales.data(), ffn1_out_scales.size() * sizeof(float)); @@ -3464,9 +4651,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( - ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); @@ -3603,6 +4790,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, + ffn_act, ffn_output); std::unordered_set marked_nodes({layer_norm, @@ -3661,8 +4849,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_eltadd1, ffn_eltadd0_out, ffn_eltadd1_out, - ffn_gelu, - ffn_gelu_out, + ffn_act, + ffn_act_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -3874,6 +5062,9 @@ REGISTER_PASS(fused_multi_transformer_encoder_pass, paddle::framework::ir::FusedMultiTransformerEncoderPass); REGISTER_PASS(fused_multi_transformer_encoder_fuse_qkv_pass, paddle::framework::ir::FusedMultiTransformerEncoderFuseQKVPass); +REGISTER_PASS( + multi_devices_fused_multi_transformer_encoder_pass, + paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderPass); REGISTER_PASS( multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass, paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderFuseQKVPass); @@ -3898,6 +5089,16 @@ REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass) .LE("matmul", 1) .EQ("matmul_v2", 0) .EQ("softmax", 0)); +REGISTER_PASS_CAPABILITY(multi_devices_fused_multi_transformer_encoder_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); REGISTER_PASS_CAPABILITY( multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass) .AddCombination( diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h index ae7f0e9761..6274a69737 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h @@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { // Q, K, V path PATTERN_DECL_NODE(input0); - PATTERN_DECL_NODE(layer_norm); - PATTERN_DECL_NODE(layer_norm_scale); - PATTERN_DECL_NODE(layer_norm_bias); - PATTERN_DECL_NODE(layer_norm_mean); - PATTERN_DECL_NODE(layer_norm_variance); - PATTERN_DECL_NODE(layer_norm_out); PATTERN_DECL_NODE(matmul0); PATTERN_DECL_NODE(matmul1); PATTERN_DECL_NODE(matmul2); @@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { PATTERN_DECL_NODE(transpose2_0_out); PATTERN_DECL_NODE(transpose2_1_out); PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(scale_q); + PATTERN_DECL_NODE(scale_q_out); // Q, K matmul PATTERN_DECL_NODE(matmul_qk); @@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(attention_output); - // while loop - PATTERN_DECL_NODE(while0); + // post layer_norm + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); // Feed Forward nodes - PATTERN_DECL_NODE(ffn_layer_norm); - PATTERN_DECL_NODE(ffn_layer_norm_scale); - PATTERN_DECL_NODE(ffn_layer_norm_bias); - PATTERN_DECL_NODE(ffn_layer_norm_mean); - PATTERN_DECL_NODE(ffn_layer_norm_variance); - PATTERN_DECL_NODE(ffn_layer_norm_out); PATTERN_DECL_NODE(ffn_matmul0); PATTERN_DECL_NODE(ffn_matmul0_w); PATTERN_DECL_NODE(ffn_matmul0_out); PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_out); - PATTERN_DECL_NODE(ffn_gelu); - PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_out); @@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_output); + + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); }; struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { @@ -212,11 +216,127 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_out); - PATTERN_DECL_NODE(ffn_gelu); - PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; + +struct MultiDevicesFusedMultiTransformerEncoderPattern : public PatternBase { + MultiDevicesFusedMultiTransformerEncoderPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, + name_scope, + "multi_devices_fused_multi_transformer_encoder") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(c_identity0); + PATTERN_DECL_NODE(c_identity0_out); + PATTERN_DECL_NODE(c_identity1); + PATTERN_DECL_NODE(c_identity1_out); + PATTERN_DECL_NODE(c_identity2); + PATTERN_DECL_NODE(c_identity2_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul1); + PATTERN_DECL_NODE(matmul2); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul1_w); + PATTERN_DECL_NODE(matmul2_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(matmul1_out); + PATTERN_DECL_NODE(matmul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(scale_q); + PATTERN_DECL_NODE(scale_q_out); + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(c_allreduce_sum); + PATTERN_DECL_NODE(c_allreduce_sum_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // post layer_norm + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_c_identity); + PATTERN_DECL_NODE(ffn_c_identity_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_c_allreduce_sum); + PATTERN_DECL_NODE(ffn_c_allreduce_sum_out); PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); @@ -224,6 +344,13 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_output); + + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); }; struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern @@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd0_out); - PATTERN_DECL_NODE(ffn_gelu); - PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_act); + PATTERN_DECL_NODE(ffn_act_out); PATTERN_DECL_NODE(ffn_matmul1); PATTERN_DECL_NODE(ffn_matmul1_w); PATTERN_DECL_NODE(ffn_matmul1_out); @@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase { Scope* scope) const; }; +class MultiDevicesFusedMultiTransformerEncoderPass : public FusePassBase { + public: + MultiDevicesFusedMultiTransformerEncoderPass(); + virtual ~MultiDevicesFusedMultiTransformerEncoderPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{ + "multi_devices_fused_multi_transformer_encoder"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass : public FusePassBase { public: diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc index 5542a802b2..844c2124bd 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc @@ -56,8 +56,8 @@ Scope* CreateParamScope() { // FFN: fc1 -> (gelu) -> fc2 AddVarToScope(param_scope, "ffn_weights0", {1024, 4096}); AddVarToScope(param_scope, "ffn_weights1", {4096, 1024}); - AddVarToScope(param_scope, "ffn_bias_0", {4096}); - AddVarToScope(param_scope, "ffn_bias_1", {1024}); + AddVarToScope(param_scope, "ffn_bias0", {4096}); + AddVarToScope(param_scope, "ffn_bias1", {1024}); return param_scope; } @@ -65,10 +65,9 @@ Scope* CreateParamScope() { TEST(FusedMultiTransformerEncoderPass, basic) { // inputs operator output // -------------------------------------------------------------------- - // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out - // (layer_norm_out, weights_0) matmul_v2 -> matmul_out0 - // (layer_norm_out, weights_1) matmul_v2 -> matmul_out1 - // (layer_norm_out, weights_2) matmul_v2 -> matmul_out2 + // (x, weights_0) matmul_v2 -> matmul_out0 + // (x, weights_1) matmul_v2 -> matmul_out1 + // (x, weights_2) matmul_v2 -> matmul_out2 // (matmul_out0, bias_0) elementwise_add -> eltadd_0 // (matmul_out1, bias_1) elementwise_add -> eltadd_1 // (matmul_out2, bias_2) elementwise_add -> eltadd_2 @@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) { // (reshape_0) transpose2 -> transpose_0 // (reshape_1) transpose2 -> transpose_1 // (reshape_2) transpose2 -> transpose_2 - // (transpose_0, transpose_1) matmul -> matmul_qk + // (transpose_0) scale -> scale_0 + // (scale_0, transpose_1) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv @@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) { // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) elementwise_add -> eltadd_linear - // (eltadd_out) elementwise_add -> attention_out + // (eltadd_linear) elementwise_add -> attention_out // - // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (attention_out, scale, bias) layer_norm -> layer_norm_out // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 // (ffn_eltadd0) gelu -> ffn_gelu // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 - // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output - // - // (transpose_1, transpose_2) while -> decoder block + // (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output + // (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out Layers layers; // MHA: pre LayerNorm auto* x = layers.data("x", {1, 128, 1024}); - auto* ln_scale = layers.data("ln_scale", {1024}, true); - auto* ln_bias = layers.data("ln_bias", {1024}, true); - auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; // MHA: QKV fc auto* weights_0 = layers.data("weights0", {1024, 1024}, true); auto* weights_1 = layers.data("weights1", {1024, 1024}, true); auto* weights_2 = layers.data("weights2", {1024, 1024}, true); - auto* matmul_out_0 = - layers.matmul_v2(ln_out, weights_0, nullptr, false, true); - auto* matmul_out_1 = - layers.matmul_v2(ln_out, weights_1, nullptr, false, true); - auto* matmul_out_2 = - layers.matmul_v2(ln_out, weights_2, nullptr, false, true); + auto* matmul_out_0 = layers.matmul_v2(x, weights_0, nullptr, false, false); + auto* matmul_out_1 = layers.matmul_v2(x, weights_1, nullptr, false, false); + auto* matmul_out_2 = layers.matmul_v2(x, weights_2, nullptr, false, false); auto* b0 = layers.data("bias_0", {1024}, true); auto* b1 = layers.data("bias_1", {1024}, true); @@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) { auto* transpose_1 = layers.transpose2(reshape_1, axis, true); auto* transpose_2 = layers.transpose2(reshape_2, axis, true); - // Link to decoder while block - layers.while_loop({transpose_1, transpose_2}); - + // q scale + auto* scale_q = layers.scale(transpose_0, 0.125, 0, false); // MHA: QK matmul auto* matmul_qk = - layers.matmul(transpose_0, transpose_1, nullptr, false, true); + layers.matmul_v2(scale_q, transpose_1, nullptr, false, true); - auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1); @@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) { // MHA: out Linear auto* weights_l = layers.data("weights_l", {1024, 1024}, true); - auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* bias_l = layers.data("bias_l", {1024}, true); auto* linear_matmut_out = - layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, false); auto* linear_eltadd_out = layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); - // FFN: pre LayerNorm - auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); - auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); - auto* ffn_ln_out = - layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + // post LayerNorm + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(attention_out, ln_scale, ln_bias)[0]; // FFN: fc1 -> gelu -> fc2 auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); @@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); auto* ffn_matmul0_out = - layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true); + layers.matmul_v2(ln_out, ffn_weights0, nullptr, false, true); auto* ffn_eltadd0_out = layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); @@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); - layers.elementwise_add(attention_out, ffn_eltadd1_out); + auto* ffn_out = layers.elementwise_add(ln_out, ffn_eltadd1_out); + + // FFN: post LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + layers.layer_norm(ffn_out, ffn_ln_scale, ffn_ln_bias)[0]; std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); PADDLE_ENFORCE_EQ(num_nodes_before, - num_nodes_after + 56, + num_nodes_after + 58, platform::errors::InvalidArgument( "After the fused_multi_transformer_encoder_pass, The " "node num in graph " "should be %d, but the result is %d", - num_nodes_before - 56, + num_nodes_before - 58, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) { .IsPassCompatible("fused_multi_transformer_encoder_pass")); } +TEST(MultiDevicesFusedMultiTransformerEncoderPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x) c_identity -> c_identity0_out + // (x) c_identity -> c_identity1_out + // (x) c_identity -> c_identity2_out + // (c_identity0_out, weights_0) matmul_v2 -> matmul_out0 + // (c_identity1_out, weights_1) matmul_v2 -> matmul_out1 + // (c_identity2_out, weights_2) matmul_v2 -> matmul_out2 + // (matmul_out0, bias_0) elementwise_add -> eltadd_0 + // (matmul_out1, bias_1) elementwise_add -> eltadd_1 + // (matmul_out2, bias_2) elementwise_add -> eltadd_2 + // (eltadd_0) reshape2 -> reshape_0 + // (eltadd_1) reshape2 -> reshape_1 + // (eltadd_2) reshape2 -> reshape_2 + // (reshape_0) transpose2 -> transpose_0 + // (reshape_1) transpose2 -> transpose_1 + // (reshape_2) transpose2 -> transpose_2 + // (transpose_0) scale -> scale_0 + // (scale_0, transpose_1) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) c_all_reduce -> c_all_reduce_out + // (c_all_reduce_out) elementwise_add -> eltadd_linear + // (eltadd_linear) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> layer_norm_out + // (layer_norm_out) c_identity -> ffn_c_identity_out + // (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out + // (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1 + // (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output + // (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* c_identity0_out = layers.c_identity(x); + auto* c_identity1_out = layers.c_identity(x); + auto* c_identity2_out = layers.c_identity(x); + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 1024}, true); + auto* weights_1 = layers.data("weights1", {1024, 1024}, true); + auto* weights_2 = layers.data("weights2", {1024, 1024}, true); + auto* matmul_out_0 = + layers.matmul_v2(c_identity0_out, weights_0, nullptr, false, false); + auto* matmul_out_1 = + layers.matmul_v2(c_identity1_out, weights_1, nullptr, false, false); + auto* matmul_out_2 = + layers.matmul_v2(c_identity2_out, weights_2, nullptr, false, false); + + auto* b0 = layers.data("bias_0", {1024}, true); + auto* b1 = layers.data("bias_1", {1024}, true); + auto* b2 = layers.data("bias_2", {1024}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + auto* elementwise_out_1 = + layers.elementwise_add(matmul_out_1, b1, nullptr, 2); + auto* elementwise_out_2 = + layers.elementwise_add(matmul_out_2, b2, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true); + auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + auto* transpose_1 = layers.transpose2(reshape_1, axis, true); + auto* transpose_2 = layers.transpose2(reshape_2, axis, true); + + // q scale + auto* scale_q = layers.scale(transpose_0, 0.125, 0, false); + + // MHA: QK matmul + auto* matmul_qk = + layers.matmul_v2(scale_q, transpose_1, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(softmax_qk, transpose_2); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("bias_l", {1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, false); + auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out); + auto* linear_eltadd_out = + layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); + + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); + + // post LayerNorm + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(attention_out, ln_scale, ln_bias)[0]; + auto* ffn_c_identity_out = layers.c_identity(ln_out); + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, false); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, false); + auto* ffn_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2); + + auto* ffn_out = layers.elementwise_add(ln_out, ffn_eltadd1_out); + + // FFN: post LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + layers.layer_norm(ffn_out, ffn_ln_scale, ffn_ln_bias)[0]; + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + graph->Set("enable_int8", new bool(false)); + + auto pass = PassRegistry::Instance().Get( + "multi_devices_fused_multi_transformer_encoder_pass"); + if (pass.get() == nullptr) + LOG(INFO) + << "get multi_devices_fused_multi_transformer_encoder_pass failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ(num_nodes_before, + num_nodes_after + 70, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder_pass, The " + "node num in graph " + "should be %d, but the result is %d", + num_nodes_before - 70, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder pass, " + "there should be one fused_multi_transformer op, " + "but the result is %d", + num_fused_nodes_after)); +} + +TEST(MultiDevicesFusedMultiTransformerEncoderPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible( + "multi_devices_fused_multi_transformer_encoder_pass")); +} + TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { // inputs operator output // -------------------------------------------------------------------- @@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true); auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); - auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); @@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true); auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); - auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true); auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); @@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, USE_PASS(fused_multi_transformer_encoder_pass); USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass); +USE_PASS(multi_devices_fused_multi_transformer_encoder_pass); USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b4018d883a..ac46c486d8 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -179,6 +179,7 @@ const std::vector kGpuLowerPrecisionPasses{ "fused_multi_transformer_decoder_pass", "fused_multi_transformer_encoder_fuse_qkv_pass", "fused_multi_transformer_decoder_fuse_qkv_pass", + "multi_devices_fused_multi_transformer_encoder_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "fuse_multi_transformer_layer_pass", @@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "fused_multi_transformer_decoder_pass", // "fused_multi_transformer_encoder_fuse_qkv_pass", // "fused_multi_transformer_decoder_fuse_qkv_pass", // + "multi_devices_fused_multi_transformer_encoder_pass", // "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", // "fuse_multi_transformer_layer_pass", // -- GitLab