From ba4fbe71ee47657bba68637dc1774595f645085c Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 3 Nov 2022 11:08:19 +0800 Subject: [PATCH] [cherry pick] fix memory copy in prepare_data of FusedMultiTransformer pass (#47308) * fix memory copy in prepare_data. test=develop * add cache_kv fp16 support. test=develop * fit for simplify_with_basic_ops_pass. test=develop --- .../fused_multi_transformer_decoder_pass.cc | 278 ++++------------- .../ir/fused_multi_transformer_decoder_pass.h | 18 -- ...d_multi_transformer_decoder_pass_tester.cc | 66 ++-- .../fused_multi_transformer_encoder_pass.cc | 287 ++++-------------- .../ir/fused_multi_transformer_encoder_pass.h | 18 -- ...d_multi_transformer_encoder_pass_tester.cc | 66 ++-- paddle/fluid/framework/ir/pass.cc | 1 + 7 files changed, 154 insertions(+), 580 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index 5559499e0b4..ef896e9c7e8 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -237,15 +237,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_qk = - pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); - auto* dropout_qk_out_var = - pattern->NewNode(dropout_qk_out_repr()) - ->assert_is_op_output("dropout", "Out") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk matmul_qk->LinksFrom({transpose2_0_out_var, concat_0_out_var}) @@ -253,7 +245,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); // QKV path Nodes auto* matmul_qkv = @@ -294,14 +285,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_linear = - pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); - auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add"); + ->assert_is_op_input("elementwise_add"); auto* eltadd_out = pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); @@ -310,7 +294,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({dropout_qk_out_var, concat_1_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, concat_1_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -320,9 +304,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { .LinksTo({matmul_linear_out_var}); eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) .LinksTo({eltadd_linear_out_var}); - dropout_linear->LinksFrom({eltadd_linear_out_var}) - .LinksTo({dropout_linear_out_var}); - eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); // Feed Forward LayerNorm Nodes @@ -358,7 +340,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ffn_layer_norm_mean_var, ffn_layer_norm_variance_var}); - // Feed Forward fc1 -> gelu -> fc2 -> dropout + // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) @@ -403,13 +385,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* ffn_dropout = - pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); - auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() ->assert_is_op_input("elementwise_add"); auto* ffn_eltadd_out = @@ -427,9 +402,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); return ffn_output; @@ -575,15 +549,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_qk = - pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); - auto* dropout_qk_out_var = - pattern->NewNode(dropout_qk_out_repr()) - ->assert_is_op_output("dropout", "Out") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) @@ -591,7 +557,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); // QKV path Nodes auto* matmul_qkv = @@ -632,14 +597,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_linear = - pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); - auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add"); + ->assert_is_op_input("elementwise_add"); auto* eltadd_out = pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); @@ -648,7 +606,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, concat_v_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -658,9 +616,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { .LinksTo({matmul_linear_out_var}); eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) .LinksTo({eltadd_linear_out_var}); - dropout_linear->LinksFrom({eltadd_linear_out_var}) - .LinksTo({dropout_linear_out_var}); - eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); // Feed Forward LayerNorm Nodes @@ -696,7 +652,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ffn_layer_norm_mean_var, ffn_layer_norm_variance_var}); - // Feed Forward fc1 -> gelu -> fc2 -> dropout + // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) @@ -741,13 +697,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* ffn_dropout = - pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); - auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() ->assert_is_op_input("elementwise_add"); auto* ffn_eltadd_out = @@ -765,9 +714,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); return ffn_output; @@ -922,15 +870,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_qk = - pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); - auto* dropout_qk_out_var = - pattern->NewNode(dropout_qk_out_repr()) - ->assert_is_op_output("dropout", "Out") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) @@ -938,7 +878,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); // QKV path Nodes auto* matmul_qkv = @@ -987,14 +926,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_linear = - pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); - auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add"); + ->assert_is_op_input("elementwise_add"); auto* eltadd_out = pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); @@ -1003,7 +935,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, concat_v_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -1015,9 +947,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { .LinksTo({c_allreduce_sum_out_var}); eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) .LinksTo({eltadd_linear_out_var}); - dropout_linear->LinksFrom({eltadd_linear_out_var}) - .LinksTo({dropout_linear_out_var}); - eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); // Feed Forward LayerNorm Nodes @@ -1063,7 +993,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) .LinksTo({ffn_c_identity_out_var}); - // Feed Forward fc1 -> gelu -> fc2 -> dropout + // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) @@ -1117,13 +1047,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* ffn_dropout = - pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); - auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() ->assert_is_op_input("elementwise_add"); auto* ffn_eltadd_out = @@ -1143,9 +1066,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { .LinksTo({ffn_c_allreduce_sum_out_var}); ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); return ffn_output; @@ -1180,11 +1102,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, Node* transpose2_1_out, Node* transpose2_2_out, Node* eltadd_qk_b, - Node* dropout_qk, Node* reshape2_0, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* dropout_linear, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, Node* ffn_layer_norm_bias, @@ -1194,7 +1114,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_dropout, Node* ffn_output) { // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: @@ -1287,14 +1206,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, "epsilon", layer_norm->Op()->GetAttr("epsilon")); // output dropout attribute - auto* dropout_op = dropout_linear->Op(); - fused_multi_transformer_op_desc.SetAttr( - "dropout_rate", dropout_op->GetAttr("dropout_prob")); - fused_multi_transformer_op_desc.SetAttr("is_test", - dropout_op->GetAttr("is_test")); - fused_multi_transformer_op_desc.SetAttr( - "dropout_implementation", - dropout_op->GetAttr("dropout_implementation")); + fused_multi_transformer_op_desc.SetAttr("is_test", true); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); @@ -1313,6 +1226,15 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, IR_NODE_LINK_TO(slice_op, slice_out); IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); }; @@ -1451,11 +1373,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH( ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout, ffn_dropout, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( @@ -1499,10 +1416,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, softmax_qk, softmax_qk, fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk, dropout_qk, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); @@ -1531,10 +1444,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - dropout_linear, dropout_linear, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( eltadd_out, eltadd_out, fused_multi_transformer_pattern) @@ -1554,11 +1463,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, transpose2_1_out, transpose2_2_out, eltadd_qk_b, - dropout_qk, reshape2_0, matmul_linear_w, eltadd_linear_b, - dropout_linear, ffn_layer_norm, ffn_layer_norm_scale, ffn_layer_norm_bias, @@ -1568,12 +1475,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_dropout, ffn_output); std::unordered_set marked_nodes({layer_norm, - layer_norm_scale, - layer_norm_bias, layer_norm_mean, layer_norm_variance, layer_norm_out, @@ -1613,8 +1517,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, eltadd_qk_out, softmax_qk, softmax_qk_out, - dropout_qk, - dropout_qk_out, transpose2_qkv, transpose2_qkv_out, matmul_qkv, @@ -1623,17 +1525,11 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, transpose2_qkv, transpose2_qkv_out, matmul_linear, - matmul_linear_w, matmul_linear_out, eltadd_linear, - eltadd_linear_b, eltadd_linear_out, - dropout_linear, - dropout_linear_out, eltadd_out, ffn_layer_norm, - ffn_layer_norm_scale, - ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, ffn_layer_norm_out, @@ -1647,8 +1543,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ffn_eltadd1_out, ffn_gelu, ffn_gelu_out, - ffn_dropout, - ffn_dropout_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -1850,11 +1744,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* matmul0_w, Node* eltadd0_b, Node* eltadd_qk_b, - Node* dropout_qk, Node* reshape2_0, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* dropout_linear, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, Node* ffn_layer_norm_bias, @@ -1864,7 +1756,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_dropout, Node* ffn_output) { // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: @@ -1957,17 +1848,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( "epsilon", layer_norm->Op()->GetAttr("epsilon")); // output dropout attribute - auto* dropout_op = dropout_linear->Op(); - fused_multi_transformer_op_desc.SetAttr( - "dropout_rate", dropout_op->GetAttr("dropout_prob")); - fused_multi_transformer_op_desc.SetAttr("is_test", - dropout_op->GetAttr("is_test")); - fused_multi_transformer_op_desc.SetAttr( - "dropout_implementation", - dropout_op->GetAttr("dropout_implementation")); - - // fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"}); - // fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true}); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + fused_multi_transformer_op_desc.SetAttr("is_test", true); auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); @@ -1986,6 +1868,15 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(slice_op, slice_out); IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); }; @@ -2116,12 +2007,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, - ffn_dropout_out, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -2153,11 +2038,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, - dropout_qk_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); @@ -2193,12 +2073,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, - dropout_linear, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, - dropout_linear_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -2212,11 +2086,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( matmul0_w, eltadd0_b, eltadd_qk_b, - dropout_qk, reshape2_0, matmul_linear_w, eltadd_linear_b, - dropout_linear, ffn_layer_norm, ffn_layer_norm_scale, ffn_layer_norm_bias, @@ -2226,12 +2098,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_dropout, ffn_output); std::unordered_set marked_nodes({layer_norm, - layer_norm_scale, - layer_norm_bias, layer_norm_mean, layer_norm_variance, layer_norm_out, @@ -2261,8 +2130,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( eltadd_qk_out, softmax_qk, softmax_qk_out, - dropout_qk, - dropout_qk_out, transpose2_qkv, transpose2_qkv_out, matmul_qkv, @@ -2271,17 +2138,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( transpose2_qkv, transpose2_qkv_out, matmul_linear, - matmul_linear_w, matmul_linear_out, eltadd_linear, - eltadd_linear_b, eltadd_linear_out, - dropout_linear, - dropout_linear_out, eltadd_out, ffn_layer_norm, - ffn_layer_norm_scale, - ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, ffn_layer_norm_out, @@ -2295,8 +2156,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, ffn_gelu, ffn_gelu_out, - ffn_dropout, - ffn_dropout_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -2500,11 +2359,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* matmul0_w, Node* eltadd0_b, Node* eltadd_qk_b, - Node* dropout_qk, Node* reshape2_0, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* dropout_linear, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, Node* ffn_layer_norm_bias, @@ -2514,7 +2371,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_dropout, Node* ffn_output) { // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: @@ -2607,23 +2463,14 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( "epsilon", layer_norm->Op()->GetAttr("epsilon")); // output dropout attribute - auto* dropout_op = dropout_linear->Op(); - fused_multi_transformer_op_desc.SetAttr( - "dropout_rate", dropout_op->GetAttr("dropout_prob")); - fused_multi_transformer_op_desc.SetAttr("is_test", - dropout_op->GetAttr("is_test")); - fused_multi_transformer_op_desc.SetAttr( - "dropout_implementation", - dropout_op->GetAttr("dropout_implementation")); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + fused_multi_transformer_op_desc.SetAttr("is_test", true); // parallel ring id auto* c_identity_op = c_identity->Op(); fused_multi_transformer_op_desc.SetAttr("ring_id", c_identity_op->GetAttr("ring_id")); - // fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"}); - // fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true}); - auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); IR_NODE_LINK_TO(input0, fused_multi_transformer); @@ -2641,6 +2488,15 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(slice_op, slice_out); IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); }; @@ -2790,12 +2646,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, - ffn_dropout_out, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -2827,11 +2677,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, - dropout_qk_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); @@ -2873,12 +2718,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, - dropout_linear, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, - dropout_linear_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -2893,11 +2732,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( matmul0_w, eltadd0_b, eltadd_qk_b, - dropout_qk, reshape2_0, matmul_linear_w, eltadd_linear_b, - dropout_linear, ffn_layer_norm, ffn_layer_norm_scale, ffn_layer_norm_bias, @@ -2907,12 +2744,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_dropout, ffn_output); std::unordered_set marked_nodes({layer_norm, - layer_norm_scale, - layer_norm_bias, layer_norm_mean, layer_norm_variance, layer_norm_out, @@ -2944,8 +2778,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( eltadd_qk_out, softmax_qk, softmax_qk_out, - dropout_qk, - dropout_qk_out, transpose2_qkv, transpose2_qkv_out, matmul_qkv, @@ -2954,19 +2786,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( transpose2_qkv, transpose2_qkv_out, matmul_linear, - matmul_linear_w, matmul_linear_out, c_allreduce_sum, c_allreduce_sum_out, eltadd_linear, - eltadd_linear_b, eltadd_linear_out, - dropout_linear, - dropout_linear_out, eltadd_out, ffn_layer_norm, - ffn_layer_norm_scale, - ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, ffn_layer_norm_out, @@ -2984,8 +2810,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, ffn_gelu, ffn_gelu_out, - ffn_dropout, - ffn_dropout_out, ffn_eltadd_out}); // Remove unneeded nodes. diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h index 0f9aae4c57d..fd2cfc8c667 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h @@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - PATTERN_DECL_NODE(dropout_qk); - PATTERN_DECL_NODE(dropout_qk_out); // QK, V matmul PATTERN_DECL_NODE(matmul_qkv); @@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); - PATTERN_DECL_NODE(dropout_linear); - PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) @@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); - PATTERN_DECL_NODE(ffn_dropout); - PATTERN_DECL_NODE(ffn_dropout_out); // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) @@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - PATTERN_DECL_NODE(dropout_qk); - PATTERN_DECL_NODE(dropout_qk_out); // QK, V matmul PATTERN_DECL_NODE(matmul_qkv); @@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); - PATTERN_DECL_NODE(dropout_linear); - PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) @@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); - PATTERN_DECL_NODE(ffn_dropout); - PATTERN_DECL_NODE(ffn_dropout_out); // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) @@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - PATTERN_DECL_NODE(dropout_qk); - PATTERN_DECL_NODE(dropout_qk_out); // QK, V matmul PATTERN_DECL_NODE(matmul_qkv); @@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); - PATTERN_DECL_NODE(dropout_linear); - PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) @@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); - PATTERN_DECL_NODE(ffn_dropout); - PATTERN_DECL_NODE(ffn_dropout_out); // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc index 100f2ad8dac..28357f4b20e 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc @@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) { // (transpose_0, transpose_1) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk - // (softmax_qk) dropout -> dropout_qk - // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) elementwise_add -> eltadd_linear - // (eltadd_linear) dropout -> dropout_linear // (eltadd_out) elementwise_add -> attention_out // // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out @@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { // (ffn_eltadd0) gelu -> ffn_gelu // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 - // (ffn_eltadd1) dropout -> ffn_dropout - // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output Layers layers; // MHA: pre LayerNorm @@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) { auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); - auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); // MHA: QKV matmul - auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); + auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); @@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { auto* linear_eltadd_out = layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); - auto* dropout_qkv = - layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); - auto* attention_out = layers.elementwise_add(x, dropout_qkv); + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); // FFN: pre LayerNorm auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); @@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); - // FFN: dropout -> elementwise_add - auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); - layers.elementwise_add(attention_out, ffn_dropout); + layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) { int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); PADDLE_ENFORCE_EQ(num_nodes_before, - num_nodes_after + 72, + num_nodes_after + 60, platform::errors::InvalidArgument( "After the fused_multi_transformer_decoder_pass, The " "node num in graph " "should be %d, but the result is %d", - num_nodes_before - 72, + num_nodes_before - 60, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { // (split_q, split_k) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk - // (softmax_qk) dropout -> dropout_qk - // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) elementwise_add -> eltadd_linear - // (eltadd_linear) dropout -> dropout_linear // (eltadd_out) elementwise_add -> attention_out // // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out @@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { // (ffn_eltadd0) gelu -> ffn_gelu // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 - // (ffn_eltadd1) dropout -> ffn_dropout - // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output // // (transpose_1, transpose_2) while -> decoder block @@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); - auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); // MHA: QKV matmul - auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); + auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); @@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { auto* linear_eltadd_out = layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); - auto* dropout_qkv = - layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); - auto* attention_out = layers.elementwise_add(x, dropout_qkv); + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); // FFN: pre LayerNorm auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); @@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); - // FFN: dropout -> elementwise_add - auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); - layers.elementwise_add(attention_out, ffn_dropout); + layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 62, + num_nodes_after + 50, platform::errors::InvalidArgument( "After the fused_multi_transformer_decoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 62, + num_nodes_before - 50, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { // (split_q, split_k) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk - // (softmax_qk) dropout -> dropout_qk - // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) c_allreduce_sum -> c_all_reduce_out // (matmul_linear) elementwise_add -> eltadd_linear - // (eltadd_linear) dropout -> dropout_linear // (eltadd_out) elementwise_add -> attention_out // // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out @@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1) c_allreduce_sum -> c_allreduce_out // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 - // (ffn_eltadd1) dropout -> ffn_dropout - // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output // // (transpose_1, transpose_2) while -> decoder block @@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); - auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); // MHA: QKV matmul - auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); + auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); @@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { auto* linear_eltadd_out = layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); - auto* dropout_qkv = - layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); - auto* attention_out = layers.elementwise_add(x, dropout_qkv); + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); // FFN: pre LayerNorm auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); @@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2); - // FFN: dropout -> elementwise_add - auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); - layers.elementwise_add(attention_out, ffn_dropout); + layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 70, + num_nodes_after + 58, platform::errors::InvalidArgument( "After the fused_multi_transformer_decoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 70, + num_nodes_before - 58, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index 3f163fc3683..099df04e523 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -227,15 +227,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_qk = - pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); - auto* dropout_qk_out_var = - pattern->NewNode(dropout_qk_out_repr()) - ->assert_is_op_output("dropout", "Out") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) @@ -243,7 +235,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); // QKV path Nodes auto* matmul_qkv = @@ -284,14 +275,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_linear = - pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); - auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add"); + ->assert_is_op_input("elementwise_add"); auto* eltadd_out = pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); @@ -300,7 +284,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({dropout_qk_out_var, transpose2_2_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -310,9 +294,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { .LinksTo({matmul_linear_out_var}); eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) .LinksTo({eltadd_linear_out_var}); - dropout_linear->LinksFrom({eltadd_linear_out_var}) - .LinksTo({dropout_linear_out_var}); - eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); // while loop @@ -352,7 +334,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ffn_layer_norm_mean_var, ffn_layer_norm_variance_var}); - // Feed Forward fc1 -> gelu -> fc2 -> dropout + // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) @@ -397,13 +379,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* ffn_dropout = - pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); - auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() ->assert_is_op_input("elementwise_add"); auto* ffn_eltadd_out = @@ -421,9 +396,8 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); return ffn_output; @@ -545,15 +519,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_qk = - pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); - auto* dropout_qk_out_var = - pattern->NewNode(dropout_qk_out_repr()) - ->assert_is_op_output("dropout", "Out") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) @@ -561,7 +527,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); // QKV path Nodes auto* matmul_qkv = @@ -602,14 +567,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_linear = - pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); - auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add"); + ->assert_is_op_input("elementwise_add"); auto* eltadd_out = pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); @@ -618,7 +576,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -628,9 +586,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { .LinksTo({matmul_linear_out_var}); eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) .LinksTo({eltadd_linear_out_var}); - dropout_linear->LinksFrom({eltadd_linear_out_var}) - .LinksTo({dropout_linear_out_var}); - eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); // Feed Forward LayerNorm Nodes @@ -666,7 +622,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ffn_layer_norm_mean_var, ffn_layer_norm_variance_var}); - // Feed Forward fc1 -> gelu -> fc2 -> dropout + // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) @@ -711,13 +667,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* ffn_dropout = - pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); - auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() ->assert_is_op_input("elementwise_add"); auto* ffn_eltadd_out = @@ -735,9 +684,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { .LinksTo({ffn_matmul1_out_var}); ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); return ffn_output; @@ -868,15 +816,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) ->assert_is_op_output("softmax") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_qk = - pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); - auto* dropout_qk_out_var = - pattern->NewNode(dropout_qk_out_repr()) - ->assert_is_op_output("dropout", "Out") - ->AsIntermediate() - ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + ->assert_is_op_input("matmul_v2", "X"); // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) @@ -884,7 +824,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); // QKV path Nodes auto* matmul_qkv = @@ -933,14 +872,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* dropout_linear = - pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); - auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() - ->assert_is_op_input("elementwise_add"); + ->assert_is_op_input("elementwise_add"); auto* eltadd_out = pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); @@ -949,7 +881,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ->AsIntermediate(); // QKV path Links - matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var}) + matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var}) .LinksTo({matmul_qkv_out_var}); transpose2_qkv->LinksFrom({matmul_qkv_out_var}) .LinksTo({transpose2_qkv_out_var}); @@ -961,9 +893,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { .LinksTo({c_allreduce_sum_out_var}); eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) .LinksTo({eltadd_linear_out_var}); - dropout_linear->LinksFrom({eltadd_linear_out_var}) - .LinksTo({dropout_linear_out_var}); - eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + eltadd_out->LinksFrom({input0, eltadd_linear_out_var}) .LinksTo({attention_output}); // Feed Forward LayerNorm Nodes @@ -1009,7 +939,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) .LinksTo({ffn_c_identity_out_var}); - // Feed Forward fc1 -> gelu -> fc2 -> dropout + // Feed Forward fc1 -> gelu -> fc2 auto* ffn_matmul0 = pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) @@ -1063,13 +993,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) ->assert_is_op_output("elementwise_add") ->AsIntermediate() - ->assert_is_op_input("dropout"); - - auto* ffn_dropout = - pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); - auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) - ->assert_is_op_output("dropout") - ->AsIntermediate() ->assert_is_op_input("elementwise_add"); auto* ffn_eltadd_out = @@ -1089,9 +1012,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { .LinksTo({ffn_c_allreduce_sum_out_var}); ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) .LinksTo({ffn_eltadd1_out_var}); - ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); - ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var}) .LinksTo({ffn_output}); return ffn_output; @@ -1253,11 +1175,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* transpose2_1_out, Node* transpose2_2_out, Node* eltadd_qk_b, - Node* dropout_qk, Node* reshape2_0, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* dropout_linear, Node* while0, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, @@ -1268,7 +1188,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_dropout, Node* ffn_output) { auto reshape_desc = reshape2_0->Op(); int num_head = @@ -1375,7 +1294,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("value", 0); - fill_const_op_desc.SetAttr("dtype", static_cast(proto::VarType::FP32)); + fill_const_op_desc.SetAttr( + "dtype", + static_cast(framework::TransToProtoVarType(wq_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -1409,15 +1330,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, fused_multi_transformer_op_desc.SetAttr( "epsilon", layer_norm->Op()->GetAttr("epsilon")); - // output dropout attribute - auto* dropout_op = dropout_linear->Op(); - fused_multi_transformer_op_desc.SetAttr( - "dropout_rate", dropout_op->GetAttr("dropout_prob")); - fused_multi_transformer_op_desc.SetAttr("is_test", - dropout_op->GetAttr("is_test")); - fused_multi_transformer_op_desc.SetAttr( - "dropout_implementation", - dropout_op->GetAttr("dropout_implementation")); + fused_multi_transformer_op_desc.SetAttr("is_test", true); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); @@ -1433,6 +1347,15 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, IR_NODE_LINK_TO(fill_const_op, cache_kv); IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); // rewrite while OP input @@ -1620,11 +1543,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH( ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout, ffn_dropout, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( @@ -1668,11 +1586,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, softmax_qk, softmax_qk, fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk, dropout_qk, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -1700,11 +1613,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) GET_IR_NODE_FROM_SUBGRAPH( eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - dropout_linear, dropout_linear, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( - dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern) - GET_IR_NODE_FROM_SUBGRAPH( eltadd_out, eltadd_out, fused_multi_transformer_pattern) @@ -1723,11 +1631,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, transpose2_1_out, transpose2_2_out, eltadd_qk_b, - dropout_qk, reshape2_0, matmul_linear_w, eltadd_linear_b, - dropout_linear, while0, ffn_layer_norm, ffn_layer_norm_scale, @@ -1738,12 +1644,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_dropout, ffn_output); std::unordered_set marked_nodes({layer_norm, - layer_norm_scale, - layer_norm_bias, layer_norm_mean, layer_norm_variance, layer_norm_out, @@ -1777,8 +1680,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, eltadd_qk_out, softmax_qk, softmax_qk_out, - dropout_qk, - dropout_qk_out, transpose2_qkv, transpose2_qkv_out, matmul_qkv, @@ -1787,17 +1688,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, transpose2_qkv, transpose2_qkv_out, matmul_linear, - matmul_linear_w, matmul_linear_out, eltadd_linear, - eltadd_linear_b, eltadd_linear_out, - dropout_linear, - dropout_linear_out, eltadd_out, ffn_layer_norm, - ffn_layer_norm_scale, - ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, ffn_layer_norm_out, @@ -1811,8 +1706,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ffn_eltadd1_out, ffn_gelu, ffn_gelu_out, - ffn_dropout, - ffn_dropout_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -2016,11 +1909,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* split0_k_out, Node* split0_v_out, Node* eltadd_qk_b, - Node* dropout_qk, Node* reshape2_0, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* dropout_linear, Node* while0, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, @@ -2031,7 +1922,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_dropout, Node* ffn_output) { auto reshape_desc = reshape2_0->Op(); int num_head = @@ -2104,7 +1994,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("value", 0); - fill_const_op_desc.SetAttr("dtype", static_cast(proto::VarType::FP32)); + fill_const_op_desc.SetAttr("dtype", + static_cast(framework::TransToProtoVarType( + qkv_w_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -2139,14 +2031,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( "epsilon", layer_norm->Op()->GetAttr("epsilon")); // output dropout attribute - auto* dropout_op = dropout_linear->Op(); - fused_multi_transformer_op_desc.SetAttr( - "dropout_rate", dropout_op->GetAttr("dropout_prob")); - fused_multi_transformer_op_desc.SetAttr("is_test", - dropout_op->GetAttr("is_test")); - fused_multi_transformer_op_desc.SetAttr( - "dropout_implementation", - dropout_op->GetAttr("dropout_implementation")); + fused_multi_transformer_op_desc.SetAttr("is_test", true); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); @@ -2162,6 +2048,15 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(fill_const_op, cache_kv); IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); // rewrite while OP input @@ -2315,12 +2210,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, - ffn_dropout_out, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -2352,11 +2241,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, - dropout_qk_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); @@ -2392,12 +2276,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, - dropout_linear, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, - dropout_linear_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -2416,11 +2294,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( split0_k_out, split0_v_out, eltadd_qk_b, - dropout_qk, reshape2_0, matmul_linear_w, eltadd_linear_b, - dropout_linear, while0, ffn_layer_norm, ffn_layer_norm_scale, @@ -2431,12 +2307,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_dropout, ffn_output); std::unordered_set marked_nodes({layer_norm, - layer_norm_scale, - layer_norm_bias, layer_norm_mean, layer_norm_variance, layer_norm_out, @@ -2458,8 +2331,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( eltadd_qk_out, softmax_qk, softmax_qk_out, - dropout_qk, - dropout_qk_out, transpose2_qkv, transpose2_qkv_out, matmul_qkv, @@ -2468,17 +2339,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( transpose2_qkv, transpose2_qkv_out, matmul_linear, - matmul_linear_w, matmul_linear_out, eltadd_linear, - eltadd_linear_b, eltadd_linear_out, - dropout_linear, - dropout_linear_out, eltadd_out, ffn_layer_norm, - ffn_layer_norm_scale, - ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, ffn_layer_norm_out, @@ -2492,8 +2357,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, ffn_gelu, ffn_gelu_out, - ffn_dropout, - ffn_dropout_out, ffn_eltadd_out}); // Remove unneeded nodes. @@ -2700,11 +2563,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* split0_k_out, Node* split0_v_out, Node* eltadd_qk_b, - Node* dropout_qk, Node* reshape2_0, Node* matmul_linear_w, Node* eltadd_linear_b, - Node* dropout_linear, Node* while0, Node* ffn_layer_norm, Node* ffn_layer_norm_scale, @@ -2715,7 +2576,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, - Node* ffn_dropout, Node* ffn_output) { auto reshape_desc = reshape2_0->Op(); int num_head = @@ -2789,7 +2649,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("value", 0); - fill_const_op_desc.SetAttr("dtype", static_cast(proto::VarType::FP32)); + fill_const_op_desc.SetAttr("dtype", + static_cast(framework::TransToProtoVarType( + qkv_w_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -2824,14 +2686,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( "epsilon", layer_norm->Op()->GetAttr("epsilon")); // output dropout attribute - auto* dropout_op = dropout_linear->Op(); - fused_multi_transformer_op_desc.SetAttr( - "dropout_rate", dropout_op->GetAttr("dropout_prob")); - fused_multi_transformer_op_desc.SetAttr("is_test", - dropout_op->GetAttr("is_test")); - fused_multi_transformer_op_desc.SetAttr( - "dropout_implementation", - dropout_op->GetAttr("dropout_implementation")); + fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + fused_multi_transformer_op_desc.SetAttr("is_test", true); // parallel ring id auto* c_identity_op = c_identity->Op(); @@ -2852,6 +2708,15 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( IR_NODE_LINK_TO(fill_const_op, cache_kv); IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer); + IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer); + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); // rewrite while OP input @@ -3024,12 +2889,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, - ffn_dropout_out, - fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_fuse_qkv_pattern) @@ -3061,11 +2920,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH( - dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) - GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, - dropout_qk_out, - fused_multi_transformer_fuse_qkv_pattern) GET_IR_NODE_FROM_SUBGRAPH( matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); @@ -3107,12 +2961,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, - dropout_linear, - fused_multi_transformer_fuse_qkv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, - dropout_linear_out, - fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern); @@ -3132,11 +2980,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( split0_k_out, split0_v_out, eltadd_qk_b, - dropout_qk, reshape2_0, matmul_linear_w, eltadd_linear_b, - dropout_linear, while0, ffn_layer_norm, ffn_layer_norm_scale, @@ -3147,12 +2993,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, - ffn_dropout, ffn_output); std::unordered_set marked_nodes({layer_norm, - layer_norm_scale, - layer_norm_bias, layer_norm_mean, layer_norm_variance, layer_norm_out, @@ -3176,8 +3019,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( eltadd_qk_out, softmax_qk, softmax_qk_out, - dropout_qk, - dropout_qk_out, transpose2_qkv, transpose2_qkv_out, matmul_qkv, @@ -3186,19 +3027,13 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( transpose2_qkv, transpose2_qkv_out, matmul_linear, - matmul_linear_w, matmul_linear_out, c_allreduce_sum, c_allreduce_sum_out, eltadd_linear, - eltadd_linear_b, eltadd_linear_out, - dropout_linear, - dropout_linear_out, eltadd_out, ffn_layer_norm, - ffn_layer_norm_scale, - ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, ffn_layer_norm_out, @@ -3216,8 +3051,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_eltadd1_out, ffn_gelu, ffn_gelu_out, - ffn_dropout, - ffn_dropout_out, ffn_eltadd_out}); // Remove unneeded nodes. diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h index 6e62a69cdf1..55792456b8c 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h @@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - PATTERN_DECL_NODE(dropout_qk); - PATTERN_DECL_NODE(dropout_qk_out); // QK, V matmul PATTERN_DECL_NODE(matmul_qkv); @@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); - PATTERN_DECL_NODE(dropout_linear); - PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) @@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); - PATTERN_DECL_NODE(ffn_dropout); - PATTERN_DECL_NODE(ffn_dropout_out); // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) @@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - PATTERN_DECL_NODE(dropout_qk); - PATTERN_DECL_NODE(dropout_qk_out); // QK, V matmul PATTERN_DECL_NODE(matmul_qkv); @@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); - PATTERN_DECL_NODE(dropout_linear); - PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) @@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); - PATTERN_DECL_NODE(ffn_dropout); - PATTERN_DECL_NODE(ffn_dropout_out); // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) @@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - PATTERN_DECL_NODE(dropout_qk); - PATTERN_DECL_NODE(dropout_qk_out); // QK, V matmul PATTERN_DECL_NODE(matmul_qkv); @@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_out); - PATTERN_DECL_NODE(dropout_linear); - PATTERN_DECL_NODE(dropout_linear_out); // output elementwise_add PATTERN_DECL_NODE(eltadd_out) @@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_out); - PATTERN_DECL_NODE(ffn_dropout); - PATTERN_DECL_NODE(ffn_dropout_out); // output elementwise_add PATTERN_DECL_NODE(ffn_eltadd_out) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc index 61017b273a0..4954bf3576e 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc @@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) { // (transpose_0, transpose_1) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk - // (softmax_qk) dropout -> dropout_qk - // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) elementwise_add -> eltadd_linear - // (eltadd_linear) dropout -> dropout_linear // (eltadd_out) elementwise_add -> attention_out // // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out @@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { // (ffn_eltadd0) gelu -> ffn_gelu // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 - // (ffn_eltadd1) dropout -> ffn_dropout - // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output // // (transpose_1, transpose_2) while -> decoder block @@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) { auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1); - auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); // MHA: QKV matmul - auto* matmul_qkv = layers.matmul_v2(dropout_qk, transpose_2); + auto* matmul_qkv = layers.matmul_v2(softmax_qk, transpose_2); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); @@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { auto* linear_eltadd_out = layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); - auto* dropout_qkv = - layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); - auto* attention_out = layers.elementwise_add(x, dropout_qkv); + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); // FFN: pre LayerNorm auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); @@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); - // FFN: dropout -> elementwise_add - auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); - layers.elementwise_add(attention_out, ffn_dropout); + layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); PADDLE_ENFORCE_EQ(num_nodes_before, - num_nodes_after + 68, + num_nodes_after + 56, platform::errors::InvalidArgument( "After the fused_multi_transformer_encoder_pass, The " "node num in graph " "should be %d, but the result is %d", - num_nodes_before - 68, + num_nodes_before - 56, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { // (split_q, split_k) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk - // (softmax_qk) dropout -> dropout_qk - // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) elementwise_add -> eltadd_linear - // (eltadd_linear) dropout -> dropout_linear // (eltadd_out) elementwise_add -> attention_out // // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out @@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { // (ffn_eltadd0) gelu -> ffn_gelu // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 - // (ffn_eltadd1) dropout -> ffn_dropout - // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output // // (transpose_1, transpose_2) while -> decoder block @@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); - auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); // MHA: QKV matmul - auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); + auto* matmul_qkv = layers.matmul_v2(softmax_qk, split_v); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); @@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { auto* linear_eltadd_out = layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); - auto* dropout_qkv = - layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); - auto* attention_out = layers.elementwise_add(x, dropout_qkv); + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); // FFN: pre LayerNorm auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); @@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); - // FFN: dropout -> elementwise_add - auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); - layers.elementwise_add(attention_out, ffn_dropout); + layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 56, + num_nodes_after + 44, platform::errors::InvalidArgument( "After the fused_multi_transformer_encoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 56, + num_nodes_before - 44, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { // (split_q, split_k) matmul -> matmul_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk - // (softmax_qk) dropout -> dropout_qk - // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) matmul_v2 -> matmul_linear // (matmul_linear) c_all_reduce -> c_all_reduce_out // (c_all_reduce_out) elementwise_add -> eltadd_linear - // (eltadd_linear) dropout -> dropout_linear // (eltadd_out) elementwise_add -> attention_out // // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out @@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { // (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out // (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1 - // (ffn_eltadd1) dropout -> ffn_dropout - // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output // // (transpose_1, transpose_2) while -> decoder block @@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); - auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); // MHA: QKV matmul - auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); + auto* matmul_qkv = layers.matmul_v2(softmax_qk, split_v); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); @@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { auto* linear_eltadd_out = layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); - auto* dropout_qkv = - layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); - auto* attention_out = layers.elementwise_add(x, dropout_qkv); + auto* attention_out = layers.elementwise_add(x, linear_eltadd_out); // FFN: pre LayerNorm auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); @@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { auto* ffn_eltadd1_out = layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2); - // FFN: dropout -> elementwise_add - auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); - layers.elementwise_add(attention_out, ffn_dropout); + layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); @@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 64, + num_nodes_after + 52, platform::errors::InvalidArgument( "After the fused_multi_transformer_encoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 64, + num_nodes_before - 52, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 4a1bf4baecc..74ad71f37da 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -39,6 +39,7 @@ namespace ir { static const char kParamScopeAttr[] = "__param_scope__"; static const std::vector support_subgraph_passes = { + "simplify_with_basic_ops_pass", "fused_multi_transformer_encoder_pass", "fused_multi_transformer_decoder_pass", "fused_multi_transformer_encoder_fuse_qkv_pass", -- GitLab