未验证 提交 5a2e5179 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add FusedMultiTransformer fuse pass for GPT3 (#45907)


* add fused_multi_transformer_encoder/decoder pass, run GPT-3 success
上级 4dc4d5fc
......@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference)
pass_library(fused_multi_transformer_encoder_pass inference)
pass_library(fused_multi_transformer_decoder_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference)
......@@ -311,6 +313,14 @@ cc_test(
test_multihead_matmul_fuse_pass
SRCS multihead_matmul_fuse_pass_tester.cc
DEPS multihead_matmul_fuse_pass)
cc_test(
test_fused_multi_transformer_encoder_pass
SRCS fused_multi_transformer_encoder_pass_tester.cc
DEPS fused_multi_transformer_encoder_pass)
cc_test(
test_fused_multi_transformer_decoder_pass
SRCS fused_multi_transformer_decoder_pass_tester.cc
DEPS fused_multi_transformer_decoder_pass)
cc_test(
test_conv_bn_fuse_pass_cc
SRCS conv_bn_fuse_pass_tester.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h"
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X")
->assert_more([](Node* x) {
if (x->outputs.size() == 3) {
return true;
} else {
return false;
}
});
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// Q path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
// Q path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path Nodes
auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2");
auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd1 =
pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate();
auto* concat_0_in_var = pattern->NewNode(concat_0_in_repr())->AsInput();
auto* concat_0 = pattern->NewNode(concat_0_repr())->assert_is_op("concat");
auto* concat_0_out_var = pattern->NewNode(concat_0_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("assign");
auto assign_0 = pattern->NewNode(assign_0_repr())->assert_is_op("assign");
// K path Links
matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var})
.LinksTo({matmul1_out_var});
eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var})
.LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
concat_0->LinksFrom({transpose2_1_out_var, concat_0_in_var})
.LinksTo({concat_0_out_var});
assign_0->LinksFrom({concat_0_out_var});
// V path Nodes
auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2");
auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd2 =
pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
auto* concat_1_in_var = pattern->NewNode(concat_1_in_repr())
->AsInput()
->assert_is_op_input("concat");
auto* concat_1 = pattern->NewNode(concat_1_repr())->assert_is_op("concat");
auto* concat_1_out_var = pattern->NewNode(concat_1_out_repr())
->assert_is_op_output("concat")
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto assign_1 = pattern->NewNode(assign_1_repr())->assert_is_op("assign");
// V path Links
matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var})
.LinksTo({matmul2_out_var});
eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var})
.LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
concat_1->LinksFrom({transpose2_2_out_var, concat_1_in_var})
.LinksTo({concat_1_out_var});
assign_1->LinksFrom({concat_1_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, concat_0_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_1_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* concat_k_in_var = pattern
->NewNode(concat_k_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat");
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_v = pattern->NewNode(concat_v_repr())->assert_is_op("concat");
auto* concat_v_out_var = pattern->NewNode(concat_v_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* assign_k = pattern->NewNode(assign_k_repr())->assert_is_op("assign");
auto* assign_v = pattern->NewNode(assign_v_repr())->assert_is_op("assign");
// QKV fused path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
concat_k->LinksFrom({concat_k_in_var, split0_k_out_var})
.LinksTo({concat_k_out_var});
concat_v->LinksFrom({concat_v_in_var, split0_v_out_var})
.LinksTo({concat_v_out_var});
assign_k->LinksFrom({concat_k_out_var});
assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// communication c_identity
auto* c_identity =
pattern->NewNode(c_identity_repr())->assert_is_op("c_identity");
auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr())
->AsIntermediate()
->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X");
c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* concat_k_in_var = pattern
->NewNode(concat_k_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat");
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_v = pattern->NewNode(concat_v_repr())->assert_is_op("concat");
auto* concat_v_out_var = pattern->NewNode(concat_v_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* assign_k = pattern->NewNode(assign_k_repr())->assert_is_op("assign");
auto* assign_v = pattern->NewNode(assign_v_repr())->assert_is_op("assign");
// QKV fused path Links
matmul0->LinksFrom({c_identity_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
concat_k->LinksFrom({concat_k_in_var, split0_k_out_var})
.LinksTo({concat_k_out_var});
concat_v->LinksFrom({concat_v_in_var, split0_v_out_var})
.LinksTo({concat_v_out_var});
assign_k->LinksFrom({concat_k_out_var});
assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* c_allreduce_sum =
pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum");
auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
c_allreduce_sum->LinksFrom({matmul_linear_out_var})
.LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// communication c_identity
auto* ffn_c_identity =
pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity");
auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr())
->assert_is_op("c_allreduce_sum");
auto* ffn_c_allreduce_sum_out_var =
pattern->NewNode(ffn_c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
} // namespace patterns
int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerDecoderPattern fused_multi_transformer_pattern(
pattern, name_scope);
fused_multi_transformer_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// Cache KV use cache_kv in encoder
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_decoder "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer decoder fuse";
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1, matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_out, matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_w, matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1, reshape2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1, transpose2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_0, concat_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_0_out, concat_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_0, assign_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_out, matmul2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2, transpose2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_1, concat_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_1_out, concat_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_1, assign_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attention_output, attention_output, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1, eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_b, eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_out, eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2, eltadd2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_b, eltadd2_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_out, eltadd2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear, dropout_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul1,
matmul2,
matmul0_out,
matmul1_out,
matmul2_out,
eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
concat_0,
concat_1,
concat_0_out,
concat_1_out,
assign_0,
assign_1,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the multi_transformer pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerDecoderPass::FusedMultiTransformerDecoderPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* eltadd0_b,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// Cache KV use cache_kv in encoder
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_decoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer decoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_in, concat_k_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k, concat_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_out, concat_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_in, concat_v_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v, concat_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_out, concat_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_k, assign_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_v, assign_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
eltadd0_b,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
concat_k_in,
concat_k,
concat_k_out,
concat_v_in,
concat_v,
concat_v_out,
assign_k,
assign_v,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerDecoderFuseQKVPass::
FusedMultiTransformerDecoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* c_identity,
Node* matmul0_w,
Node* eltadd0_b,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// Cache KV use cache_kv in encoder
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// parallel ring id
auto* c_identity_op = c_identity->Op();
fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id"));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_decoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer decoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
c_identity, c_identity, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity_out,
c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_in, concat_k_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k, concat_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_out, concat_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_in, concat_v_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v, concat_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_out, concat_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_k, assign_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_v, assign_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity,
ffn_c_identity,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out,
ffn_c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum,
ffn_c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out,
ffn_c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum,
c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out,
c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
c_identity,
matmul0_w,
eltadd0_b,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
c_identity,
c_identity_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
concat_k_in,
concat_k,
concat_k_out,
concat_v_in,
concat_v,
concat_v_out,
assign_k,
assign_v,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
c_allreduce_sum,
c_allreduce_sum_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_c_identity,
ffn_c_identity_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_c_allreduce_sum,
ffn_c_allreduce_sum_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(
Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
}
AddStatis(fusion_count);
}
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_multi_transformer_decoder_pass,
paddle::framework::ir::FusedMultiTransformerDecoderPass);
REGISTER_PASS(fused_multi_transformer_decoder_fuse_qkv_pass,
paddle::framework::ir::FusedMultiTransformerDecoderFuseQKVPass);
REGISTER_PASS(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass,
paddle::framework::ir::MultiDevicesFusedMultiTransformerDecoderFuseQKVPass);
REGISTER_PASS_CAPABILITY(fused_multi_transformer_decoder_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(fused_multi_transformer_decoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerDecoderPattern : public PatternBase {
FusedMultiTransformerDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_multi_transformer_decoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(concat_0_in);
PATTERN_DECL_NODE(concat_0);
PATTERN_DECL_NODE(concat_0_out);
PATTERN_DECL_NODE(assign_0);
PATTERN_DECL_NODE(concat_1_in);
PATTERN_DECL_NODE(concat_1);
PATTERN_DECL_NODE(concat_1_out);
PATTERN_DECL_NODE(assign_1);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
FusedMultiTransformerDecoderFuseQKVPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(
pattern, name_scope, "fused_multi_transformer_decoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
PATTERN_DECL_NODE(concat_k_in)
PATTERN_DECL_NODE(concat_v_in)
PATTERN_DECL_NODE(concat_k)
PATTERN_DECL_NODE(concat_v)
PATTERN_DECL_NODE(concat_k_out)
PATTERN_DECL_NODE(concat_v_out)
PATTERN_DECL_NODE(assign_k)
PATTERN_DECL_NODE(assign_v)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(c_identity);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
PATTERN_DECL_NODE(concat_k_in)
PATTERN_DECL_NODE(concat_v_in)
PATTERN_DECL_NODE(concat_k)
PATTERN_DECL_NODE(concat_v)
PATTERN_DECL_NODE(concat_k_out)
PATTERN_DECL_NODE(concat_v_out)
PATTERN_DECL_NODE(assign_k)
PATTERN_DECL_NODE(assign_v)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
} // namespace patterns
class FusedMultiTransformerDecoderPass : public FusePassBase {
public:
FusedMultiTransformerDecoderPass();
virtual ~FusedMultiTransformerDecoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_decoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class FusedMultiTransformerDecoderFuseQKVPass : public FusePassBase {
public:
FusedMultiTransformerDecoderFuseQKVPass();
virtual ~FusedMultiTransformerDecoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_decoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
: public FusePassBase {
public:
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass();
virtual ~MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_decoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
// MHA: QKV fc
AddVarToScope(param_scope, "weights0", {1024, 1024});
AddVarToScope(param_scope, "weights1", {1024, 1024});
AddVarToScope(param_scope, "weights2", {1024, 1024});
AddVarToScope(param_scope, "bias_0", {1024});
AddVarToScope(param_scope, "bias_1", {1024});
AddVarToScope(param_scope, "bias_2", {1024});
// MHA: QK bias
AddVarToScope(param_scope, "biasqk", {1024});
// MHA: out Linear
AddVarToScope(param_scope, "weights_l", {1024, 1024});
AddVarToScope(param_scope, "bias_l", {1024});
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope(param_scope, "ffn_weights0", {1024, 4096});
AddVarToScope(param_scope, "ffn_weights1", {4096, 1024});
AddVarToScope(param_scope, "ffn_bias_0", {4096});
AddVarToScope(param_scope, "ffn_bias_1", {1024});
return param_scope;
}
TEST(FusedMultiTransformerDecoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_1) concat -> concat_0
// (transpose_2) concat -> concat_2
// (concat_0) assign -> assign_0
// (concat_1) assign -> assign_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* matmul_out_1 =
layers.matmul_v2(ln_out, weights_1, nullptr, false, true);
auto* matmul_out_2 =
layers.matmul_v2(ln_out, weights_2, nullptr, false, true);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
auto* b2 = layers.data("bias_2", {1024}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 =
layers.elementwise_add(matmul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 =
layers.elementwise_add(matmul_out_2, b2, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
auto* cache_k = layers.data("cache_k", {1, 16, 128, 64});
auto* cache_v = layers.data("cache_v", {1, 16, 128, 64});
auto* concat_k = layers.concat({cache_k, transpose_1}, 2);
auto* concat_v = layers.concat({cache_v, transpose_2}, 2);
layers.assign(concat_k);
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(transpose_0, concat_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_decoder_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 72,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 72,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerDecoderPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_decoder_pass"));
}
TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
auto* cache_k = layers.data("cache_k", {1, 16, 128, 64});
auto* cache_v = layers.data("cache_v", {1, 16, 128, 64});
auto* concat_k = layers.concat({cache_k, split_k}, 2);
auto* concat_v = layers.concat({cache_v, split_v}, 2);
layers.assign(concat_k);
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_decoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_decoder_fuse_qkv_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 62,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 62,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv "
"pass, there should be one fused_multi_transformer "
"op, but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerDecoderFuseQKVPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_decoder_fuse_qkv_pass"));
}
TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> c_identity_out
// (c_identity_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (ffn_layer_norm_out) c_identity -> ffn_c_identity_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
auto* c_identity_out = layers.c_identity(ln_out);
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(c_identity_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
auto* cache_k = layers.data("cache_k", {1, 16, 128, 64});
auto* cache_v = layers.data("cache_v", {1, 16, 128, 64});
auto* concat_k = layers.concat({cache_k, split_k}, 2);
auto* concat_v = layers.concat({cache_v, split_v}, 2);
layers.assign(concat_k);
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out);
auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
auto* ffn_c_identity_out = layers.c_identity(ffn_ln_out);
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_c_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO)
<< "get multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass "
"failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 70,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 70,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv "
"multi-devices pass, there should be one "
"fused_multi_transformer op, but the result is %d",
num_fused_nodes_after));
}
TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass,
pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_decoder_pass);
USE_PASS(fused_multi_transformer_decoder_fuse_qkv_pass);
USE_PASS(multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X")
->assert_more([](Node* x) {
if (x->outputs.size() == 3) {
return true;
} else {
return false;
}
});
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// Q path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
// Q path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path Nodes
auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2");
auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd1 =
pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("while")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// K path Links
matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var})
.LinksTo({matmul1_out_var});
eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var})
.LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// V path Nodes
auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2");
auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd2 =
pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// V path Links
matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var})
.LinksTo({matmul2_out_var});
eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var})
.LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({transpose2_1_out_var, transpose2_2_out_var});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
// QKV fused path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// communication c_identity
auto* c_identity =
pattern->NewNode(c_identity_repr())->assert_is_op("c_identity");
auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr())
->AsIntermediate()
->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X");
c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
// QKV fused path Links
matmul0->LinksFrom({c_identity_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* c_allreduce_sum =
pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum");
auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
c_allreduce_sum->LinksFrom({matmul_linear_out_var})
.LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// communication c_identity
auto* ffn_c_identity =
pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity");
auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr())
->assert_is_op("c_allreduce_sum");
auto* ffn_c_allreduce_sum_out_var =
pattern->NewNode(ffn_c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
} // namespace patterns
template <typename T>
inline void QKVWeightsProcess(framework::LoDTensor* wq_tensor,
framework::LoDTensor* wk_tensor,
framework::LoDTensor* wv_tensor,
framework::LoDTensor* bq_tensor,
framework::LoDTensor* bk_tensor,
framework::LoDTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head});
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<T>(platform::CPUPlace());
std::vector<T*> w_vec = {wq_data, wk_data, wv_data};
// Combine the three fc weights together.
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx = l * num_head * dim_head + j * dim_head + k;
tmp_combined_w_data[out_idx] = w_vec[i][in_idx];
}
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<T>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size);
memcpy(
tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(T) * bq_tensor->numel());
}
template <typename T>
inline void QKVWeightsProcessFuseQKV(framework::LoDTensor* qkv_w_tensor,
framework::LoDTensor* qkv_b_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* qkv_w_data = qkv_w_tensor->mutable_data<T>(platform::CPUPlace());
auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
framework::LoDTensor tmp_transpose_w_tensor;
tmp_transpose_w_tensor.Resize(transpose_w_dims);
auto* tmp_transpose_w_data =
tmp_transpose_w_tensor.mutable_data<T>(platform::CPUPlace());
// transpose qkv matmul Y to QKVWeights
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx =
l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k;
tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx];
}
}
}
}
qkv_w_tensor->Resize(transpose_w_dims);
auto* new_transpose_w_data =
qkv_w_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_transpose_w_data,
tmp_transpose_w_data,
sizeof(T) * qkv_w_tensor->numel());
auto* qkv_b_data = qkv_b_tensor->mutable_data<T>(platform::CPUPlace());
auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head});
framework::LoDTensor tmp_transpose_b_tensor;
tmp_transpose_b_tensor.Resize(transpose_b_dims);
auto* tmp_transpose_b_data =
tmp_transpose_b_tensor.mutable_data<T>(platform::CPUPlace());
// transpose qkv elemenwise_add Y to QKVBias
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
int out_idx = i * num_head * dim_head + j * dim_head + k;
int in_idx = j * 3 * dim_head + i * dim_head + k;
tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx];
}
}
}
qkv_b_tensor->Resize({3, num_head, dim_head});
auto* new_transpose_b_data =
qkv_b_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_transpose_b_data,
tmp_transpose_b_data,
sizeof(T) * qkv_b_tensor->numel());
}
int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern(
pattern, name_scope);
fused_multi_transformer_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3);
int dim_embed = num_head * dim_head;
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* wq_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<LoDTensor>();
auto* wk_tensor =
scope->FindVar(matmul1_w->Name())->GetMutable<LoDTensor>();
auto* wv_tensor =
scope->FindVar(matmul2_w->Name())->GetMutable<LoDTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<LoDTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
if (wq_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcess<float>(wq_tensor,
wk_tensor,
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
} else if (wq_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcess<platform::float16>(wq_tensor,
wk_tensor,
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = matmul0_w->Var();
combined_w_desc->SetShape({3, num_head, dim_head, dim_embed});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, num_head, dim_head});
combined_bias_desc->SetPersistable(true);
scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()});
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(wq_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), transpose2_1_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), transpose2_2_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), matmul1_w->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), matmul2_w->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), eltadd1_b->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), eltadd2_b->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(std::remove(std::begin(while_Outs),
std::end(while_Outs),
transpose2_1_out->Name()),
std::end(while_Outs));
while_Outs.erase(std::remove(std::begin(while_Outs),
std::end(while_Outs),
transpose2_2_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(transpose2_1_out, while0);
IR_NODE_UNLINK(transpose2_2_out, while0);
IR_NODE_UNLINK(while0, transpose2_1_out);
IR_NODE_UNLINK(while0, transpose2_2_out);
// unlink KV weight/bias to while after merged into Q weight/bias
IR_NODE_UNLINK(matmul1_w, while0);
IR_NODE_UNLINK(matmul2_w, while0);
IR_NODE_UNLINK(eltadd1_b, while0);
IR_NODE_UNLINK(eltadd2_b, while0);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder pass in "
"op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder fuse";
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1, matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_out, matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_w, matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1, reshape2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1, transpose2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_out, matmul2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2, transpose2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attention_output, attention_output, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(while0, while0, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1, eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_b, eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_out, eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2, eltadd2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_b, eltadd2_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_out, eltadd2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear, dropout_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
while0,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul1,
matmul2,
matmul0_out,
matmul1_out,
matmul2_out,
eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
}
int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerEncoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* eltadd0_b,
Node* split0_k_out,
Node* split0_v_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
int dim_embed = num_head * dim_head;
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* qkv_w_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<LoDTensor>();
auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcessFuseQKV<platform::float16>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
}
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()),
std::end(while_Outs));
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(split0_k_out, while0);
IR_NODE_UNLINK(split0_v_out, while0);
IR_NODE_UNLINK(while0, split0_k_out);
IR_NODE_UNLINK(while0, split0_v_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
while0, while0, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
eltadd0_b,
split0_k_out,
split0_v_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
while0,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerEncoderFuseQKVPass::
FusedMultiTransformerEncoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
}
int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* c_identity,
Node* matmul0_w,
Node* eltadd0_b,
Node* split0_k_out,
Node* split0_v_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* qkv_w_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<LoDTensor>();
auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
int dim_embed = qkv_w_tensor->dims()[0];
if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcessFuseQKV<platform::float16>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
}
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// parallel ring id
auto* c_identity_op = c_identity->Op();
fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()),
std::end(while_Outs));
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(split0_k_out, while0);
IR_NODE_UNLINK(split0_v_out, while0);
IR_NODE_UNLINK(while0, split0_k_out);
IR_NODE_UNLINK(while0, split0_v_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
c_identity, c_identity, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity_out,
c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity,
ffn_c_identity,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out,
ffn_c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum,
ffn_c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out,
ffn_c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum,
c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out,
c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
while0, while0, fused_multi_transformer_fuse_qkv_pattern);
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
c_identity,
matmul0_w,
eltadd0_b,
split0_k_out,
split0_v_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
while0,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
c_identity,
c_identity_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
c_allreduce_sum,
c_allreduce_sum_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_c_identity,
ffn_c_identity_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_c_allreduce_sum,
ffn_c_allreduce_sum_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(
Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
new bool(true));
}
AddStatis(fusion_count);
}
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_multi_transformer_encoder_pass,
paddle::framework::ir::FusedMultiTransformerEncoderPass);
REGISTER_PASS(fused_multi_transformer_encoder_fuse_qkv_pass,
paddle::framework::ir::FusedMultiTransformerEncoderFuseQKVPass);
REGISTER_PASS(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass,
paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderFuseQKVPass);
REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerEncoderPattern : public PatternBase {
FusedMultiTransformerEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_multi_transformer_encoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
FusedMultiTransformerEncoderFuseQKVPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(
pattern, name_scope, "fused_multi_transformer_encoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// while loop
PATTERN_DECL_NODE(while0);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(c_identity);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// while loop
PATTERN_DECL_NODE(while0);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
} // namespace patterns
class FusedMultiTransformerEncoderPass : public FusePassBase {
public:
FusedMultiTransformerEncoderPass();
virtual ~FusedMultiTransformerEncoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_encoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
public:
FusedMultiTransformerEncoderFuseQKVPass();
virtual ~FusedMultiTransformerEncoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_encoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
: public FusePassBase {
public:
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass();
virtual ~MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_encoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
// MHA: QKV fc
AddVarToScope(param_scope, "weights0", {1024, 1024});
AddVarToScope(param_scope, "weights1", {1024, 1024});
AddVarToScope(param_scope, "weights2", {1024, 1024});
AddVarToScope(param_scope, "bias_0", {1024});
AddVarToScope(param_scope, "bias_1", {1024});
AddVarToScope(param_scope, "bias_2", {1024});
// MHA: QK bias
AddVarToScope(param_scope, "biasqk", {1024});
// MHA: out Linear
AddVarToScope(param_scope, "weights_l", {1024, 1024});
AddVarToScope(param_scope, "bias_l", {1024});
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope(param_scope, "ffn_weights0", {1024, 4096});
AddVarToScope(param_scope, "ffn_weights1", {4096, 1024});
AddVarToScope(param_scope, "ffn_bias_0", {4096});
AddVarToScope(param_scope, "ffn_bias_1", {1024});
return param_scope;
}
TEST(FusedMultiTransformerEncoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* matmul_out_1 =
layers.matmul_v2(ln_out, weights_1, nullptr, false, true);
auto* matmul_out_2 =
layers.matmul_v2(ln_out, weights_2, nullptr, false, true);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
auto* b2 = layers.data("bias_2", {1024}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 =
layers.elementwise_add(matmul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 =
layers.elementwise_add(matmul_out_2, b2, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
// Link to decoder while block
layers.while_loop({transpose_1, transpose_2});
// MHA: QK matmul
auto* matmul_qk =
layers.matmul(transpose_0, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_encoder_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 68,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 68,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_encoder_pass"));
}
TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
layers.assign(split_k);
layers.assign(split_v);
// Link to decoder while block
layers.while_loop({split_k, split_v});
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_encoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_encoder_fuse_qkv_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 56,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 56,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv "
"pass, there should be one fused_multi_transformer "
"op, but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerEncoderFuseQKVPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_encoder_fuse_qkv_pass"));
}
TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> c_identity_out
// (c_identity_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (ffn_layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
auto* c_identity_out = layers.c_identity(ln_out);
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(c_identity_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
layers.assign(split_k);
layers.assign(split_v);
// Link to decoder while block
layers.while_loop({split_k, split_v});
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out);
auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
auto* ffn_c_identity_out = layers.c_identity(ffn_ln_out);
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO)
<< "get multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass "
"failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 64,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 64,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv "
"multi-devices pass, there should be one "
"fused_multi_transformer op, but the result is %d",
num_fused_nodes_after));
}
TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible(
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_encoder_pass);
USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass);
USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass);
......@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph,
// avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue;
block = program_pb.add_blocks();
block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex);
if (static_cast<int>(idx) < program_pb.blocks_size()) {
block = program_pb.mutable_blocks(idx);
} else {
block = program_pb.add_blocks();
block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex);
}
GraphToBlock(*graph.GetSubGraph(idx),
block,
sort_kind,
......
......@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
if (graph.Nodes().empty()) return false;
for (auto &node : GraphTraits::DFS(graph)) {
if (node.Name().rfind("__control_var") == 0) continue;
for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) {
VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
......@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const {
// Create Edges
for (const auto &edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto &src = node2dot.at(edge.first);
......@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() {
PDNode *PDNode::assert_is_persistable_var() {
assert_is_var();
asserts_.emplace_back([=](Node *x) { return x->Var()->Persistable(); });
asserts_.emplace_back(
[=](Node *x) { return x->Var() && x->Var()->Persistable(); });
return this;
}
......
......@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase {
a->outputs.push_back(b); \
b->inputs.push_back(a);
// UnLink 2 ir::Nodes from each other.
#define IR_NODE_UNLINK(a, b) \
a->outputs.erase( \
std::remove(std::begin(a->outputs), std::end(a->outputs), b), \
std::end(a->outputs)); \
b->inputs.erase(std::remove(std::begin(b->inputs), std::end(b->inputs), a), \
std::end(b->inputs));
// Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \
op->outputs.push_back(out_var); \
......
......@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
class Scope;
namespace ir {
class Graph;
} // namespace ir
......@@ -35,6 +36,17 @@ namespace paddle {
namespace framework {
namespace ir {
static const char kParamScopeAttr[] = "__param_scope__";
static const std::vector<std::string> support_subgraph_passes = {
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass",
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
};
Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
CheckPrevPass();
......@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const {
true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
applied_ = true;
if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
graph->Get<PassRecorder>(kPassRecorder).insert(Type());
if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(),
support_subgraph_passes.end(),
Type())) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
sub_graph->SetNotOwned<Scope>(
framework::ir::kParamScopeAttr,
&graph->Get<Scope>(framework::ir::kParamScopeAttr));
}
ApplyImpl(sub_graph);
PADDLE_ENFORCE_EQ(
HasCircle(*sub_graph),
false,
platform::errors::InvalidArgument(
"Illegal pass %s. Generated graph shouldn't contain cycle.",
Type()));
PADDLE_ENFORCE_EQ(
VarDescIsConsistency(*sub_graph),
true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
if (!sub_graph->Has(kPassRecorder)) {
sub_graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
sub_graph->Get<PassRecorder>(kPassRecorder).insert(Type());
}
}
applied_ = true;
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded
......
......@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
constexpr char kFusedMultiTransformerEncoderPass[] =
"fused_multi_transformer_encoder_pass_flag";
constexpr char kFusedMultiTransformerDecoderPass[] =
"fused_multi_transformer_decoder_pass_flag";
constexpr char kFusedMultiTransformerEncoderFuseQKVPass[] =
"fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kFusedMultiTransformerDecoderFuseQKVPass[] =
"fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag";
......
......@@ -146,6 +146,12 @@ struct Layers {
return unary_op("relu", x, out);
}
VarDesc* gelu(VarDesc* x, VarDesc* out = nullptr, bool approximate = true) {
AttributeMap attrs;
attrs["approximate"] = approximate;
return unary_op("gelu", x, out, &attrs);
}
VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) {
return unary_op("sigmoid", x, out);
}
......@@ -154,6 +160,20 @@ struct Layers {
return unary_op("tanh", x, out);
}
VarDesc* c_identity(VarDesc* x, VarDesc* out = nullptr, int ring_id = -1) {
AttributeMap attrs;
attrs["ring_id"] = ring_id;
return unary_op("c_identity", x, out, &attrs);
}
VarDesc* c_allreduce_sum(VarDesc* x,
VarDesc* out = nullptr,
int ring_id = -1) {
AttributeMap attrs;
attrs["ring_id"] = ring_id;
return unary_op("c_allreduce_sum", x, out, &attrs);
}
VarDesc* fc(VarDesc* input,
VarDesc* w,
VarDesc* bias,
......@@ -332,6 +352,37 @@ struct Layers {
return outs;
}
std::vector<VarDesc*> split(VarDesc* x, int num_or_section, int axis = 0) {
std::vector<VarDesc*> outs(num_or_section);
for (int i = 0; i < num_or_section; i++) {
outs[i] = lod_tensor(unique_name());
}
std::vector<std::string> out_names(num_or_section);
for (int i = 0; i < num_or_section; i++) {
out_names[i] = outs[i]->Name();
}
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", out_names);
op->SetAttr("num_or_section", num_or_section);
op->SetAttr("axis", axis);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return outs;
}
VarDesc* assign(VarDesc* x) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* matmul(VarDesc* x,
VarDesc* y,
VarDesc* alpha = nullptr,
......@@ -459,6 +510,24 @@ struct Layers {
return out;
}
VarDesc* while_loop(std::vector<VarDesc*> xs, VarDesc* cond = nullptr) {
VarDesc* out = lod_tensor(unique_name());
VarDesc* step_scopes = lod_tensor(unique_name());
if (cond == nullptr) cond = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("while");
std::vector<std::string> xs_names;
for (auto& x : xs) xs_names.emplace_back(x->Name());
op->SetInput("X", xs_names);
op->SetInput("Condition", {cond->Name()});
op->SetOutput("Out", {out->Name()});
op->SetOutput("StepScopes", {step_scopes->Name()});
op->SetAttr("sub_block", {program_.MutableBlock(0)});
op->SetAttr("is_test", true);
return out;
}
void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program.
......@@ -523,7 +592,10 @@ struct Layers {
return var;
}
VarDesc* unary_op(std::string type, VarDesc* x, VarDesc* out = nullptr) {
VarDesc* unary_op(std::string type,
VarDesc* x,
VarDesc* out = nullptr,
const AttributeMap* attrs = nullptr) {
if (!out) {
out = lod_tensor(unique_name());
}
......@@ -531,6 +603,11 @@ struct Layers {
op->SetType(type);
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
if (attrs) {
for (auto& iter : *attrs) {
op->SetAttr(iter.first, iter.second);
}
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
......
......@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle(
} else {
// Normal operators.
for (const Node* node : requires) {
if (!node->Var()) continue;
if (node->Var()->Persistable()) continue;
std::string var = node->Name();
if (!lifecycles->count(var)) {
......@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// between performance and underlying principle.
std::unordered_set<std::string> black_list;
for (auto* node : graph->Nodes()) {
if (node->IsVar() &&
if (node->IsVar() && node->Var() &&
node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
if (!valid_var(node)) {
......@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// Collect tensors from graph.
for (auto* node : graph->Nodes()) {
if (node->IsVar() &&
if (node->IsVar() && node->Var() &&
node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
!black_list.count(node->Var()->Name())) {
......
......@@ -193,22 +193,28 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", //
"multihead_matmul_fuse_pass_v3", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"fused_multi_transformer_encoder_pass", //
"fused_multi_transformer_decoder_pass", //
"fused_multi_transformer_encoder_fuse_qkv_pass", //
"fused_multi_transformer_decoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", //
"multihead_matmul_fuse_pass_v3", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册