From 5a2e5179148119c745065bfeec6cce9b0ce1700c Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 20 Oct 2022 20:11:23 +0800 Subject: [PATCH] Add FusedMultiTransformer fuse pass for GPT3 (#45907) * add fused_multi_transformer_encoder/decoder pass, run GPT-3 success --- paddle/fluid/framework/ir/CMakeLists.txt | 10 + .../fused_multi_transformer_decoder_pass.cc | 3214 +++++++++++++++ .../ir/fused_multi_transformer_decoder_pass.h | 416 ++ ...d_multi_transformer_decoder_pass_tester.cc | 576 +++ .../fused_multi_transformer_encoder_pass.cc | 3448 +++++++++++++++++ .../ir/fused_multi_transformer_encoder_pass.h | 398 ++ ...d_multi_transformer_encoder_pass_tester.cc | 563 +++ paddle/fluid/framework/ir/graph_helper.cc | 11 +- .../framework/ir/graph_pattern_detector.cc | 5 +- .../framework/ir/graph_pattern_detector.h | 8 + paddle/fluid/framework/ir/pass.cc | 44 +- paddle/fluid/framework/ir/pass.h | 12 + .../fluid/framework/ir/pass_tester_helper.h | 79 +- .../analysis/passes/memory_optimize_pass.cc | 5 +- .../inference/api/paddle_pass_builder.cc | 38 +- 15 files changed, 8802 insertions(+), 25 deletions(-) create mode 100644 paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc create mode 100644 paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h create mode 100644 paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc create mode 100644 paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc create mode 100644 paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h create mode 100644 paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index c6337a5a304..adf92e57998 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) +pass_library(fused_multi_transformer_encoder_pass inference) +pass_library(fused_multi_transformer_decoder_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(yolo_box_fuse_pass inference) @@ -311,6 +313,14 @@ cc_test( test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass) +cc_test( + test_fused_multi_transformer_encoder_pass + SRCS fused_multi_transformer_encoder_pass_tester.cc + DEPS fused_multi_transformer_encoder_pass) +cc_test( + test_fused_multi_transformer_decoder_pass + SRCS fused_multi_transformer_decoder_pass_tester.cc + DEPS fused_multi_transformer_decoder_pass) cc_test( test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc new file mode 100644 index 00000000000..5559499e0b4 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -0,0 +1,3214 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h" + +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +PDNode* FusedMultiTransformerDecoderPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); + + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X") + ->assert_more([](Node* x) { + if (x->outputs.size() == 3) { + return true; + } else { + return false; + } + }); + + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + + // Q path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("matmul", "X"); + + // Q path Links + matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + + // K path Nodes + auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2"); + auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd1 = + pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate(); + auto* concat_0_in_var = pattern->NewNode(concat_0_in_repr())->AsInput(); + auto* concat_0 = pattern->NewNode(concat_0_repr())->assert_is_op("concat"); + auto* concat_0_out_var = pattern->NewNode(concat_0_out_repr()) + ->assert_is_op_output("concat") + ->AsIntermediate() + ->assert_is_op_input("matmul") + ->assert_is_op_input("assign"); + auto assign_0 = pattern->NewNode(assign_0_repr())->assert_is_op("assign"); + + // K path Links + matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var}) + .LinksTo({matmul1_out_var}); + eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var}) + .LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + concat_0->LinksFrom({transpose2_1_out_var, concat_0_in_var}) + .LinksTo({concat_0_out_var}); + assign_0->LinksFrom({concat_0_out_var}); + + // V path Nodes + auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2"); + auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd2 = + pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + auto* concat_1_in_var = pattern->NewNode(concat_1_in_repr()) + ->AsInput() + ->assert_is_op_input("concat"); + auto* concat_1 = pattern->NewNode(concat_1_repr())->assert_is_op("concat"); + auto* concat_1_out_var = pattern->NewNode(concat_1_out_repr()) + ->assert_is_op_output("concat") + ->assert_is_op_input("matmul_v2") + ->assert_is_op_input("assign"); + auto assign_1 = pattern->NewNode(assign_1_repr())->assert_is_op("assign"); + + // V path Links + matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var}) + .LinksTo({matmul2_out_var}); + eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var}) + .LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + concat_1->LinksFrom({transpose2_2_out_var, concat_1_in_var}) + .LinksTo({concat_1_out_var}); + assign_1->LinksFrom({concat_1_out_var}); + + // QK path Nodes + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_qk = + pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); + auto* dropout_qk_out_var = + pattern->NewNode(dropout_qk_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + + // QK path Linsk + matmul_qk->LinksFrom({transpose2_0_out_var, concat_0_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_linear = + pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); + auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({dropout_qk_out_var, concat_1_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + dropout_linear->LinksFrom({eltadd_linear_out_var}) + .LinksTo({dropout_linear_out_var}); + eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + .LinksTo({attention_output}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // Feed Forward fc1 -> gelu -> fc2 -> dropout + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("gelu"); + + auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); + auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) + ->assert_is_op_output("gelu") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* ffn_dropout = + pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); + auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); + ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); + + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X"); + + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + + // QKV fused path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("split", "X"); + + auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); + auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("matmul", "X"); + auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("concat"); + auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("concat"); + + auto* concat_k_in_var = pattern + ->NewNode(concat_k_in_repr()) + // ->AsInput() + ->assert_is_op_input("concat"); + auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat"); + auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr()) + ->assert_is_op_output("concat") + ->AsIntermediate() + ->assert_is_op_input("matmul") + ->assert_is_op_input("assign"); + auto* concat_v_in_var = pattern + ->NewNode(concat_v_in_repr()) + // ->AsInput() + ->assert_is_op_input("concat"); + auto* concat_v = pattern->NewNode(concat_v_repr())->assert_is_op("concat"); + auto* concat_v_out_var = pattern->NewNode(concat_v_out_repr()) + ->assert_is_op_output("concat") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2") + ->assert_is_op_input("assign"); + + auto* assign_k = pattern->NewNode(assign_k_repr())->assert_is_op("assign"); + auto* assign_v = pattern->NewNode(assign_v_repr())->assert_is_op("assign"); + + // QKV fused path Links + matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + split0->LinksFrom({transpose2_0_out_var}) + .LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var}); + concat_k->LinksFrom({concat_k_in_var, split0_k_out_var}) + .LinksTo({concat_k_out_var}); + concat_v->LinksFrom({concat_v_in_var, split0_v_out_var}) + .LinksTo({concat_v_out_var}); + assign_k->LinksFrom({concat_k_out_var}); + assign_v->LinksFrom({concat_v_out_var}); + + // QK path Nodes + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_qk = + pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); + auto* dropout_qk_out_var = + pattern->NewNode(dropout_qk_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + + // QK path Linsk + matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_linear = + pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); + auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + dropout_linear->LinksFrom({eltadd_linear_out_var}) + .LinksTo({dropout_linear_out_var}); + eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + .LinksTo({attention_output}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // Feed Forward fc1 -> gelu -> fc2 -> dropout + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("gelu"); + + auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); + auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) + ->assert_is_op_output("gelu") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* ffn_dropout = + pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); + auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); + ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); + + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X"); + + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + + // communication c_identity + auto* c_identity = + pattern->NewNode(c_identity_repr())->assert_is_op("c_identity"); + auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("c_identity", "Out") + ->assert_is_op_input("matmul_v2", "X"); + c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var}); + + // QKV fused path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("split", "X"); + + auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); + auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("matmul", "X"); + auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("concat"); + auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("concat"); + + auto* concat_k_in_var = pattern + ->NewNode(concat_k_in_repr()) + // ->AsInput() + ->assert_is_op_input("concat"); + auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat"); + auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr()) + ->assert_is_op_output("concat") + ->AsIntermediate() + ->assert_is_op_input("matmul") + ->assert_is_op_input("assign"); + auto* concat_v_in_var = pattern + ->NewNode(concat_v_in_repr()) + // ->AsInput() + ->assert_is_op_input("concat"); + auto* concat_v = pattern->NewNode(concat_v_repr())->assert_is_op("concat"); + auto* concat_v_out_var = pattern->NewNode(concat_v_out_repr()) + ->assert_is_op_output("concat") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2") + ->assert_is_op_input("assign"); + + auto* assign_k = pattern->NewNode(assign_k_repr())->assert_is_op("assign"); + auto* assign_v = pattern->NewNode(assign_v_repr())->assert_is_op("assign"); + + // QKV fused path Links + matmul0->LinksFrom({c_identity_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + split0->LinksFrom({transpose2_0_out_var}) + .LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var}); + concat_k->LinksFrom({concat_k_in_var, split0_k_out_var}) + .LinksTo({concat_k_out_var}); + concat_v->LinksFrom({concat_v_in_var, split0_v_out_var}) + .LinksTo({concat_v_out_var}); + assign_k->LinksFrom({concat_k_out_var}); + assign_v->LinksFrom({concat_v_out_var}); + + // QK path Nodes + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_qk = + pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); + auto* dropout_qk_out_var = + pattern->NewNode(dropout_qk_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + + // QK path Linsk + matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("c_allreduce_sum"); + + // communication c_allreduce_sum + auto* c_allreduce_sum = + pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum"); + auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_linear = + pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); + auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + c_allreduce_sum->LinksFrom({matmul_linear_out_var}) + .LinksTo({c_allreduce_sum_out_var}); + eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + dropout_linear->LinksFrom({eltadd_linear_out_var}) + .LinksTo({dropout_linear_out_var}); + eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + .LinksTo({attention_output}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // communication c_identity + auto* ffn_c_identity = + pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity"); + auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr()) + ->assert_is_op_output("c_identity", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); + ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) + .LinksTo({ffn_c_identity_out_var}); + + // Feed Forward fc1 -> gelu -> fc2 -> dropout + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("gelu"); + + auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); + auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) + ->assert_is_op_output("gelu") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("c_allreduce_sum"); + + // communication c_allreduce_sum + auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr()) + ->assert_is_op("c_allreduce_sum"); + auto* ffn_c_allreduce_sum_out_var = + pattern->NewNode(ffn_c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* ffn_dropout = + pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); + auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); + ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) + .LinksTo({ffn_c_allreduce_sum_out_var}); + ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +} // namespace patterns + +int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::FusedMultiTransformerDecoderPattern fused_multi_transformer_pattern( + pattern, name_scope); + fused_multi_transformer_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* matmul0_w, + Node* matmul1_w, + Node* matmul2_w, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* transpose2_1_out, + Node* transpose2_2_out, + Node* eltadd_qk_b, + Node* dropout_qk, + Node* reshape2_0, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* dropout_linear, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0_w, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_dropout, + Node* ffn_output) { + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // Cache KV use cache_kv in encoder + auto cache_kv_name = "cache_kv" + std::to_string(layer_idx); + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name}); + + VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx)); + shape_out_desc.SetDataType(proto::VarType::INT32); + shape_out_desc.SetPersistable(false); + auto* shape_out = graph->CreateVarNode(&shape_out_desc); + + OpDesc shape_op_desc(layer_norm->Op()->Block()); + shape_op_desc.SetType("shape"); + shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); + shape_op_desc.SetOutput("Out", {shape_out->Name()}); + auto* shape_op = graph->CreateOpNode(&shape_op_desc); + + VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx)); + slice_out_desc.SetDataType(proto::VarType::INT32); + slice_out_desc.SetPersistable(false); + auto* slice_out = graph->CreateVarNode(&slice_out_desc); + + OpDesc slice_op_desc(layer_norm->Op()->Block()); + slice_op_desc.SetType("slice"); + slice_op_desc.SetInput("Input", {shape_out->Name()}); + slice_op_desc.SetOutput("Out", {slice_out->Name()}); + std::vector axes = {0}; + std::vector starts = {3}; + std::vector ends = {4}; + slice_op_desc.SetAttr("axes", axes); + slice_op_desc.SetAttr("starts", starts); + slice_op_desc.SetAttr("ends", ends); + auto* slice_op = graph->CreateOpNode(&slice_op_desc); + + fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + // output dropout attribute + auto* dropout_op = dropout_linear->Op(); + fused_multi_transformer_op_desc.SetAttr( + "dropout_rate", dropout_op->GetAttr("dropout_prob")); + fused_multi_transformer_op_desc.SetAttr("is_test", + dropout_op->GetAttr("is_test")); + fused_multi_transformer_op_desc.SetAttr( + "dropout_implementation", + dropout_op->GetAttr("dropout_implementation")); + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + // TimeStep link + IR_NODE_LINK_TO(eltadd_qk_b, shape_op); + IR_NODE_LINK_TO(shape_op, shape_out); + IR_NODE_LINK_TO(shape_out, slice_op); + IR_NODE_LINK_TO(slice_op, slice_out); + IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_decoder " + "pass in op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer decoder fuse"; + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul1, matmul1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul1_out, matmul1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul1_w, matmul1_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1, reshape2_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1, transpose2_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_0, concat_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_0_out, concat_0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + assign_0, assign_0, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul2, matmul2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_out, matmul2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_w, matmul2_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2, reshape2_2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2, transpose2_2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_1, concat_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_1_out, concat_1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + assign_1, assign_1, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + attention_output, attention_output, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu, ffn_gelu, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout, ffn_dropout, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1, eltadd1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1_b, eltadd1_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1_out, eltadd1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2, eltadd2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2_b, eltadd2_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2_out, eltadd2_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk, dropout_qk, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + dropout_linear, dropout_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_pattern) + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + matmul0_w, + matmul1_w, + matmul2_w, + eltadd0_b, + eltadd1_b, + eltadd2_b, + transpose2_1_out, + transpose2_2_out, + eltadd_qk_b, + dropout_qk, + reshape2_0, + matmul_linear_w, + eltadd_linear_b, + dropout_linear, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0_w, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_dropout, + ffn_output); + + std::unordered_set marked_nodes({layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + matmul0, + matmul1, + matmul2, + matmul0_out, + matmul1_out, + matmul2_out, + eltadd0, + eltadd1, + eltadd2, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + concat_0, + concat_1, + concat_0_out, + concat_1_out, + assign_0, + assign_1, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + dropout_qk, + dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_w, + matmul_linear_out, + eltadd_linear, + eltadd_linear_b, + eltadd_linear_out, + dropout_linear, + dropout_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_layer_norm_out, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_gelu, + ffn_gelu_out, + ffn_dropout, + ffn_dropout_out, + ffn_eltadd_out}); + + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal("During the multi_transformer pass, " + "The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kFusedMultiTransformerDecoderPass, new bool(true)); + } + AddStatis(fusion_count); +} + +FusedMultiTransformerDecoderPass::FusedMultiTransformerDecoderPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(2) + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); +} + +int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::FusedMultiTransformerDecoderFuseQKVPattern + fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); + fused_multi_transformer_fuse_qkv_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* matmul0_w, + Node* eltadd0_b, + Node* eltadd_qk_b, + Node* dropout_qk, + Node* reshape2_0, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* dropout_linear, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0_w, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_dropout, + Node* ffn_output) { + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // Cache KV use cache_kv in encoder + auto cache_kv_name = "cache_kv" + std::to_string(layer_idx); + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name}); + + VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx)); + shape_out_desc.SetDataType(proto::VarType::INT32); + shape_out_desc.SetPersistable(false); + auto* shape_out = graph->CreateVarNode(&shape_out_desc); + + OpDesc shape_op_desc(layer_norm->Op()->Block()); + shape_op_desc.SetType("shape"); + shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); + shape_op_desc.SetOutput("Out", {shape_out->Name()}); + auto* shape_op = graph->CreateOpNode(&shape_op_desc); + + VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx)); + slice_out_desc.SetDataType(proto::VarType::INT32); + slice_out_desc.SetPersistable(false); + auto* slice_out = graph->CreateVarNode(&slice_out_desc); + + OpDesc slice_op_desc(layer_norm->Op()->Block()); + slice_op_desc.SetType("slice"); + slice_op_desc.SetInput("Input", {shape_out->Name()}); + slice_op_desc.SetOutput("Out", {slice_out->Name()}); + std::vector axes = {0}; + std::vector starts = {3}; + std::vector ends = {4}; + slice_op_desc.SetAttr("axes", axes); + slice_op_desc.SetAttr("starts", starts); + slice_op_desc.SetAttr("ends", ends); + auto* slice_op = graph->CreateOpNode(&slice_op_desc); + + fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + // output dropout attribute + auto* dropout_op = dropout_linear->Op(); + fused_multi_transformer_op_desc.SetAttr( + "dropout_rate", dropout_op->GetAttr("dropout_prob")); + fused_multi_transformer_op_desc.SetAttr("is_test", + dropout_op->GetAttr("is_test")); + fused_multi_transformer_op_desc.SetAttr( + "dropout_implementation", + dropout_op->GetAttr("dropout_implementation")); + + // fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"}); + // fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true}); + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + // TimeStep link + IR_NODE_LINK_TO(eltadd_qk_b, shape_op); + IR_NODE_LINK_TO(shape_op, shape_out); + IR_NODE_LINK_TO(shape_out, slice_op); + IR_NODE_LINK_TO(slice_op, slice_out); + IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_decoder_fuse_qkv " + "pass in op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer decoder(Fuse-QKV) fuse"; + GET_IR_NODE_FROM_SUBGRAPH( + input0, input0, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, + layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, + layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, + layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, + layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, + reshape2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, + transpose2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + split0, split0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_k_in, concat_k_in, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_k, concat_k, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_k_out, concat_k_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_v_in, concat_v_in, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_v, concat_v, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_v_out, concat_v_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + assign_k, assign_k, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + assign_v, assign_v, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, + ffn_layer_norm, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, + ffn_matmul0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, + ffn_eltadd0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, + ffn_matmul1_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, + ffn_eltadd1_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, + ffn_dropout_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, + ffn_eltadd_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, + softmax_qk_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, + dropout_qk_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, + matmul_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, + reshape2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, + transpose2_qkv, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, + matmul_linear_w, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, + matmul_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, + eltadd_linear_b, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, + eltadd_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, + dropout_linear, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, + dropout_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + matmul0_w, + eltadd0_b, + eltadd_qk_b, + dropout_qk, + reshape2_0, + matmul_linear_w, + eltadd_linear_b, + dropout_linear, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0_w, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_dropout, + ffn_output); + + std::unordered_set marked_nodes({layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + matmul0, + matmul0_out, + eltadd0, + eltadd0_out, + reshape2_0, + reshape2_0_out, + transpose2_0, + transpose2_0_out, + split0, + split0_q_out, + split0_k_out, + split0_v_out, + concat_k_in, + concat_k, + concat_k_out, + concat_v_in, + concat_v, + concat_v_out, + assign_k, + assign_v, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + dropout_qk, + dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_w, + matmul_linear_out, + eltadd_linear, + eltadd_linear_b, + eltadd_linear_out, + dropout_linear, + dropout_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_layer_norm_out, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_gelu, + ffn_gelu_out, + ffn_dropout, + ffn_dropout_out, + ffn_eltadd_out}); + + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal("During the fused_multi_transformer_decoder " + "pass, The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); + } + AddStatis(fusion_count); +} + +FusedMultiTransformerDecoderFuseQKVPass:: + FusedMultiTransformerDecoderFuseQKVPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(2) + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); +} + +int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern + fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); + fused_multi_transformer_fuse_qkv_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* c_identity, + Node* matmul0_w, + Node* eltadd0_b, + Node* eltadd_qk_b, + Node* dropout_qk, + Node* reshape2_0, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* dropout_linear, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0_w, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_dropout, + Node* ffn_output) { + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // Cache KV use cache_kv in encoder + auto cache_kv_name = "cache_kv" + std::to_string(layer_idx); + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name}); + + VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx)); + shape_out_desc.SetDataType(proto::VarType::INT32); + shape_out_desc.SetPersistable(false); + auto* shape_out = graph->CreateVarNode(&shape_out_desc); + + OpDesc shape_op_desc(layer_norm->Op()->Block()); + shape_op_desc.SetType("shape"); + shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()}); + shape_op_desc.SetOutput("Out", {shape_out->Name()}); + auto* shape_op = graph->CreateOpNode(&shape_op_desc); + + VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx)); + slice_out_desc.SetDataType(proto::VarType::INT32); + slice_out_desc.SetPersistable(false); + auto* slice_out = graph->CreateVarNode(&slice_out_desc); + + OpDesc slice_op_desc(layer_norm->Op()->Block()); + slice_op_desc.SetType("slice"); + slice_op_desc.SetInput("Input", {shape_out->Name()}); + slice_op_desc.SetOutput("Out", {slice_out->Name()}); + std::vector axes = {0}; + std::vector starts = {3}; + std::vector ends = {4}; + slice_op_desc.SetAttr("axes", axes); + slice_op_desc.SetAttr("starts", starts); + slice_op_desc.SetAttr("ends", ends); + auto* slice_op = graph->CreateOpNode(&slice_op_desc); + + fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + // output dropout attribute + auto* dropout_op = dropout_linear->Op(); + fused_multi_transformer_op_desc.SetAttr( + "dropout_rate", dropout_op->GetAttr("dropout_prob")); + fused_multi_transformer_op_desc.SetAttr("is_test", + dropout_op->GetAttr("is_test")); + fused_multi_transformer_op_desc.SetAttr( + "dropout_implementation", + dropout_op->GetAttr("dropout_implementation")); + + // parallel ring id + auto* c_identity_op = c_identity->Op(); + fused_multi_transformer_op_desc.SetAttr("ring_id", + c_identity_op->GetAttr("ring_id")); + + // fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"}); + // fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true}); + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + // TimeStep link + IR_NODE_LINK_TO(eltadd_qk_b, shape_op); + IR_NODE_LINK_TO(shape_op, shape_out); + IR_NODE_LINK_TO(shape_out, slice_op); + IR_NODE_LINK_TO(slice_op, slice_out); + IR_NODE_LINK_TO(slice_out, fused_multi_transformer) + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_decoder_fuse_qkv " + "pass in op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer decoder(Fuse-QKV) fuse"; + GET_IR_NODE_FROM_SUBGRAPH( + input0, input0, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, + layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, + layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, + layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, + layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + c_identity, c_identity, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity_out, + c_identity_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, + reshape2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, + transpose2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + split0, split0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_k_in, concat_k_in, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_k, concat_k, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_k_out, concat_k_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_v_in, concat_v_in, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_v, concat_v, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + concat_v_out, concat_v_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + assign_k, assign_k, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + assign_v, assign_v, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, + ffn_layer_norm, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity, + ffn_c_identity, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out, + ffn_c_identity_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, + ffn_matmul0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, + ffn_eltadd0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, + ffn_matmul1_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum, + ffn_c_allreduce_sum, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out, + ffn_c_allreduce_sum_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, + ffn_eltadd1_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, + ffn_dropout_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, + ffn_eltadd_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, + softmax_qk_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, + dropout_qk_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, + matmul_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, + reshape2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, + transpose2_qkv, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, + matmul_linear_w, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, + matmul_linear_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum, + c_allreduce_sum, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out, + c_allreduce_sum_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, + eltadd_linear_b, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, + eltadd_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, + dropout_linear, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, + dropout_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + c_identity, + matmul0_w, + eltadd0_b, + eltadd_qk_b, + dropout_qk, + reshape2_0, + matmul_linear_w, + eltadd_linear_b, + dropout_linear, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0_w, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_dropout, + ffn_output); + + std::unordered_set marked_nodes({layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + c_identity, + c_identity_out, + matmul0, + matmul0_out, + eltadd0, + eltadd0_out, + reshape2_0, + reshape2_0_out, + transpose2_0, + transpose2_0_out, + split0, + split0_q_out, + split0_k_out, + split0_v_out, + concat_k_in, + concat_k, + concat_k_out, + concat_v_in, + concat_v, + concat_v_out, + assign_k, + assign_v, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + dropout_qk, + dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_w, + matmul_linear_out, + c_allreduce_sum, + c_allreduce_sum_out, + eltadd_linear, + eltadd_linear_b, + eltadd_linear_out, + dropout_linear, + dropout_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_layer_norm_out, + ffn_c_identity, + ffn_c_identity_out, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_c_allreduce_sum, + ffn_c_allreduce_sum_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_gelu, + ffn_gelu_out, + ffn_dropout, + ffn_dropout_out, + ffn_eltadd_out}); + + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl( + Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal("During the fused_multi_transformer_decoder " + "pass, The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); + } + AddStatis(fusion_count); +} + +MultiDevicesFusedMultiTransformerDecoderFuseQKVPass:: + MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(2) + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fused_multi_transformer_decoder_pass, + paddle::framework::ir::FusedMultiTransformerDecoderPass); +REGISTER_PASS(fused_multi_transformer_decoder_fuse_qkv_pass, + paddle::framework::ir::FusedMultiTransformerDecoderFuseQKVPass); +REGISTER_PASS( + multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass, + paddle::framework::ir::MultiDevicesFusedMultiTransformerDecoderFuseQKVPass); + +REGISTER_PASS_CAPABILITY(fused_multi_transformer_decoder_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); +REGISTER_PASS_CAPABILITY(fused_multi_transformer_decoder_fuse_qkv_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); +REGISTER_PASS_CAPABILITY( + multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h new file mode 100644 index 00000000000..0f9aae4c57d --- /dev/null +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h @@ -0,0 +1,416 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FusedMultiTransformerDecoderPattern : public PatternBase { + FusedMultiTransformerDecoderPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "fused_multi_transformer_decoder") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul1); + PATTERN_DECL_NODE(matmul2); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul1_w); + PATTERN_DECL_NODE(matmul2_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(matmul1_out); + PATTERN_DECL_NODE(matmul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + + PATTERN_DECL_NODE(concat_0_in); + PATTERN_DECL_NODE(concat_0); + PATTERN_DECL_NODE(concat_0_out); + PATTERN_DECL_NODE(assign_0); + PATTERN_DECL_NODE(concat_1_in); + PATTERN_DECL_NODE(concat_1); + PATTERN_DECL_NODE(concat_1_out); + PATTERN_DECL_NODE(assign_1); + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + PATTERN_DECL_NODE(dropout_qk); + PATTERN_DECL_NODE(dropout_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // while loop + PATTERN_DECL_NODE(while0); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_gelu); + PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + PATTERN_DECL_NODE(ffn_dropout); + PATTERN_DECL_NODE(ffn_dropout_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; + +struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { + FusedMultiTransformerDecoderFuseQKVPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase( + pattern, name_scope, "fused_multi_transformer_decoder_fuse_qkv") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_0_out); + + PATTERN_DECL_NODE(split0) + PATTERN_DECL_NODE(split0_q_out) + PATTERN_DECL_NODE(split0_k_out) + PATTERN_DECL_NODE(split0_v_out) + PATTERN_DECL_NODE(concat_k_in) + PATTERN_DECL_NODE(concat_v_in) + PATTERN_DECL_NODE(concat_k) + PATTERN_DECL_NODE(concat_v) + PATTERN_DECL_NODE(concat_k_out) + PATTERN_DECL_NODE(concat_v_out) + PATTERN_DECL_NODE(assign_k) + PATTERN_DECL_NODE(assign_v) + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + PATTERN_DECL_NODE(dropout_qk); + PATTERN_DECL_NODE(dropout_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_gelu); + PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + PATTERN_DECL_NODE(ffn_dropout); + PATTERN_DECL_NODE(ffn_dropout_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; + +struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern + : public PatternBase { + MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, + name_scope, + "multi_devices_fused_multi_transformer_decoder_fuse_qkv") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(c_identity); + PATTERN_DECL_NODE(c_identity_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_0_out); + + PATTERN_DECL_NODE(split0) + PATTERN_DECL_NODE(split0_q_out) + PATTERN_DECL_NODE(split0_k_out) + PATTERN_DECL_NODE(split0_v_out) + PATTERN_DECL_NODE(concat_k_in) + PATTERN_DECL_NODE(concat_v_in) + PATTERN_DECL_NODE(concat_k) + PATTERN_DECL_NODE(concat_v) + PATTERN_DECL_NODE(concat_k_out) + PATTERN_DECL_NODE(concat_v_out) + PATTERN_DECL_NODE(assign_k) + PATTERN_DECL_NODE(assign_v) + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + PATTERN_DECL_NODE(dropout_qk); + PATTERN_DECL_NODE(dropout_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(c_allreduce_sum); + PATTERN_DECL_NODE(c_allreduce_sum_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); + PATTERN_DECL_NODE(ffn_c_identity); + PATTERN_DECL_NODE(ffn_c_identity_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_gelu); + PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_c_allreduce_sum); + PATTERN_DECL_NODE(ffn_c_allreduce_sum_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + PATTERN_DECL_NODE(ffn_dropout); + PATTERN_DECL_NODE(ffn_dropout_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; + +} // namespace patterns + +class FusedMultiTransformerDecoderPass : public FusePassBase { + public: + FusedMultiTransformerDecoderPass(); + virtual ~FusedMultiTransformerDecoderPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fused_multi_transformer_decoder"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +class FusedMultiTransformerDecoderFuseQKVPass : public FusePassBase { + public: + FusedMultiTransformerDecoderFuseQKVPass(); + virtual ~FusedMultiTransformerDecoderFuseQKVPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fused_multi_transformer_decoder_fuse_qkv"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +class MultiDevicesFusedMultiTransformerDecoderFuseQKVPass + : public FusePassBase { + public: + MultiDevicesFusedMultiTransformerDecoderFuseQKVPass(); + virtual ~MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{ + "multi_devices_fused_multi_transformer_decoder_fuse_qkv"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc new file mode 100644 index 00000000000..100f2ad8dac --- /dev/null +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc @@ -0,0 +1,576 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + + // MHA: pre Layer Norm + AddVarToScope(param_scope, "ln_scale", {1024}); + AddVarToScope(param_scope, "ln_bias", {1024}); + + // MHA: QKV fc + AddVarToScope(param_scope, "weights0", {1024, 1024}); + AddVarToScope(param_scope, "weights1", {1024, 1024}); + AddVarToScope(param_scope, "weights2", {1024, 1024}); + AddVarToScope(param_scope, "bias_0", {1024}); + AddVarToScope(param_scope, "bias_1", {1024}); + AddVarToScope(param_scope, "bias_2", {1024}); + + // MHA: QK bias + AddVarToScope(param_scope, "biasqk", {1024}); + + // MHA: out Linear + AddVarToScope(param_scope, "weights_l", {1024, 1024}); + AddVarToScope(param_scope, "bias_l", {1024}); + + // MHA: pre Layer Norm + AddVarToScope(param_scope, "ffn_ln_scale", {1024}); + AddVarToScope(param_scope, "ffn_ln_bias", {1024}); + + // FFN: fc1 -> (gelu) -> fc2 + AddVarToScope(param_scope, "ffn_weights0", {1024, 4096}); + AddVarToScope(param_scope, "ffn_weights1", {4096, 1024}); + AddVarToScope(param_scope, "ffn_bias_0", {4096}); + AddVarToScope(param_scope, "ffn_bias_1", {1024}); + + return param_scope; +} + +TEST(FusedMultiTransformerDecoderPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out + // (layer_norm_out, weights_0) matmul_v2 -> matmul_out0 + // (layer_norm_out, weights_1) matmul_v2 -> matmul_out1 + // (layer_norm_out, weights_2) matmul_v2 -> matmul_out2 + // (matmul_out0, bias_0) elementwise_add -> eltadd_0 + // (matmul_out1, bias_1) elementwise_add -> eltadd_1 + // (matmul_out2, bias_2) elementwise_add -> eltadd_2 + // (eltadd_0) reshape2 -> reshape_0 + // (eltadd_1) reshape2 -> reshape_1 + // (eltadd_2) reshape2 -> reshape_2 + // (reshape_0) transpose2 -> transpose_0 + // (reshape_1) transpose2 -> transpose_1 + // (reshape_2) transpose2 -> transpose_2 + // (transpose_1) concat -> concat_0 + // (transpose_2) concat -> concat_2 + // (concat_0) assign -> assign_0 + // (concat_1) assign -> assign_2 + // (transpose_0, transpose_1) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk) dropout -> dropout_qk + // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) elementwise_add -> eltadd_linear + // (eltadd_linear) dropout -> dropout_linear + // (eltadd_out) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 + // (ffn_eltadd1) dropout -> ffn_dropout + // (attention_out, ffn_dropout) elementwise_add -> ffn_output + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 1024}, true); + auto* weights_1 = layers.data("weights1", {1024, 1024}, true); + auto* weights_2 = layers.data("weights2", {1024, 1024}, true); + auto* matmul_out_0 = + layers.matmul_v2(ln_out, weights_0, nullptr, false, true); + auto* matmul_out_1 = + layers.matmul_v2(ln_out, weights_1, nullptr, false, true); + auto* matmul_out_2 = + layers.matmul_v2(ln_out, weights_2, nullptr, false, true); + + auto* b0 = layers.data("bias_0", {1024}, true); + auto* b1 = layers.data("bias_1", {1024}, true); + auto* b2 = layers.data("bias_2", {1024}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + auto* elementwise_out_1 = + layers.elementwise_add(matmul_out_1, b1, nullptr, 2); + auto* elementwise_out_2 = + layers.elementwise_add(matmul_out_2, b2, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true); + auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + auto* transpose_1 = layers.transpose2(reshape_1, axis, true); + auto* transpose_2 = layers.transpose2(reshape_2, axis, true); + + auto* cache_k = layers.data("cache_k", {1, 16, 128, 64}); + auto* cache_v = layers.data("cache_v", {1, 16, 128, 64}); + auto* concat_k = layers.concat({cache_k, transpose_1}, 2); + auto* concat_v = layers.concat({cache_v, transpose_2}, 2); + layers.assign(concat_k); + layers.assign(concat_v); + + // MHA: QK matmul + auto* matmul_qk = layers.matmul(transpose_0, concat_k, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + auto* linear_eltadd_out = + layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); + + auto* dropout_qkv = + layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); + auto* attention_out = layers.elementwise_add(x, dropout_qkv); + + // FFN: pre LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + auto* ffn_ln_out = + layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); + + // FFN: dropout -> elementwise_add + auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); + layers.elementwise_add(attention_out, ffn_dropout); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = + PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fused_multi_transformer_decoder_pass failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ(num_nodes_before, + num_nodes_after + 72, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_decoder_pass, The " + "node num in graph " + "should be %d, but the result is %d", + num_nodes_before - 72, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_decoder pass, " + "there should be one fused_multi_transformer op, " + "but the result is %d", + num_fused_nodes_after)); +} + +TEST(FusedMultiTransformerDecoderPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("fused_multi_transformer_decoder_pass")); +} + +TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out + // (layer_norm_out, weights_0) matmul_v2 -> matmul_out0 + // (matmul_out0, bias_0) elementwise_add -> eltadd_0 + // (eltadd_0) reshape2 -> reshape_0 + // (reshape_0) transpose2 -> transpose_0 + // (transpose_0) split -> split_q, split_k, + // split_v (split_k) concat -> concat_k + // (split_v) concat -> concat_v + // (concat_k) assign -> assign_k + // (concat_v) assign -> assign_v + // (split_q, split_k) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk) dropout -> dropout_qk + // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) elementwise_add -> eltadd_linear + // (eltadd_linear) dropout -> dropout_linear + // (eltadd_out) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 + // (ffn_eltadd1) dropout -> ffn_dropout + // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // + // (transpose_1, transpose_2) while -> decoder block + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 3072}, true); + auto* matmul_out_0 = + layers.matmul_v2(ln_out, weights_0, nullptr, false, true); + + auto* b0 = layers.data("bias_0", {3072}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + + auto split_outs = layers.split(transpose_0, 3, 3); + auto* split_q = split_outs[0]; + auto* split_k = split_outs[1]; + auto* split_v = split_outs[2]; + + auto* cache_k = layers.data("cache_k", {1, 16, 128, 64}); + auto* cache_v = layers.data("cache_v", {1, 16, 128, 64}); + auto* concat_k = layers.concat({cache_k, split_k}, 2); + auto* concat_v = layers.concat({cache_v, split_v}, 2); + layers.assign(concat_k); + layers.assign(concat_v); + + // MHA: QK matmul + auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + auto* linear_eltadd_out = + layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); + + auto* dropout_qkv = + layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); + auto* attention_out = layers.elementwise_add(x, dropout_qkv); + + // FFN: pre LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + auto* ffn_ln_out = + layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); + + // FFN: dropout -> elementwise_add + auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); + layers.elementwise_add(attention_out, ffn_dropout); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = PassRegistry::Instance().Get( + "fused_multi_transformer_decoder_fuse_qkv_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fused_multi_transformer_decoder_fuse_qkv_pass failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_before, + num_nodes_after + 62, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_decoder_fuse_qkv_pass, " + "The node num in graph should be %d, but the result is %d", + num_nodes_before - 62, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_decoder_fuse_qkv " + "pass, there should be one fused_multi_transformer " + "op, but the result is %d", + num_fused_nodes_after)); +} + +TEST(FusedMultiTransformerDecoderFuseQKVPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("fused_multi_transformer_decoder_fuse_qkv_pass")); +} + +TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out + // (layer_norm_out) c_identity -> c_identity_out + // (c_identity_out, weights_0) matmul_v2 -> matmul_out0 + // (matmul_out0, bias_0) elementwise_add -> eltadd_0 + // (eltadd_0) reshape2 -> reshape_0 + // (reshape_0) transpose2 -> transpose_0 + // (transpose_0) split -> split_q, split_k, + // split_v (split_k) concat -> concat_k + // (split_v) concat -> concat_v + // (concat_k) assign -> assign_k + // (concat_v) assign -> assign_v + // (split_q, split_k) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk) dropout -> dropout_qk + // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) c_allreduce_sum -> c_all_reduce_out + // (matmul_linear) elementwise_add -> eltadd_linear + // (eltadd_linear) dropout -> dropout_linear + // (eltadd_out) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (ffn_layer_norm_out) c_identity -> ffn_c_identity_out + // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1) c_allreduce_sum -> c_allreduce_out + // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 + // (ffn_eltadd1) dropout -> ffn_dropout + // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // + // (transpose_1, transpose_2) while -> decoder block + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; + auto* c_identity_out = layers.c_identity(ln_out); + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 3072}, true); + auto* matmul_out_0 = + layers.matmul_v2(c_identity_out, weights_0, nullptr, false, true); + + auto* b0 = layers.data("bias_0", {3072}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + + auto split_outs = layers.split(transpose_0, 3, 3); + auto* split_q = split_outs[0]; + auto* split_k = split_outs[1]; + auto* split_v = split_outs[2]; + + auto* cache_k = layers.data("cache_k", {1, 16, 128, 64}); + auto* cache_v = layers.data("cache_v", {1, 16, 128, 64}); + auto* concat_k = layers.concat({cache_k, split_k}, 2); + auto* concat_v = layers.concat({cache_v, split_v}, 2); + layers.assign(concat_k); + layers.assign(concat_v); + + // MHA: QK matmul + auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out); + auto* linear_eltadd_out = + layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); + + auto* dropout_qkv = + layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); + auto* attention_out = layers.elementwise_add(x, dropout_qkv); + + // FFN: pre LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + auto* ffn_ln_out = + layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + auto* ffn_c_identity_out = layers.c_identity(ffn_ln_out); + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, true); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true); + auto* ffn_c_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2); + + // FFN: dropout -> elementwise_add + auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); + layers.elementwise_add(attention_out, ffn_dropout); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = PassRegistry::Instance().Get( + "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"); + if (pass.get() == nullptr) + LOG(INFO) + << "get multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass " + "failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_before, + num_nodes_after + 70, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_decoder_fuse_qkv_pass, " + "The node num in graph should be %d, but the result is %d", + num_nodes_before - 70, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_decoder_fuse_qkv " + "multi-devices pass, there should be one " + "fused_multi_transformer op, but the result is %d", + num_fused_nodes_after)); +} + +TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, + pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible( + "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fused_multi_transformer_decoder_pass); +USE_PASS(fused_multi_transformer_decoder_fuse_qkv_pass); +USE_PASS(multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc new file mode 100644 index 00000000000..3f163fc3683 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -0,0 +1,3448 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +PDNode* FusedMultiTransformerEncoderPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); + + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X") + ->assert_more([](Node* x) { + if (x->outputs.size() == 3) { + return true; + } else { + return false; + } + }); + + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + + // Q path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("matmul", "X"); + + // Q path Links + matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + + // K path Nodes + auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2"); + auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd1 = + pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2") + ->AsOutput() + ->assert_is_op_input("matmul", "Y") + ->assert_is_op_input("while") + ->assert_more([](Node* x) { + if (x->outputs.size() == 2) { + return true; + } else { + return false; + } + }); + + // K path Links + matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var}) + .LinksTo({matmul1_out_var}); + eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var}) + .LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + + // V path Nodes + auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2"); + auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd2 = + pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2") + ->AsOutput() + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_op_input("while") + ->assert_more([](Node* x) { + if (x->outputs.size() == 2) { + return true; + } else { + return false; + } + }); + + // V path Links + matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var}) + .LinksTo({matmul2_out_var}); + eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var}) + .LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + + // QK path Nodes + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_qk = + pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); + auto* dropout_qk_out_var = + pattern->NewNode(dropout_qk_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + + // QK path Linsk + matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_linear = + pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); + auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({dropout_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + dropout_linear->LinksFrom({eltadd_linear_out_var}) + .LinksTo({dropout_linear_out_var}); + eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + .LinksTo({attention_output}); + + // while loop + auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); + while0->LinksFrom({transpose2_1_out_var, transpose2_2_out_var}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // Feed Forward fc1 -> gelu -> fc2 -> dropout + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("gelu"); + + auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); + auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) + ->assert_is_op_output("gelu") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* ffn_dropout = + pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); + auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); + ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); + + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X"); + + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + + // QKV fused path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("split", "X"); + + auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); + auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("matmul", "X"); + auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) + ->assert_is_op_output("split") + ->AsOutput() + ->assert_is_op_input("matmul", "Y") + ->assert_is_op_input("while"); + auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) + ->assert_is_op_output("split") + ->AsOutput() + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_op_input("while"); + + // QKV fused path Links + matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + split0->LinksFrom({transpose2_0_out_var}) + .LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var}); + + // while loop + auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); + while0->LinksFrom({split0_k_out_var, split0_v_out_var}); + + // QK path Nodes + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_qk = + pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); + auto* dropout_qk_out_var = + pattern->NewNode(dropout_qk_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + + // QK path Linsk + matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_linear = + pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); + auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + dropout_linear->LinksFrom({eltadd_linear_out_var}) + .LinksTo({dropout_linear_out_var}); + eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + .LinksTo({attention_output}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("matmul_v2", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // Feed Forward fc1 -> gelu -> fc2 -> dropout + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("gelu"); + + auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); + auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) + ->assert_is_op_output("gelu") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* ffn_dropout = + pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); + auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); + ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("layer_norm", "X"); + + // pre-LayerNorm + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X"); + + layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + + // communication c_identity + auto* c_identity = + pattern->NewNode(c_identity_repr())->assert_is_op("c_identity"); + auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("c_identity", "Out") + ->assert_is_op_input("matmul_v2", "X"); + c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var}); + + // QKV fused path Nodes + auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2"); + auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->AsIntermediate() + ->assert_is_op_input("split", "X"); + + auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split"); + auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) + ->assert_is_op_output("split") + ->AsIntermediate() + ->assert_is_op_input("matmul", "X"); + auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) + ->assert_is_op_output("split") + ->AsOutput() + ->assert_is_op_input("matmul", "Y") + ->assert_is_op_input("while"); + auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) + ->assert_is_op_output("split") + ->AsOutput() + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_op_input("while"); + + // QKV fused path Links + matmul0->LinksFrom({c_identity_out_var, matmul0_w_var}) + .LinksTo({matmul0_out_var}); + eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var}) + .LinksTo({eltadd0_out_var}); + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + split0->LinksFrom({transpose2_0_out_var}) + .LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var}); + + // while loop + auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while"); + while0->LinksFrom({split0_k_out_var, split0_v_out_var}); + + // QK path Nodes + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_qk = + pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout"); + auto* dropout_qk_out_var = + pattern->NewNode(dropout_qk_out_repr()) + ->assert_is_op_output("dropout", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv + + // QK path Linsk + matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var}); + + // QKV path Nodes + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = + pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); // -> out_linear + + auto* matmul_linear = + pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2"); + auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("c_allreduce_sum"); + + // communication c_allreduce_sum + auto* c_allreduce_sum = + pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum"); + auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_linear = + pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add"); + auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* dropout_linear = + pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout"); + auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* eltadd_out = + pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* attention_output = pattern->NewNode(attention_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + // QKV path Links + matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var}) + .LinksTo({matmul_linear_out_var}); + c_allreduce_sum->LinksFrom({matmul_linear_out_var}) + .LinksTo({c_allreduce_sum_out_var}); + eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) + .LinksTo({eltadd_linear_out_var}); + dropout_linear->LinksFrom({eltadd_linear_out_var}) + .LinksTo({dropout_linear_out_var}); + eltadd_out->LinksFrom({input0, dropout_linear_out_var}) + .LinksTo({attention_output}); + + // Feed Forward LayerNorm Nodes + auto* ffn_layer_norm = + pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm"); + auto* ffn_layer_norm_scale_var = + pattern->NewNode(ffn_layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* ffn_layer_norm_bias_var = + pattern->NewNode(ffn_layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* ffn_layer_norm_mean_var = + pattern->NewNode(ffn_layer_norm_mean_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Mean"); + auto* ffn_layer_norm_variance_var = + pattern->NewNode(ffn_layer_norm_variance_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Variance"); + auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("c_identity", "X"); + + ffn_layer_norm + ->LinksFrom( + {attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var}) + .LinksTo({ffn_layer_norm_out_var, + ffn_layer_norm_mean_var, + ffn_layer_norm_variance_var}); + + // communication c_identity + auto* ffn_c_identity = + pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity"); + auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr()) + ->assert_is_op_output("c_identity", "Out") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2", "X"); + ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) + .LinksTo({ffn_c_identity_out_var}); + + // Feed Forward fc1 -> gelu -> fc2 -> dropout + auto* ffn_matmul0 = + pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd0 = + pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("gelu"); + + auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu"); + auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr()) + ->assert_is_op_output("gelu") + ->AsIntermediate() + ->assert_is_op_input("matmul_v2"); + + auto* ffn_matmul1 = + pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2"); + auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr()) + ->assert_is_op_output("matmul_v2") + ->AsIntermediate() + ->assert_is_op_input("c_allreduce_sum"); + + // communication c_allreduce_sum + auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr()) + ->assert_is_op("c_allreduce_sum"); + auto* ffn_c_allreduce_sum_out_var = + pattern->NewNode(ffn_c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd1 = + pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add"); + auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate() + ->assert_is_op_input("dropout"); + + auto* ffn_dropout = + pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout"); + auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr()) + ->assert_is_op_output("dropout") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add"); + + auto* ffn_eltadd_out = + pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add"); + auto* ffn_output = pattern->NewNode(ffn_output_repr()) + ->assert_is_op_output("elementwise_add") + ->AsOutput(); + + ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var}) + .LinksTo({ffn_matmul0_out_var}); + ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var}) + .LinksTo({ffn_eltadd0_out_var}); + ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var}); + ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var}) + .LinksTo({ffn_matmul1_out_var}); + ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var}) + .LinksTo({ffn_c_allreduce_sum_out_var}); + ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) + .LinksTo({ffn_eltadd1_out_var}); + ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var}); + + ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) + .LinksTo({ffn_output}); + + return ffn_output; +} + +} // namespace patterns + +template +inline void QKVWeightsProcess(framework::LoDTensor* wq_tensor, + framework::LoDTensor* wk_tensor, + framework::LoDTensor* wv_tensor, + framework::LoDTensor* bq_tensor, + framework::LoDTensor* bk_tensor, + framework::LoDTensor* bv_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); + auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); + auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); + auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head}); + + framework::LoDTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + auto* tmp_combined_w_data = + tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + // Combine the three fc weights together. + for (int i = 0; i < 3; i++) { + for (int j = 0; j < num_head; j++) { + for (int k = 0; k < dim_head; k++) { + for (int l = 0; l < dim_embed; l++) { + int out_idx = i * num_head * dim_head * dim_embed + + j * dim_head * dim_embed + k * dim_embed + l; + int in_idx = l * num_head * dim_head + j * dim_head + k; + tmp_combined_w_data[out_idx] = w_vec[i][in_idx]; + } + } + } + } + + wq_tensor->Resize(combined_w_dims); + auto* new_combined_w_data = wq_tensor->mutable_data(platform::CPUPlace()); + memcpy( + new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel()); + + framework::LoDTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + auto* tmp_combined_bias_data = + tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size); + memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size); + memcpy( + tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size); + + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = + bq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_bias_data, + tmp_combined_bias_data, + sizeof(T) * bq_tensor->numel()); +} + +template +inline void QKVWeightsProcessFuseQKV(framework::LoDTensor* qkv_w_tensor, + framework::LoDTensor* qkv_b_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* qkv_w_data = qkv_w_tensor->mutable_data(platform::CPUPlace()); + auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); + + framework::LoDTensor tmp_transpose_w_tensor; + tmp_transpose_w_tensor.Resize(transpose_w_dims); + auto* tmp_transpose_w_data = + tmp_transpose_w_tensor.mutable_data(platform::CPUPlace()); + + // transpose qkv matmul Y to QKVWeights + for (int i = 0; i < 3; i++) { + for (int j = 0; j < num_head; j++) { + for (int k = 0; k < dim_head; k++) { + for (int l = 0; l < dim_embed; l++) { + int out_idx = i * num_head * dim_head * dim_embed + + j * dim_head * dim_embed + k * dim_embed + l; + int in_idx = + l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k; + tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx]; + } + } + } + } + + qkv_w_tensor->Resize(transpose_w_dims); + auto* new_transpose_w_data = + qkv_w_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_transpose_w_data, + tmp_transpose_w_data, + sizeof(T) * qkv_w_tensor->numel()); + + auto* qkv_b_data = qkv_b_tensor->mutable_data(platform::CPUPlace()); + auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head}); + + framework::LoDTensor tmp_transpose_b_tensor; + tmp_transpose_b_tensor.Resize(transpose_b_dims); + auto* tmp_transpose_b_data = + tmp_transpose_b_tensor.mutable_data(platform::CPUPlace()); + + // transpose qkv elemenwise_add Y to QKVBias + for (int i = 0; i < 3; i++) { + for (int j = 0; j < num_head; j++) { + for (int k = 0; k < dim_head; k++) { + int out_idx = i * num_head * dim_head + j * dim_head + k; + int in_idx = j * 3 * dim_head + i * dim_head + k; + tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx]; + } + } + } + + qkv_b_tensor->Resize({3, num_head, dim_head}); + auto* new_transpose_b_data = + qkv_b_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_transpose_b_data, + tmp_transpose_b_data, + sizeof(T) * qkv_b_tensor->numel()); +} + +int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern( + pattern, name_scope); + fused_multi_transformer_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* matmul0_w, + Node* matmul1_w, + Node* matmul2_w, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* transpose2_1_out, + Node* transpose2_2_out, + Node* eltadd_qk_b, + Node* dropout_qk, + Node* reshape2_0, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* dropout_linear, + Node* while0, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0_w, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_dropout, + Node* ffn_output) { + auto reshape_desc = reshape2_0->Op(); + int num_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3); + int dim_embed = num_head * dim_head; + + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + auto* wq_tensor = + scope->FindVar(matmul0_w->Name())->GetMutable(); + auto* wk_tensor = + scope->FindVar(matmul1_w->Name())->GetMutable(); + auto* wv_tensor = + scope->FindVar(matmul2_w->Name())->GetMutable(); + + auto* bq_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); + + if (wq_tensor->dtype() == phi::DataType::FLOAT32) { + QKVWeightsProcess(wq_tensor, + wk_tensor, + wv_tensor, + bq_tensor, + bk_tensor, + bv_tensor, + num_head, + dim_head, + dim_embed); + } else if (wq_tensor->dtype() == phi::DataType::FLOAT16) { + QKVWeightsProcess(wq_tensor, + wk_tensor, + wv_tensor, + bq_tensor, + bk_tensor, + bv_tensor, + num_head, + dim_head, + dim_embed); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32 and fp16.")); + } + + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = matmul0_w->Var(); + combined_w_desc->SetShape({3, num_head, dim_head, dim_embed}); + combined_w_desc->SetPersistable(true); + + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, num_head, dim_head}); + combined_bias_desc->SetPersistable(true); + + scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()}); + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // CacheKV input + VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); + // FIXME: only support max_seq_len <= 1024 + cache_kv_desc.SetDataType( + framework::TransToProtoVarType(wq_tensor->dtype())); + cache_kv_desc.SetPersistable(false); + auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); + + OpDesc fill_const_op_desc(layer_norm->Op()->Block()); + fill_const_op_desc.SetType("fill_constant_batch_size_like"); + fill_const_op_desc.SetInput("Input", {input0->Name()}); + fill_const_op_desc.SetOutput("Out", {cache_kv->Name()}); + std::vector shape = {2, -1, num_head, 1024, dim_head}; + fill_const_op_desc.SetAttr("shape", shape); + fill_const_op_desc.SetAttr("input_dim_idx", 0); + fill_const_op_desc.SetAttr("output_dim_idx", 1); + fill_const_op_desc.SetAttr("value", 0); + fill_const_op_desc.SetAttr("dtype", static_cast(proto::VarType::FP32)); + auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); + + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + // output dropout attribute + auto* dropout_op = dropout_linear->Op(); + fused_multi_transformer_op_desc.SetAttr( + "dropout_rate", dropout_op->GetAttr("dropout_prob")); + fused_multi_transformer_op_desc.SetAttr("is_test", + dropout_op->GetAttr("is_test")); + fused_multi_transformer_op_desc.SetAttr( + "dropout_implementation", + dropout_op->GetAttr("dropout_implementation")); + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + IR_NODE_LINK_TO(input0, fill_const_op); + IR_NODE_LINK_TO(fill_const_op, cache_kv); + IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); + + // rewrite while OP input + // 1. delete k, v + // 2. delete matmul1/2_w eltadd1/2_w + // 3. add cache_kv + auto while_Xs = while0->Op()->Input("X"); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), transpose2_1_out->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), transpose2_2_out->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), matmul1_w->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), matmul2_w->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), eltadd1_b->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), eltadd2_b->Name()), + std::end(while_Xs)); + while_Xs.emplace_back(cache_kv->Name()); + while0->Op()->SetInput("X", while_Xs); + + // rewrite while OP output + // 1. delete k, v + // 2. add cache_kv + auto while_Outs = while0->Op()->Output("Out"); + while_Outs.erase(std::remove(std::begin(while_Outs), + std::end(while_Outs), + transpose2_1_out->Name()), + std::end(while_Outs)); + while_Outs.erase(std::remove(std::begin(while_Outs), + std::end(while_Outs), + transpose2_2_out->Name()), + std::end(while_Outs)); + while_Outs.emplace_back(cache_kv->Name()); + while0->Op()->SetOutput("Out", while_Outs); + + // link CacheKV to while + IR_NODE_LINK_TO(cache_kv, while0) + // unlink origin KV output to while + IR_NODE_UNLINK(transpose2_1_out, while0); + IR_NODE_UNLINK(transpose2_2_out, while0); + IR_NODE_UNLINK(while0, transpose2_1_out); + IR_NODE_UNLINK(while0, transpose2_2_out); + // unlink KV weight/bias to while after merged into Q weight/bias + IR_NODE_UNLINK(matmul1_w, while0); + IR_NODE_UNLINK(matmul2_w, while0); + IR_NODE_UNLINK(eltadd1_b, while0); + IR_NODE_UNLINK(eltadd2_b, while0); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_encoder pass in " + "op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer encoder fuse"; + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul1, matmul1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul1_out, matmul1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul1_w, matmul1_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1, reshape2_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1, transpose2_1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul2, matmul2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_out, matmul2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul2_w, matmul2_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2, reshape2_2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2, transpose2_2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + attention_output, attention_output, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH(while0, while0, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu, ffn_gelu, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout, ffn_dropout, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1, eltadd1, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1_b, eltadd1_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd1_out, eltadd1_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2, eltadd2, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2_b, eltadd2_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd2_out, eltadd2_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk, dropout_qk, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + dropout_linear, dropout_linear, fused_multi_transformer_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_pattern) + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + matmul0_w, + matmul1_w, + matmul2_w, + eltadd0_b, + eltadd1_b, + eltadd2_b, + transpose2_1_out, + transpose2_2_out, + eltadd_qk_b, + dropout_qk, + reshape2_0, + matmul_linear_w, + eltadd_linear_b, + dropout_linear, + while0, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0_w, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_dropout, + ffn_output); + + std::unordered_set marked_nodes({layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + matmul0, + matmul1, + matmul2, + matmul0_out, + matmul1_out, + matmul2_out, + eltadd0, + eltadd1, + eltadd2, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + dropout_qk, + dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_w, + matmul_linear_out, + eltadd_linear, + eltadd_linear_b, + eltadd_linear_out, + dropout_linear, + dropout_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_layer_norm_out, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_gelu, + ffn_gelu_out, + ffn_dropout, + ffn_dropout_out, + ffn_eltadd_out}); + + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multi_transformer pass, The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); + } + AddStatis(fusion_count); +} + +FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); + + AddOpCompat(OpCompat("while")) + .AddInput("X") // A set of variables, unconstrained + .End() + .AddInput("Condition") // An scalar + .IsTensor() + .End() + .AddOutput("Out") // A set of variables, unconstrained + .End() + .AddOutput("StepScopes") // A vector of local scope, unconstrained + .End() + .AddAttr("sub_block") + .IsType() + .End(); +} + +int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::FusedMultiTransformerEncoderFuseQKVPattern + fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); + fused_multi_transformer_fuse_qkv_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* matmul0_w, + Node* eltadd0_b, + Node* split0_k_out, + Node* split0_v_out, + Node* eltadd_qk_b, + Node* dropout_qk, + Node* reshape2_0, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* dropout_linear, + Node* while0, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0_w, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_dropout, + Node* ffn_output) { + auto reshape_desc = reshape2_0->Op(); + int num_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3) / + 3; // 3 for qkv + int dim_embed = num_head * dim_head; + + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + auto* qkv_w_tensor = + scope->FindVar(matmul0_w->Name())->GetMutable(); + auto* qkv_b_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + + if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) { + QKVWeightsProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + } else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) { + QKVWeightsProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32 and fp16.")); + } + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // CacheKV input + VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); + // FIXME: only support max_seq_len <= 1024 + cache_kv_desc.SetDataType( + framework::TransToProtoVarType(qkv_w_tensor->dtype())); + cache_kv_desc.SetPersistable(false); + auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); + + OpDesc fill_const_op_desc(layer_norm->Op()->Block()); + fill_const_op_desc.SetType("fill_constant_batch_size_like"); + fill_const_op_desc.SetInput("Input", {input0->Name()}); + fill_const_op_desc.SetOutput("Out", {cache_kv->Name()}); + std::vector shape = {2, -1, num_head, 1024, dim_head}; + fill_const_op_desc.SetAttr("shape", shape); + fill_const_op_desc.SetAttr("input_dim_idx", 0); + fill_const_op_desc.SetAttr("output_dim_idx", 1); + fill_const_op_desc.SetAttr("value", 0); + fill_const_op_desc.SetAttr("dtype", static_cast(proto::VarType::FP32)); + auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); + + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + // output dropout attribute + auto* dropout_op = dropout_linear->Op(); + fused_multi_transformer_op_desc.SetAttr( + "dropout_rate", dropout_op->GetAttr("dropout_prob")); + fused_multi_transformer_op_desc.SetAttr("is_test", + dropout_op->GetAttr("is_test")); + fused_multi_transformer_op_desc.SetAttr( + "dropout_implementation", + dropout_op->GetAttr("dropout_implementation")); + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + IR_NODE_LINK_TO(input0, fill_const_op); + IR_NODE_LINK_TO(fill_const_op, cache_kv); + IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); + + // rewrite while OP input + // 1. delete k, v + // 2. delete matmul1/2_w eltadd1/2_w + // 3. add cache_kv + auto while_Xs = while0->Op()->Input("X"); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()), + std::end(while_Xs)); + while_Xs.emplace_back(cache_kv->Name()); + while0->Op()->SetInput("X", while_Xs); + + // rewrite while OP output + // 1. delete k, v + // 2. add cache_kv + auto while_Outs = while0->Op()->Output("Out"); + while_Outs.erase( + std::remove( + std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()), + std::end(while_Outs)); + while_Outs.erase( + std::remove( + std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()), + std::end(while_Outs)); + while_Outs.emplace_back(cache_kv->Name()); + while0->Op()->SetOutput("Out", while_Outs); + + // link CacheKV to while + IR_NODE_LINK_TO(cache_kv, while0) + // unlink origin KV output to while + IR_NODE_UNLINK(split0_k_out, while0); + IR_NODE_UNLINK(split0_v_out, while0); + IR_NODE_UNLINK(while0, split0_k_out); + IR_NODE_UNLINK(while0, split0_v_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv " + "pass in op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse"; + GET_IR_NODE_FROM_SUBGRAPH( + input0, input0, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, + layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, + layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, + layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, + layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, + reshape2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, + transpose2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + split0, split0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, + ffn_layer_norm, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, + ffn_matmul0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, + ffn_eltadd0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, + ffn_matmul1_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, + ffn_eltadd1_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, + ffn_dropout_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, + ffn_eltadd_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, + softmax_qk_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, + dropout_qk_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, + matmul_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, + reshape2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, + transpose2_qkv, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, + matmul_linear_w, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, + matmul_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, + eltadd_linear_b, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, + eltadd_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, + dropout_linear, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, + dropout_linear_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + while0, while0, fused_multi_transformer_fuse_qkv_pattern) + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + matmul0_w, + eltadd0_b, + split0_k_out, + split0_v_out, + eltadd_qk_b, + dropout_qk, + reshape2_0, + matmul_linear_w, + eltadd_linear_b, + dropout_linear, + while0, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0_w, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_dropout, + ffn_output); + + std::unordered_set marked_nodes({layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + matmul0, + matmul0_out, + eltadd0, + eltadd0_out, + reshape2_0, + reshape2_0_out, + transpose2_0, + transpose2_0_out, + split0, + split0_q_out, + split0_k_out, + split0_v_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + dropout_qk, + dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_w, + matmul_linear_out, + eltadd_linear, + eltadd_linear_b, + eltadd_linear_out, + dropout_linear, + dropout_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_layer_norm_out, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_gelu, + ffn_gelu_out, + ffn_dropout, + ffn_dropout_out, + ffn_eltadd_out}); + + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the fused_multi_transformer_encoder pass, " + "The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); + } + AddStatis(fusion_count); +} + +FusedMultiTransformerEncoderFuseQKVPass:: + FusedMultiTransformerEncoderFuseQKVPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); + + AddOpCompat(OpCompat("while")) + .AddInput("X") // A set of variables, unconstrained + .End() + .AddInput("Condition") // An scalar + .IsTensor() + .End() + .AddOutput("Out") // A set of variables, unconstrained + .End() + .AddOutput("StepScopes") // A vector of local scope, unconstrained + .End() + .AddAttr("sub_block") + .IsType() + .End(); +} + +int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern + fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); + fused_multi_transformer_fuse_qkv_pattern(); + + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* layer_norm, + Node* layer_norm_scale, + Node* layer_norm_bias, + Node* layer_norm_mean, + Node* layer_norm_variance, + Node* c_identity, + Node* matmul0_w, + Node* eltadd0_b, + Node* split0_k_out, + Node* split0_v_out, + Node* eltadd_qk_b, + Node* dropout_qk, + Node* reshape2_0, + Node* matmul_linear_w, + Node* eltadd_linear_b, + Node* dropout_linear, + Node* while0, + Node* ffn_layer_norm, + Node* ffn_layer_norm_scale, + Node* ffn_layer_norm_bias, + Node* ffn_layer_norm_mean, + Node* ffn_layer_norm_variance, + Node* ffn_matmul0_w, + Node* ffn_matmul1_w, + Node* ffn_eltadd0_b, + Node* ffn_eltadd1_b, + Node* ffn_dropout, + Node* ffn_output) { + auto reshape_desc = reshape2_0->Op(); + int num_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3) / + 3; // 3 for qkv + + // Calc index of transformer layer by LayerNorm Scale name + // This calculation assumes: + // 1. no LayerNorm before all transformer layer + // 2. each transformer layer contains 2 LayerNorm layer + auto ln_scale_name = layer_norm_scale->Name(); + auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.')); + auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1); + int layer_idx = atoi(ln_idx_str.c_str()) / 2; + + auto* qkv_w_tensor = + scope->FindVar(matmul0_w->Name())->GetMutable(); + auto* qkv_b_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + + int dim_embed = qkv_w_tensor->dims()[0]; + + if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) { + QKVWeightsProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + } else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) { + QKVWeightsProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32 and fp16.")); + } + + // create fused_multi_transformer + OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); + fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + + // 1. Input setting + fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); + + // pre-LayerNorm input + fused_multi_transformer_op_desc.SetInput("LnScale", + {layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("LnBias", + {layer_norm_bias->Name()}); + + // QKV computation input + fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()}); + + // CacheKV input + VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); + // FIXME: only support max_seq_len <= 1024 + cache_kv_desc.SetDataType( + framework::TransToProtoVarType(qkv_w_tensor->dtype())); + cache_kv_desc.SetPersistable(false); + auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); + + OpDesc fill_const_op_desc(layer_norm->Op()->Block()); + fill_const_op_desc.SetType("fill_constant_batch_size_like"); + fill_const_op_desc.SetInput("Input", {input0->Name()}); + fill_const_op_desc.SetOutput("Out", {cache_kv->Name()}); + std::vector shape = {2, -1, num_head, 1024, dim_head}; + fill_const_op_desc.SetAttr("shape", shape); + fill_const_op_desc.SetAttr("input_dim_idx", 0); + fill_const_op_desc.SetAttr("output_dim_idx", 1); + fill_const_op_desc.SetAttr("value", 0); + fill_const_op_desc.SetAttr("dtype", static_cast(proto::VarType::FP32)); + auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); + + fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); + + // Out Linear input + fused_multi_transformer_op_desc.SetInput("OutLinearW", + {matmul_linear_w->Name()}); + fused_multi_transformer_op_desc.SetInput("OutLinearBias", + {eltadd_linear_b->Name()}); + + // Feed Forward input + fused_multi_transformer_op_desc.SetInput("FFNLnScale", + {ffn_layer_norm_scale->Name()}); + fused_multi_transformer_op_desc.SetInput("FFNLnBias", + {ffn_layer_norm_bias->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Weight", + {ffn_matmul0_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN1Bias", + {ffn_eltadd0_b->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Weight", + {ffn_matmul1_w->Name()}); + fused_multi_transformer_op_desc.SetInput("FFN2Bias", + {ffn_eltadd1_b->Name()}); + + // 2. Output setting + fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()}); + fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()}); + + // Attribute setting + fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true); + fused_multi_transformer_op_desc.SetAttr( + "epsilon", layer_norm->Op()->GetAttr("epsilon")); + + // output dropout attribute + auto* dropout_op = dropout_linear->Op(); + fused_multi_transformer_op_desc.SetAttr( + "dropout_rate", dropout_op->GetAttr("dropout_prob")); + fused_multi_transformer_op_desc.SetAttr("is_test", + dropout_op->GetAttr("is_test")); + fused_multi_transformer_op_desc.SetAttr( + "dropout_implementation", + dropout_op->GetAttr("dropout_implementation")); + + // parallel ring id + auto* c_identity_op = c_identity->Op(); + fused_multi_transformer_op_desc.SetAttr("ring_id", + c_identity_op->GetAttr("ring_id")); + + auto* fused_multi_transformer = + graph->CreateOpNode(&fused_multi_transformer_op_desc); + IR_NODE_LINK_TO(input0, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); + IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); + + IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer); + IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer); + + IR_NODE_LINK_TO(input0, fill_const_op); + IR_NODE_LINK_TO(fill_const_op, cache_kv); + IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); + + IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); + + // rewrite while OP input + // 1. delete k, v + // 2. delete matmul1/2_w eltadd1/2_w + // 3. add cache_kv + auto while_Xs = while0->Op()->Input("X"); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()), + std::end(while_Xs)); + while_Xs.erase( + std::remove( + std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()), + std::end(while_Xs)); + while_Xs.emplace_back(cache_kv->Name()); + while0->Op()->SetInput("X", while_Xs); + + // rewrite while OP output + // 1. delete k, v + // 2. add cache_kv + auto while_Outs = while0->Op()->Output("Out"); + while_Outs.erase( + std::remove( + std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()), + std::end(while_Outs)); + while_Outs.erase( + std::remove( + std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()), + std::end(while_Outs)); + while_Outs.emplace_back(cache_kv->Name()); + while0->Op()->SetOutput("Out", while_Outs); + + // link CacheKV to while + IR_NODE_LINK_TO(cache_kv, while0) + // unlink origin KV output to while + IR_NODE_UNLINK(split0_k_out, while0); + IR_NODE_UNLINK(split0_v_out, while0); + IR_NODE_UNLINK(while0, split0_k_out); + IR_NODE_UNLINK(while0, split0_v_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv " + "pass in op compat failed."; + return; + } + + VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse"; + GET_IR_NODE_FROM_SUBGRAPH( + input0, input0, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, + layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, + layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, + layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, + layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, + layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + c_identity, c_identity, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity_out, + c_identity_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, + reshape2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, + transpose2_0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + split0, split0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm, + ffn_layer_norm, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale, + ffn_layer_norm_scale, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias, + ffn_layer_norm_bias, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean, + ffn_layer_norm_mean, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance, + ffn_layer_norm_variance, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out, + ffn_layer_norm_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity, + ffn_c_identity, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out, + ffn_c_identity_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out, + ffn_matmul0_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out, + ffn_eltadd0_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out, + ffn_matmul1_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum, + ffn_c_allreduce_sum, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out, + ffn_c_allreduce_sum_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out, + ffn_eltadd1_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out, + ffn_dropout_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, + ffn_eltadd_out, + fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH( + ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern) + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, + softmax_qk_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern) + GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out, + dropout_qk_out, + fused_multi_transformer_fuse_qkv_pattern) + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, + matmul_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, + reshape2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, + transpose2_qkv, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, + transpose2_qkv_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w, + matmul_linear_w, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out, + matmul_linear_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum, + c_allreduce_sum, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out, + c_allreduce_sum_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b, + eltadd_linear_b, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, + eltadd_linear_out, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear, + dropout_linear, + fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out, + dropout_linear_out, + fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + while0, while0, fused_multi_transformer_fuse_qkv_pattern); + + fuse_creater(input0, + layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + c_identity, + matmul0_w, + eltadd0_b, + split0_k_out, + split0_v_out, + eltadd_qk_b, + dropout_qk, + reshape2_0, + matmul_linear_w, + eltadd_linear_b, + dropout_linear, + while0, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_matmul0_w, + ffn_matmul1_w, + ffn_eltadd0_b, + ffn_eltadd1_b, + ffn_dropout, + ffn_output); + + std::unordered_set marked_nodes({layer_norm, + layer_norm_scale, + layer_norm_bias, + layer_norm_mean, + layer_norm_variance, + layer_norm_out, + c_identity, + c_identity_out, + matmul0, + matmul0_out, + eltadd0, + eltadd0_out, + reshape2_0, + reshape2_0_out, + transpose2_0, + transpose2_0_out, + split0, + split0_q_out, + split0_k_out, + split0_v_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + dropout_qk, + dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + reshape2_qkv, + transpose2_qkv, + transpose2_qkv_out, + matmul_linear, + matmul_linear_w, + matmul_linear_out, + c_allreduce_sum, + c_allreduce_sum_out, + eltadd_linear, + eltadd_linear_b, + eltadd_linear_out, + dropout_linear, + dropout_linear_out, + eltadd_out, + ffn_layer_norm, + ffn_layer_norm_scale, + ffn_layer_norm_bias, + ffn_layer_norm_mean, + ffn_layer_norm_variance, + ffn_layer_norm_out, + ffn_c_identity, + ffn_c_identity_out, + ffn_matmul0, + ffn_matmul1, + ffn_matmul0_out, + ffn_matmul1_out, + ffn_c_allreduce_sum, + ffn_c_allreduce_sum_out, + ffn_eltadd0, + ffn_eltadd1, + ffn_eltadd0_out, + ffn_eltadd1_out, + ffn_gelu, + ffn_gelu_out, + ffn_dropout, + ffn_dropout_out, + ffn_eltadd_out}); + + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl( + Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the fused_multi_transformer_encoder pass, " + "The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass, + new bool(true)); + } + AddStatis(fusion_count); +} + +MultiDevicesFusedMultiTransformerEncoderFuseQKVPass:: + MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); + + AddOpCompat(OpCompat("while")) + .AddInput("X") // A set of variables, unconstrained + .End() + .AddInput("Condition") // An scalar + .IsTensor() + .End() + .AddOutput("Out") // A set of variables, unconstrained + .End() + .AddOutput("StepScopes") // A vector of local scope, unconstrained + .End() + .AddAttr("sub_block") + .IsType() + .End(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fused_multi_transformer_encoder_pass, + paddle::framework::ir::FusedMultiTransformerEncoderPass); +REGISTER_PASS(fused_multi_transformer_encoder_fuse_qkv_pass, + paddle::framework::ir::FusedMultiTransformerEncoderFuseQKVPass); +REGISTER_PASS( + multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass, + paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderFuseQKVPass); + +REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); +REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); +REGISTER_PASS_CAPABILITY( + multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h new file mode 100644 index 00000000000..6e62a69cdf1 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h @@ -0,0 +1,398 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FusedMultiTransformerEncoderPattern : public PatternBase { + FusedMultiTransformerEncoderPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "fused_multi_transformer_encoder") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul1); + PATTERN_DECL_NODE(matmul2); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul1_w); + PATTERN_DECL_NODE(matmul2_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(matmul1_out); + PATTERN_DECL_NODE(matmul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + PATTERN_DECL_NODE(dropout_qk); + PATTERN_DECL_NODE(dropout_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // while loop + PATTERN_DECL_NODE(while0); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_gelu); + PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + PATTERN_DECL_NODE(ffn_dropout); + PATTERN_DECL_NODE(ffn_dropout_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; + +struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { + FusedMultiTransformerEncoderFuseQKVPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase( + pattern, name_scope, "fused_multi_transformer_encoder_fuse_qkv") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_0_out); + + PATTERN_DECL_NODE(split0) + PATTERN_DECL_NODE(split0_q_out) + PATTERN_DECL_NODE(split0_k_out) + PATTERN_DECL_NODE(split0_v_out) + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + PATTERN_DECL_NODE(dropout_qk); + PATTERN_DECL_NODE(dropout_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // while loop + PATTERN_DECL_NODE(while0); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_gelu); + PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + PATTERN_DECL_NODE(ffn_dropout); + PATTERN_DECL_NODE(ffn_dropout_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; + +struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern + : public PatternBase { + MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, + name_scope, + "multi_devices_fused_multi_transformer_encoder_fuse_qkv") {} + + PDNode* operator()(); + + // Q, K, V path + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(c_identity); + PATTERN_DECL_NODE(c_identity_out); + PATTERN_DECL_NODE(matmul0); + PATTERN_DECL_NODE(matmul0_w); + PATTERN_DECL_NODE(matmul0_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_0_out); + + PATTERN_DECL_NODE(split0) + PATTERN_DECL_NODE(split0_q_out) + PATTERN_DECL_NODE(split0_k_out) + PATTERN_DECL_NODE(split0_v_out) + + // Q, K matmul + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + PATTERN_DECL_NODE(dropout_qk); + PATTERN_DECL_NODE(dropout_qk_out); + + // QK, V matmul + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_qkv_out); + + // while loop + PATTERN_DECL_NODE(while0); + + // out linear + PATTERN_DECL_NODE(matmul_linear); + PATTERN_DECL_NODE(matmul_linear_w); + PATTERN_DECL_NODE(matmul_linear_out); + PATTERN_DECL_NODE(c_allreduce_sum); + PATTERN_DECL_NODE(c_allreduce_sum_out); + PATTERN_DECL_NODE(eltadd_linear); + PATTERN_DECL_NODE(eltadd_linear_b); + PATTERN_DECL_NODE(eltadd_linear_out); + PATTERN_DECL_NODE(dropout_linear); + PATTERN_DECL_NODE(dropout_linear_out); + + // output elementwise_add + PATTERN_DECL_NODE(eltadd_out) + PATTERN_DECL_NODE(attention_output); + + // Feed Forward nodes + PATTERN_DECL_NODE(ffn_layer_norm); + PATTERN_DECL_NODE(ffn_layer_norm_scale); + PATTERN_DECL_NODE(ffn_layer_norm_bias); + PATTERN_DECL_NODE(ffn_layer_norm_mean); + PATTERN_DECL_NODE(ffn_layer_norm_variance); + PATTERN_DECL_NODE(ffn_layer_norm_out); + PATTERN_DECL_NODE(ffn_c_identity); + PATTERN_DECL_NODE(ffn_c_identity_out); + PATTERN_DECL_NODE(ffn_matmul0); + PATTERN_DECL_NODE(ffn_matmul0_w); + PATTERN_DECL_NODE(ffn_matmul0_out); + PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd0_out); + PATTERN_DECL_NODE(ffn_gelu); + PATTERN_DECL_NODE(ffn_gelu_out); + PATTERN_DECL_NODE(ffn_matmul1); + PATTERN_DECL_NODE(ffn_matmul1_w); + PATTERN_DECL_NODE(ffn_matmul1_out); + PATTERN_DECL_NODE(ffn_c_allreduce_sum); + PATTERN_DECL_NODE(ffn_c_allreduce_sum_out); + PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(ffn_eltadd1_out); + PATTERN_DECL_NODE(ffn_dropout); + PATTERN_DECL_NODE(ffn_dropout_out); + + // output elementwise_add + PATTERN_DECL_NODE(ffn_eltadd_out) + PATTERN_DECL_NODE(ffn_output); +}; +} // namespace patterns + +class FusedMultiTransformerEncoderPass : public FusePassBase { + public: + FusedMultiTransformerEncoderPass(); + virtual ~FusedMultiTransformerEncoderPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fused_multi_transformer_encoder"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase { + public: + FusedMultiTransformerEncoderFuseQKVPass(); + virtual ~FusedMultiTransformerEncoderFuseQKVPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fused_multi_transformer_encoder_fuse_qkv"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass + : public FusePassBase { + public: + MultiDevicesFusedMultiTransformerEncoderFuseQKVPass(); + virtual ~MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{ + "multi_devices_fused_multi_transformer_encoder_fuse_qkv"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc new file mode 100644 index 00000000000..61017b273a0 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc @@ -0,0 +1,563 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + + // MHA: pre Layer Norm + AddVarToScope(param_scope, "ln_scale", {1024}); + AddVarToScope(param_scope, "ln_bias", {1024}); + + // MHA: QKV fc + AddVarToScope(param_scope, "weights0", {1024, 1024}); + AddVarToScope(param_scope, "weights1", {1024, 1024}); + AddVarToScope(param_scope, "weights2", {1024, 1024}); + AddVarToScope(param_scope, "bias_0", {1024}); + AddVarToScope(param_scope, "bias_1", {1024}); + AddVarToScope(param_scope, "bias_2", {1024}); + + // MHA: QK bias + AddVarToScope(param_scope, "biasqk", {1024}); + + // MHA: out Linear + AddVarToScope(param_scope, "weights_l", {1024, 1024}); + AddVarToScope(param_scope, "bias_l", {1024}); + + // MHA: pre Layer Norm + AddVarToScope(param_scope, "ffn_ln_scale", {1024}); + AddVarToScope(param_scope, "ffn_ln_bias", {1024}); + + // FFN: fc1 -> (gelu) -> fc2 + AddVarToScope(param_scope, "ffn_weights0", {1024, 4096}); + AddVarToScope(param_scope, "ffn_weights1", {4096, 1024}); + AddVarToScope(param_scope, "ffn_bias_0", {4096}); + AddVarToScope(param_scope, "ffn_bias_1", {1024}); + + return param_scope; +} + +TEST(FusedMultiTransformerEncoderPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out + // (layer_norm_out, weights_0) matmul_v2 -> matmul_out0 + // (layer_norm_out, weights_1) matmul_v2 -> matmul_out1 + // (layer_norm_out, weights_2) matmul_v2 -> matmul_out2 + // (matmul_out0, bias_0) elementwise_add -> eltadd_0 + // (matmul_out1, bias_1) elementwise_add -> eltadd_1 + // (matmul_out2, bias_2) elementwise_add -> eltadd_2 + // (eltadd_0) reshape2 -> reshape_0 + // (eltadd_1) reshape2 -> reshape_1 + // (eltadd_2) reshape2 -> reshape_2 + // (reshape_0) transpose2 -> transpose_0 + // (reshape_1) transpose2 -> transpose_1 + // (reshape_2) transpose2 -> transpose_2 + // (transpose_0, transpose_1) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk) dropout -> dropout_qk + // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) elementwise_add -> eltadd_linear + // (eltadd_linear) dropout -> dropout_linear + // (eltadd_out) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 + // (ffn_eltadd1) dropout -> ffn_dropout + // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // + // (transpose_1, transpose_2) while -> decoder block + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 1024}, true); + auto* weights_1 = layers.data("weights1", {1024, 1024}, true); + auto* weights_2 = layers.data("weights2", {1024, 1024}, true); + auto* matmul_out_0 = + layers.matmul_v2(ln_out, weights_0, nullptr, false, true); + auto* matmul_out_1 = + layers.matmul_v2(ln_out, weights_1, nullptr, false, true); + auto* matmul_out_2 = + layers.matmul_v2(ln_out, weights_2, nullptr, false, true); + + auto* b0 = layers.data("bias_0", {1024}, true); + auto* b1 = layers.data("bias_1", {1024}, true); + auto* b2 = layers.data("bias_2", {1024}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + auto* elementwise_out_1 = + layers.elementwise_add(matmul_out_1, b1, nullptr, 2); + auto* elementwise_out_2 = + layers.elementwise_add(matmul_out_2, b2, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true); + auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + auto* transpose_1 = layers.transpose2(reshape_1, axis, true); + auto* transpose_2 = layers.transpose2(reshape_2, axis, true); + + // Link to decoder while block + layers.while_loop({transpose_1, transpose_2}); + + // MHA: QK matmul + auto* matmul_qk = + layers.matmul(transpose_0, transpose_1, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(dropout_qk, transpose_2); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + auto* linear_eltadd_out = + layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); + + auto* dropout_qkv = + layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); + auto* attention_out = layers.elementwise_add(x, dropout_qkv); + + // FFN: pre LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + auto* ffn_ln_out = + layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); + + // FFN: dropout -> elementwise_add + auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); + layers.elementwise_add(attention_out, ffn_dropout); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = + PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fused_multi_transformer_encoder_pass failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ(num_nodes_before, + num_nodes_after + 68, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder_pass, The " + "node num in graph " + "should be %d, but the result is %d", + num_nodes_before - 68, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder pass, " + "there should be one fused_multi_transformer op, " + "but the result is %d", + num_fused_nodes_after)); +} + +TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("fused_multi_transformer_encoder_pass")); +} + +TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out + // (layer_norm_out, weights_0) matmul_v2 -> matmul_out0 + // (matmul_out0, bias_0) elementwise_add -> eltadd_0 + // (eltadd_0) reshape2 -> reshape_0 + // (reshape_0) transpose2 -> transpose_0 + // (transpose_0) split -> split_q, split_k, + // split_v (split_k) assign -> assign_k + // (split_v) assign -> assign_v + // (split_q, split_k) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk) dropout -> dropout_qk + // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) elementwise_add -> eltadd_linear + // (eltadd_linear) dropout -> dropout_linear + // (eltadd_out) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 + // (ffn_eltadd1) dropout -> ffn_dropout + // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // + // (transpose_1, transpose_2) while -> decoder block + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 3072}, true); + auto* matmul_out_0 = + layers.matmul_v2(ln_out, weights_0, nullptr, false, true); + + auto* b0 = layers.data("bias_0", {3072}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + + auto split_outs = layers.split(transpose_0, 3, 3); + auto* split_q = split_outs[0]; + auto* split_k = split_outs[1]; + auto* split_v = split_outs[2]; + layers.assign(split_k); + layers.assign(split_v); + + // Link to decoder while block + layers.while_loop({split_k, split_v}); + + // MHA: QK matmul + auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + auto* linear_eltadd_out = + layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); + + auto* dropout_qkv = + layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); + auto* attention_out = layers.elementwise_add(x, dropout_qkv); + + // FFN: pre LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + auto* ffn_ln_out = + layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); + + // FFN: dropout -> elementwise_add + auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); + layers.elementwise_add(attention_out, ffn_dropout); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = PassRegistry::Instance().Get( + "fused_multi_transformer_encoder_fuse_qkv_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fused_multi_transformer_encoder_fuse_qkv_pass failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_before, + num_nodes_after + 56, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder_fuse_qkv_pass, " + "The node num in graph should be %d, but the result is %d", + num_nodes_before - 56, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder_fuse_qkv " + "pass, there should be one fused_multi_transformer " + "op, but the result is %d", + num_fused_nodes_after)); +} + +TEST(FusedMultiTransformerEncoderFuseQKVPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("fused_multi_transformer_encoder_fuse_qkv_pass")); +} + +TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, ln_scale, ln_bias) layer_norm -> layer_norm_out + // (layer_norm_out) c_identity -> c_identity_out + // (c_identity_out, weights_0) matmul_v2 -> matmul_out0 + // (matmul_out0) elementwise_add -> eltadd_0 + // (eltadd_0) reshape2 -> reshape_0 + // (reshape_0) transpose2 -> transpose_0 + // (transpose_0) split -> split_q, split_k, + // split_v (split_k) assign -> assign_k + // (split_v) assign -> assign_v + // (split_q, split_k) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk) dropout -> dropout_qk + // (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) matmul_v2 -> matmul_linear + // (matmul_linear) c_all_reduce -> c_all_reduce_out + // (c_all_reduce_out) elementwise_add -> eltadd_linear + // (eltadd_linear) dropout -> dropout_linear + // (eltadd_out) elementwise_add -> attention_out + // + // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out + // (ffn_layer_norm_out) c_identity -> ffn_c_identity_out + // (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0 + // (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0 + // (ffn_eltadd0) gelu -> ffn_gelu + // (ffn_gelu) matmul_v2 -> ffn_matmul1 + // (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out + // (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1 + // (ffn_eltadd1) dropout -> ffn_dropout + // (attention_out, ffn_dropout) elementwise_add -> ffn_output + // + // (transpose_1, transpose_2) while -> decoder block + + Layers layers; + // MHA: pre LayerNorm + auto* x = layers.data("x", {1, 128, 1024}); + auto* ln_scale = layers.data("ln_scale", {1024}, true); + auto* ln_bias = layers.data("ln_bias", {1024}, true); + auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0]; + auto* c_identity_out = layers.c_identity(ln_out); + + // MHA: QKV fc + auto* weights_0 = layers.data("weights0", {1024, 3072}, true); + auto* matmul_out_0 = + layers.matmul_v2(c_identity_out, weights_0, nullptr, false, true); + + auto* b0 = layers.data("bias_0", {3072}, true); + auto* elementwise_out_0 = + layers.elementwise_add(matmul_out_0, b0, nullptr, 2); + + std::vector shape = {1, 128, 16, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + + auto split_outs = layers.split(transpose_0, 3, 3); + auto* split_q = split_outs[0]; + auto* split_k = split_outs[1]; + auto* split_v = split_outs[2]; + layers.assign(split_k); + layers.assign(split_v); + + // Link to decoder while block + layers.while_loop({split_k, split_v}); + + // MHA: QK matmul + auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train"); + + // MHA: QKV matmul + auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); + + // MHA: out Linear + auto* weights_l = layers.data("weights_l", {1024, 1024}, true); + auto* bias_l = layers.data("weightsl", {1024, 1024}, true); + auto* linear_matmut_out = + layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true); + auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out); + auto* linear_eltadd_out = + layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); + + auto* dropout_qkv = + layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train"); + auto* attention_out = layers.elementwise_add(x, dropout_qkv); + + // FFN: pre LayerNorm + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + auto* ffn_ln_out = + layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0]; + auto* ffn_c_identity_out = layers.c_identity(ffn_ln_out); + + // FFN: fc1 -> gelu -> fc2 + auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true); + auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true); + auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true); + auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true); + auto* ffn_matmul0_out = + layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, true); + auto* ffn_eltadd0_out = + layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2); + auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out); + auto* ffn_matmul1_out = + layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true); + auto* ffn_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out); + auto* ffn_eltadd1_out = + layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2); + + // FFN: dropout -> elementwise_add + auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train"); + layers.elementwise_add(attention_out, ffn_dropout); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = PassRegistry::Instance().Get( + "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"); + if (pass.get() == nullptr) + LOG(INFO) + << "get multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass " + "failed"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_before, + num_nodes_after + 64, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder_fuse_qkv_pass, " + "The node num in graph should be %d, but the result is %d", + num_nodes_before - 64, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fused_multi_transformer_encoder_fuse_qkv " + "multi-devices pass, there should be one " + "fused_multi_transformer op, but the result is %d", + num_fused_nodes_after)); +} + +TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, + pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible( + "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fused_multi_transformer_encoder_pass); +USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass); +USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass); diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index f5f28219ecd..fe2c9adf68f 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph, // avoid kRootBlockIndex not 0 if (idx == kRootBlockIndex) continue; - block = program_pb.add_blocks(); - block->set_idx(idx); - block->set_parent_idx(kRootBlockIndex); + if (static_cast(idx) < program_pb.blocks_size()) { + block = program_pb.mutable_blocks(idx); + } else { + block = program_pb.add_blocks(); + block->set_idx(idx); + block->set_parent_idx(kRootBlockIndex); + } + GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind, diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 72b19c1dd52..55c7787012b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { if (graph.Nodes().empty()) return false; for (auto &node : GraphTraits::DFS(graph)) { + if (node.Name().rfind("__control_var") == 0) continue; for (const auto &pdnode : pattern_.nodes()) { if (pdnode->Tell(&node)) { VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name(); @@ -383,7 +384,6 @@ std::string PDPattern::DotString() const { // Create Edges for (const auto &edge : edges()) { if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { - LOG(ERROR) << "no node " << edge.first << " " << edge.second; continue; } auto &src = node2dot.at(edge.first); @@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() { PDNode *PDNode::assert_is_persistable_var() { assert_is_var(); - asserts_.emplace_back([=](Node *x) { return x->Var()->Persistable(); }); + asserts_.emplace_back( + [=](Node *x) { return x->Var() && x->Var()->Persistable(); }); return this; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index bd38b2123e9..110e73b228e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase { a->outputs.push_back(b); \ b->inputs.push_back(a); +// UnLink 2 ir::Nodes from each other. +#define IR_NODE_UNLINK(a, b) \ + a->outputs.erase( \ + std::remove(std::begin(a->outputs), std::end(a->outputs), b), \ + std::end(a->outputs)); \ + b->inputs.erase(std::remove(std::begin(b->inputs), std::end(b->inputs), a), \ + std::end(b->inputs)); + // Set the out_var as the output of the op #define IR_OP_VAR_LINK(op, out_var) \ op->outputs.push_back(out_var); \ diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 35f72deab89..4a1bf4baecc 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -22,6 +22,7 @@ limitations under the License. */ namespace paddle { namespace framework { +class Scope; namespace ir { class Graph; } // namespace ir @@ -35,6 +36,17 @@ namespace paddle { namespace framework { namespace ir { +static const char kParamScopeAttr[] = "__param_scope__"; + +static const std::vector support_subgraph_passes = { + "fused_multi_transformer_encoder_pass", + "fused_multi_transformer_decoder_pass", + "fused_multi_transformer_encoder_fuse_qkv_pass", + "fused_multi_transformer_decoder_fuse_qkv_pass", + "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", + "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", +}; + Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; CheckPrevPass(); @@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const { true, platform::errors::InvalidArgument( "The VarDescs of persistable variable are not consistency.")); - applied_ = true; if (!graph->Has(kPassRecorder)) { graph->Set(kPassRecorder, new PassRecorder); } graph->Get(kPassRecorder).insert(Type()); + + if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(), + support_subgraph_passes.end(), + Type())) { + for (size_t i = 1; i < graph->SubGraphsSize(); i++) { + auto *sub_graph = graph->GetSubGraph(i); + if (!sub_graph->Has(framework::ir::kParamScopeAttr)) { + sub_graph->SetNotOwned( + framework::ir::kParamScopeAttr, + &graph->Get(framework::ir::kParamScopeAttr)); + } + + ApplyImpl(sub_graph); + PADDLE_ENFORCE_EQ( + HasCircle(*sub_graph), + false, + platform::errors::InvalidArgument( + "Illegal pass %s. Generated graph shouldn't contain cycle.", + Type())); + PADDLE_ENFORCE_EQ( + VarDescIsConsistency(*sub_graph), + true, + platform::errors::InvalidArgument( + "The VarDescs of persistable variable are not consistency.")); + if (!sub_graph->Has(kPassRecorder)) { + sub_graph->Set(kPassRecorder, new PassRecorder); + } + sub_graph->Get(kPassRecorder).insert(Type()); + } + } + applied_ = true; #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // Passes can change params, tensors, so caching need to be discarded diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 37a28bec16d..2ed753cdeb7 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder"; constexpr char kEmbEltwiseLayernormPass[] = "embedding_eltwise_layernorm_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; +constexpr char kFusedMultiTransformerEncoderPass[] = + "fused_multi_transformer_encoder_pass_flag"; +constexpr char kFusedMultiTransformerDecoderPass[] = + "fused_multi_transformer_decoder_pass_flag"; +constexpr char kFusedMultiTransformerEncoderFuseQKVPass[] = + "fused_multi_transformer_encoder_fuse_qkv_pass_flag"; +constexpr char kFusedMultiTransformerDecoderFuseQKVPass[] = + "fused_multi_transformer_decoder_fuse_qkv_pass_flag"; +constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] = + "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag"; +constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] = + "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag"; constexpr char kPrelnEmbEltwiseLayernormPass[] = "preln_embedding_eltwise_layernorm_fuse_pass_flag"; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 4d34e9e0900..3cce19e10c6 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -146,6 +146,12 @@ struct Layers { return unary_op("relu", x, out); } + VarDesc* gelu(VarDesc* x, VarDesc* out = nullptr, bool approximate = true) { + AttributeMap attrs; + attrs["approximate"] = approximate; + return unary_op("gelu", x, out, &attrs); + } + VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) { return unary_op("sigmoid", x, out); } @@ -154,6 +160,20 @@ struct Layers { return unary_op("tanh", x, out); } + VarDesc* c_identity(VarDesc* x, VarDesc* out = nullptr, int ring_id = -1) { + AttributeMap attrs; + attrs["ring_id"] = ring_id; + return unary_op("c_identity", x, out, &attrs); + } + + VarDesc* c_allreduce_sum(VarDesc* x, + VarDesc* out = nullptr, + int ring_id = -1) { + AttributeMap attrs; + attrs["ring_id"] = ring_id; + return unary_op("c_allreduce_sum", x, out, &attrs); + } + VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias, @@ -332,6 +352,37 @@ struct Layers { return outs; } + std::vector split(VarDesc* x, int num_or_section, int axis = 0) { + std::vector outs(num_or_section); + for (int i = 0; i < num_or_section; i++) { + outs[i] = lod_tensor(unique_name()); + } + std::vector out_names(num_or_section); + for (int i = 0; i < num_or_section; i++) { + out_names[i] = outs[i]->Name(); + } + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("split"); + op->SetInput("X", {x->Name()}); + op->SetOutput("Out", out_names); + op->SetAttr("num_or_section", num_or_section); + op->SetAttr("axis", axis); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return outs; + } + + VarDesc* assign(VarDesc* x) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("assign"); + op->SetInput("X", {x->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return out; + } + VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr, @@ -459,6 +510,24 @@ struct Layers { return out; } + VarDesc* while_loop(std::vector xs, VarDesc* cond = nullptr) { + VarDesc* out = lod_tensor(unique_name()); + VarDesc* step_scopes = lod_tensor(unique_name()); + if (cond == nullptr) cond = lod_tensor(unique_name()); + + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("while"); + std::vector xs_names; + for (auto& x : xs) xs_names.emplace_back(x->Name()); + op->SetInput("X", xs_names); + op->SetInput("Condition", {cond->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetOutput("StepScopes", {step_scopes->Name()}); + op->SetAttr("sub_block", {program_.MutableBlock(0)}); + op->SetAttr("is_test", true); + return out; + } + void backward(std::vector targets) { // This function is designed to simulate the structure of training program, // but is constructed differently as the actual program. @@ -523,7 +592,10 @@ struct Layers { return var; } - VarDesc* unary_op(std::string type, VarDesc* x, VarDesc* out = nullptr) { + VarDesc* unary_op(std::string type, + VarDesc* x, + VarDesc* out = nullptr, + const AttributeMap* attrs = nullptr) { if (!out) { out = lod_tensor(unique_name()); } @@ -531,6 +603,11 @@ struct Layers { op->SetType(type); op->SetInput("X", {x->Name()}); op->SetOutput("Out", {out->Name()}); + if (attrs) { + for (auto& iter : *attrs) { + op->SetAttr(iter.first, iter.second); + } + } op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); return out; diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index bfe9e1e4b26..775b61e9494 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle( } else { // Normal operators. for (const Node* node : requires) { + if (!node->Var()) continue; if (node->Var()->Persistable()) continue; std::string var = node->Name(); if (!lifecycles->count(var)) { @@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize( // between performance and underlying principle. std::unordered_set black_list; for (auto* node : graph->Nodes()) { - if (node->IsVar() && + if (node->IsVar() && node->Var() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { if (!valid_var(node)) { @@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize( // Collect tensors from graph. for (auto* node : graph->Nodes()) { - if (node->IsVar() && + if (node->IsVar() && node->Var() && node->Var()->GetType() == framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && !black_list.count(node->Var()->Name())) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 931aad80ce2..8b7d90fc5cc 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -193,22 +193,28 @@ const std::vector kTrtLowerPrecisionPasses{ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // - "is_test_pass", // - "simplify_with_basic_ops_pass", // - "conv_bn_fuse_pass", // - "conv_eltwiseadd_bn_fuse_pass", // - "embedding_eltwise_layernorm_fuse_pass", // - "multihead_matmul_fuse_pass_v2", // - "gpu_cpu_squeeze2_matmul_fuse_pass", // - "gpu_cpu_reshape2_matmul_fuse_pass", // - "gpu_cpu_flatten2_matmul_fuse_pass", // - "gpu_cpu_map_matmul_v2_to_mul_pass", // - "gpu_cpu_map_matmul_v2_to_matmul_pass", // - "matmul_scale_fuse_pass", // - "multihead_matmul_fuse_pass_v3", // - "gpu_cpu_map_matmul_to_mul_pass", // - "fc_fuse_pass", // - "fc_elementwise_layernorm_fuse_pass", // + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // + "embedding_eltwise_layernorm_fuse_pass", // + "multihead_matmul_fuse_pass_v2", // + "fused_multi_transformer_encoder_pass", // + "fused_multi_transformer_decoder_pass", // + "fused_multi_transformer_encoder_fuse_qkv_pass", // + "fused_multi_transformer_decoder_fuse_qkv_pass", // + "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", // + "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", // + "gpu_cpu_squeeze2_matmul_fuse_pass", // + "gpu_cpu_reshape2_matmul_fuse_pass", // + "gpu_cpu_flatten2_matmul_fuse_pass", // + "gpu_cpu_map_matmul_v2_to_mul_pass", // + "gpu_cpu_map_matmul_v2_to_matmul_pass", // + "matmul_scale_fuse_pass", // + "multihead_matmul_fuse_pass_v3", // + "gpu_cpu_map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "fc_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we -- GitLab