未验证 提交 5a2e5179 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add FusedMultiTransformer fuse pass for GPT3 (#45907)


* add fused_multi_transformer_encoder/decoder pass, run GPT-3 success
上级 4dc4d5fc
...@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base) ...@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference) pass_library(multihead_matmul_fuse_pass inference)
pass_library(fused_multi_transformer_encoder_pass inference)
pass_library(fused_multi_transformer_decoder_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference) pass_library(yolo_box_fuse_pass inference)
...@@ -311,6 +313,14 @@ cc_test( ...@@ -311,6 +313,14 @@ cc_test(
test_multihead_matmul_fuse_pass test_multihead_matmul_fuse_pass
SRCS multihead_matmul_fuse_pass_tester.cc SRCS multihead_matmul_fuse_pass_tester.cc
DEPS multihead_matmul_fuse_pass) DEPS multihead_matmul_fuse_pass)
cc_test(
test_fused_multi_transformer_encoder_pass
SRCS fused_multi_transformer_encoder_pass_tester.cc
DEPS fused_multi_transformer_encoder_pass)
cc_test(
test_fused_multi_transformer_decoder_pass
SRCS fused_multi_transformer_decoder_pass_tester.cc
DEPS fused_multi_transformer_decoder_pass)
cc_test( cc_test(
test_conv_bn_fuse_pass_cc test_conv_bn_fuse_pass_cc
SRCS conv_bn_fuse_pass_tester.cc SRCS conv_bn_fuse_pass_tester.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h"
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X")
->assert_more([](Node* x) {
if (x->outputs.size() == 3) {
return true;
} else {
return false;
}
});
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// Q path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
// Q path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path Nodes
auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2");
auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd1 =
pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate();
auto* concat_0_in_var = pattern->NewNode(concat_0_in_repr())->AsInput();
auto* concat_0 = pattern->NewNode(concat_0_repr())->assert_is_op("concat");
auto* concat_0_out_var = pattern->NewNode(concat_0_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("assign");
auto assign_0 = pattern->NewNode(assign_0_repr())->assert_is_op("assign");
// K path Links
matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var})
.LinksTo({matmul1_out_var});
eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var})
.LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
concat_0->LinksFrom({transpose2_1_out_var, concat_0_in_var})
.LinksTo({concat_0_out_var});
assign_0->LinksFrom({concat_0_out_var});
// V path Nodes
auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2");
auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd2 =
pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
auto* concat_1_in_var = pattern->NewNode(concat_1_in_repr())
->AsInput()
->assert_is_op_input("concat");
auto* concat_1 = pattern->NewNode(concat_1_repr())->assert_is_op("concat");
auto* concat_1_out_var = pattern->NewNode(concat_1_out_repr())
->assert_is_op_output("concat")
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto assign_1 = pattern->NewNode(assign_1_repr())->assert_is_op("assign");
// V path Links
matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var})
.LinksTo({matmul2_out_var});
eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var})
.LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
concat_1->LinksFrom({transpose2_2_out_var, concat_1_in_var})
.LinksTo({concat_1_out_var});
assign_1->LinksFrom({concat_1_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, concat_0_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_1_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* concat_k_in_var = pattern
->NewNode(concat_k_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat");
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_v = pattern->NewNode(concat_v_repr())->assert_is_op("concat");
auto* concat_v_out_var = pattern->NewNode(concat_v_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* assign_k = pattern->NewNode(assign_k_repr())->assert_is_op("assign");
auto* assign_v = pattern->NewNode(assign_v_repr())->assert_is_op("assign");
// QKV fused path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
concat_k->LinksFrom({concat_k_in_var, split0_k_out_var})
.LinksTo({concat_k_out_var});
concat_v->LinksFrom({concat_v_in_var, split0_v_out_var})
.LinksTo({concat_v_out_var});
assign_k->LinksFrom({concat_k_out_var});
assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// communication c_identity
auto* c_identity =
pattern->NewNode(c_identity_repr())->assert_is_op("c_identity");
auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr())
->AsIntermediate()
->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X");
c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("concat");
auto* concat_k_in_var = pattern
->NewNode(concat_k_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat");
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
// ->AsInput()
->assert_is_op_input("concat");
auto* concat_v = pattern->NewNode(concat_v_repr())->assert_is_op("concat");
auto* concat_v_out_var = pattern->NewNode(concat_v_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* assign_k = pattern->NewNode(assign_k_repr())->assert_is_op("assign");
auto* assign_v = pattern->NewNode(assign_v_repr())->assert_is_op("assign");
// QKV fused path Links
matmul0->LinksFrom({c_identity_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
concat_k->LinksFrom({concat_k_in_var, split0_k_out_var})
.LinksTo({concat_k_out_var});
concat_v->LinksFrom({concat_v_in_var, split0_v_out_var})
.LinksTo({concat_v_out_var});
assign_k->LinksFrom({concat_k_out_var});
assign_v->LinksFrom({concat_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* c_allreduce_sum =
pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum");
auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
c_allreduce_sum->LinksFrom({matmul_linear_out_var})
.LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// communication c_identity
auto* ffn_c_identity =
pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity");
auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr())
->assert_is_op("c_allreduce_sum");
auto* ffn_c_allreduce_sum_out_var =
pattern->NewNode(ffn_c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
} // namespace patterns
int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerDecoderPattern fused_multi_transformer_pattern(
pattern, name_scope);
fused_multi_transformer_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// Cache KV use cache_kv in encoder
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_decoder "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer decoder fuse";
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1, matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_out, matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_w, matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1, reshape2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1, transpose2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_0, concat_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_0_out, concat_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_0, assign_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_out, matmul2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2, transpose2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_1, concat_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_1_out, concat_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_1, assign_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attention_output, attention_output, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1, eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_b, eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_out, eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2, eltadd2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_b, eltadd2_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_out, eltadd2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear, dropout_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul1,
matmul2,
matmul0_out,
matmul1_out,
matmul2_out,
eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
concat_0,
concat_1,
concat_0_out,
concat_1_out,
assign_0,
assign_1,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the multi_transformer pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerDecoderPass::FusedMultiTransformerDecoderPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* eltadd0_b,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// Cache KV use cache_kv in encoder
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_decoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer decoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_in, concat_k_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k, concat_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_out, concat_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_in, concat_v_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v, concat_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_out, concat_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_k, assign_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_v, assign_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
eltadd0_b,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
concat_k_in,
concat_k,
concat_k_out,
concat_v_in,
concat_v,
concat_v_out,
assign_k,
assign_v,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerDecoderFuseQKVPass::
FusedMultiTransformerDecoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* c_identity,
Node* matmul0_w,
Node* eltadd0_b,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// Cache KV use cache_kv in encoder
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});
VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);
OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);
VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);
OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);
fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv_name});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// parallel ring id
auto* c_identity_op = c_identity->Op();
fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id"));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_decoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer decoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
c_identity, c_identity, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity_out,
c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_in, concat_k_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k, concat_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_k_out, concat_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_in, concat_v_in, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v, concat_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
concat_v_out, concat_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_k, assign_k, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
assign_v, assign_v, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity,
ffn_c_identity,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out,
ffn_c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum,
ffn_c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out,
ffn_c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum,
c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out,
c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
c_identity,
matmul0_w,
eltadd0_b,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
c_identity,
c_identity_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
concat_k_in,
concat_k,
concat_k_out,
concat_v_in,
concat_v,
concat_v_out,
assign_k,
assign_v,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
c_allreduce_sum,
c_allreduce_sum_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_c_identity,
ffn_c_identity_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_c_allreduce_sum,
ffn_c_allreduce_sum_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(
Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the fused_multi_transformer_decoder "
"pass, The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
}
AddStatis(fusion_count);
}
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_multi_transformer_decoder_pass,
paddle::framework::ir::FusedMultiTransformerDecoderPass);
REGISTER_PASS(fused_multi_transformer_decoder_fuse_qkv_pass,
paddle::framework::ir::FusedMultiTransformerDecoderFuseQKVPass);
REGISTER_PASS(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass,
paddle::framework::ir::MultiDevicesFusedMultiTransformerDecoderFuseQKVPass);
REGISTER_PASS_CAPABILITY(fused_multi_transformer_decoder_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(fused_multi_transformer_decoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerDecoderPattern : public PatternBase {
FusedMultiTransformerDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_multi_transformer_decoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(concat_0_in);
PATTERN_DECL_NODE(concat_0);
PATTERN_DECL_NODE(concat_0_out);
PATTERN_DECL_NODE(assign_0);
PATTERN_DECL_NODE(concat_1_in);
PATTERN_DECL_NODE(concat_1);
PATTERN_DECL_NODE(concat_1_out);
PATTERN_DECL_NODE(assign_1);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
FusedMultiTransformerDecoderFuseQKVPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(
pattern, name_scope, "fused_multi_transformer_decoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
PATTERN_DECL_NODE(concat_k_in)
PATTERN_DECL_NODE(concat_v_in)
PATTERN_DECL_NODE(concat_k)
PATTERN_DECL_NODE(concat_v)
PATTERN_DECL_NODE(concat_k_out)
PATTERN_DECL_NODE(concat_v_out)
PATTERN_DECL_NODE(assign_k)
PATTERN_DECL_NODE(assign_v)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(c_identity);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
PATTERN_DECL_NODE(concat_k_in)
PATTERN_DECL_NODE(concat_v_in)
PATTERN_DECL_NODE(concat_k)
PATTERN_DECL_NODE(concat_v)
PATTERN_DECL_NODE(concat_k_out)
PATTERN_DECL_NODE(concat_v_out)
PATTERN_DECL_NODE(assign_k)
PATTERN_DECL_NODE(assign_v)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
} // namespace patterns
class FusedMultiTransformerDecoderPass : public FusePassBase {
public:
FusedMultiTransformerDecoderPass();
virtual ~FusedMultiTransformerDecoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_decoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class FusedMultiTransformerDecoderFuseQKVPass : public FusePassBase {
public:
FusedMultiTransformerDecoderFuseQKVPass();
virtual ~FusedMultiTransformerDecoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_decoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
: public FusePassBase {
public:
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass();
virtual ~MultiDevicesFusedMultiTransformerDecoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_decoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
// MHA: QKV fc
AddVarToScope(param_scope, "weights0", {1024, 1024});
AddVarToScope(param_scope, "weights1", {1024, 1024});
AddVarToScope(param_scope, "weights2", {1024, 1024});
AddVarToScope(param_scope, "bias_0", {1024});
AddVarToScope(param_scope, "bias_1", {1024});
AddVarToScope(param_scope, "bias_2", {1024});
// MHA: QK bias
AddVarToScope(param_scope, "biasqk", {1024});
// MHA: out Linear
AddVarToScope(param_scope, "weights_l", {1024, 1024});
AddVarToScope(param_scope, "bias_l", {1024});
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope(param_scope, "ffn_weights0", {1024, 4096});
AddVarToScope(param_scope, "ffn_weights1", {4096, 1024});
AddVarToScope(param_scope, "ffn_bias_0", {4096});
AddVarToScope(param_scope, "ffn_bias_1", {1024});
return param_scope;
}
TEST(FusedMultiTransformerDecoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_1) concat -> concat_0
// (transpose_2) concat -> concat_2
// (concat_0) assign -> assign_0
// (concat_1) assign -> assign_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* matmul_out_1 =
layers.matmul_v2(ln_out, weights_1, nullptr, false, true);
auto* matmul_out_2 =
layers.matmul_v2(ln_out, weights_2, nullptr, false, true);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
auto* b2 = layers.data("bias_2", {1024}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 =
layers.elementwise_add(matmul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 =
layers.elementwise_add(matmul_out_2, b2, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
auto* cache_k = layers.data("cache_k", {1, 16, 128, 64});
auto* cache_v = layers.data("cache_v", {1, 16, 128, 64});
auto* concat_k = layers.concat({cache_k, transpose_1}, 2);
auto* concat_v = layers.concat({cache_v, transpose_2}, 2);
layers.assign(concat_k);
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(transpose_0, concat_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_decoder_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 72,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 72,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerDecoderPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_decoder_pass"));
}
TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
auto* cache_k = layers.data("cache_k", {1, 16, 128, 64});
auto* cache_v = layers.data("cache_v", {1, 16, 128, 64});
auto* concat_k = layers.concat({cache_k, split_k}, 2);
auto* concat_v = layers.concat({cache_v, split_v}, 2);
layers.assign(concat_k);
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_decoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_decoder_fuse_qkv_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 62,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 62,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv "
"pass, there should be one fused_multi_transformer "
"op, but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerDecoderFuseQKVPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_decoder_fuse_qkv_pass"));
}
TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> c_identity_out
// (c_identity_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (ffn_layer_norm_out) c_identity -> ffn_c_identity_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
auto* c_identity_out = layers.c_identity(ln_out);
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(c_identity_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
auto* cache_k = layers.data("cache_k", {1, 16, 128, 64});
auto* cache_v = layers.data("cache_v", {1, 16, 128, 64});
auto* concat_k = layers.concat({cache_k, split_k}, 2);
auto* concat_v = layers.concat({cache_v, split_v}, 2);
layers.assign(concat_k);
layers.assign(concat_v);
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out);
auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
auto* ffn_c_identity_out = layers.c_identity(ffn_ln_out);
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_c_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO)
<< "get multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass "
"failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 70,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 70,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv "
"multi-devices pass, there should be one "
"fused_multi_transformer op, but the result is %d",
num_fused_nodes_after));
}
TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass,
pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_decoder_pass);
USE_PASS(fused_multi_transformer_decoder_fuse_qkv_pass);
USE_PASS(multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X")
->assert_more([](Node* x) {
if (x->outputs.size() == 3) {
return true;
} else {
return false;
}
});
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// Q path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
// Q path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path Nodes
auto* matmul1 = pattern->NewNode(matmul1_repr())->assert_is_op("matmul_v2");
auto* matmul1_w_var = pattern->NewNode(matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul1_out_var = pattern->NewNode(matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd1 =
pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
auto* eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("while")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// K path Links
matmul1->LinksFrom({layer_norm_out_var, matmul1_w_var})
.LinksTo({matmul1_out_var});
eltadd1->LinksFrom({matmul1_out_var, eltadd1_b_var})
.LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// V path Nodes
auto* matmul2 = pattern->NewNode(matmul2_repr())->assert_is_op("matmul_v2");
auto* matmul2_w_var = pattern->NewNode(matmul2_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul2_out_var = pattern->NewNode(matmul2_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd2 =
pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
auto* eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while")
->assert_more([](Node* x) {
if (x->outputs.size() == 2) {
return true;
} else {
return false;
}
});
// V path Links
matmul2->LinksFrom({layer_norm_out_var, matmul2_w_var})
.LinksTo({matmul2_out_var});
eltadd2->LinksFrom({matmul2_out_var, eltadd2_b_var})
.LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({transpose2_1_out_var, transpose2_2_out_var});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
// QKV fused path Links
matmul0->LinksFrom({layer_norm_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("matmul_v2", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_layer_norm_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
// pre-LayerNorm
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
layer_norm->LinksFrom({input0, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
// communication c_identity
auto* c_identity =
pattern->NewNode(c_identity_repr())->assert_is_op("c_identity");
auto* c_identity_out_var = pattern->NewNode(c_identity_out_repr())
->AsIntermediate()
->assert_is_op_output("c_identity", "Out")
->assert_is_op_input("matmul_v2", "X");
c_identity->LinksFrom({layer_norm_out_var}).LinksTo({c_identity_out_var});
// QKV fused path Nodes
auto* matmul0 = pattern->NewNode(matmul0_repr())->assert_is_op("matmul_v2");
auto* matmul0_w_var = pattern->NewNode(matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul0_out_var = pattern->NewNode(matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd0 =
pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
auto* eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->AsIntermediate()
->assert_is_op_input("split", "X");
auto* split0 = pattern->NewNode(split0_repr())->assert_is_op("split");
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul", "Y")
->assert_is_op_input("while");
auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr())
->assert_is_op_output("split")
->AsOutput()
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_input("while");
// QKV fused path Links
matmul0->LinksFrom({c_identity_out_var, matmul0_w_var})
.LinksTo({matmul0_out_var});
eltadd0->LinksFrom({matmul0_out_var, eltadd0_b_var})
.LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
split0->LinksFrom({transpose2_0_out_var})
.LinksTo({split0_q_out_var, split0_k_out_var, split0_v_out_var});
// while loop
auto* while0 = pattern->NewNode(while0_repr())->assert_is_op("while");
while0->LinksFrom({split0_k_out_var, split0_v_out_var});
// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul_v2");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul_v2");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var =
pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsIntermediate()
->assert_is_op_input("matmul_v2"); // -> out_linear
auto* matmul_linear =
pattern->NewNode(matmul_linear_repr())->assert_is_op("matmul_v2");
auto* matmul_linear_w_var = pattern->NewNode(matmul_linear_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* matmul_linear_out_var = pattern->NewNode(matmul_linear_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* c_allreduce_sum =
pattern->NewNode(c_allreduce_sum_repr())->assert_is_op("c_allreduce_sum");
auto* c_allreduce_sum_out_var = pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_linear =
pattern->NewNode(eltadd_linear_repr())->assert_is_op("elementwise_add");
auto* eltadd_linear_b_var = pattern->NewNode(eltadd_linear_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
auto* attention_output = pattern->NewNode(attention_output_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate();
// QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
matmul_linear->LinksFrom({reshape2_qkv_out_var, matmul_linear_w_var})
.LinksTo({matmul_linear_out_var});
c_allreduce_sum->LinksFrom({matmul_linear_out_var})
.LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output});
// Feed Forward LayerNorm Nodes
auto* ffn_layer_norm =
pattern->NewNode(ffn_layer_norm_repr())->assert_is_op("layer_norm");
auto* ffn_layer_norm_scale_var =
pattern->NewNode(ffn_layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* ffn_layer_norm_bias_var =
pattern->NewNode(ffn_layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* ffn_layer_norm_mean_var =
pattern->NewNode(ffn_layer_norm_mean_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Mean");
auto* ffn_layer_norm_variance_var =
pattern->NewNode(ffn_layer_norm_variance_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Variance");
auto* ffn_layer_norm_out_var = pattern->NewNode(ffn_layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("c_identity", "X");
ffn_layer_norm
->LinksFrom(
{attention_output, ffn_layer_norm_bias_var, ffn_layer_norm_scale_var})
.LinksTo({ffn_layer_norm_out_var,
ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var});
// communication c_identity
auto* ffn_c_identity =
pattern->NewNode(ffn_c_identity_repr())->assert_is_op("c_identity");
auto* ffn_c_identity_out_var = pattern->NewNode(ffn_c_identity_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X");
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul0_out_var = pattern->NewNode(ffn_matmul0_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd0 =
pattern->NewNode(ffn_eltadd0_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd0_b_var = pattern->NewNode(ffn_eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul1_w_var = pattern->NewNode(ffn_matmul1_w_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto* ffn_matmul1_out_var = pattern->NewNode(ffn_matmul1_out_repr())
->assert_is_op_output("matmul_v2")
->AsIntermediate()
->assert_is_op_input("c_allreduce_sum");
// communication c_allreduce_sum
auto* ffn_c_allreduce_sum = pattern->NewNode(ffn_c_allreduce_sum_repr())
->assert_is_op("c_allreduce_sum");
auto* ffn_c_allreduce_sum_out_var =
pattern->NewNode(ffn_c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd1 =
pattern->NewNode(ffn_eltadd1_repr())->assert_is_op("elementwise_add");
auto* ffn_eltadd1_b_var = pattern->NewNode(ffn_eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out =
pattern->NewNode(ffn_eltadd_out_repr())->assert_is_op("elementwise_add");
auto* ffn_output = pattern->NewNode(ffn_output_repr())
->assert_is_op_output("elementwise_add")
->AsOutput();
ffn_matmul0->LinksFrom({ffn_c_identity_out_var, ffn_matmul0_w_var})
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var})
.LinksTo({ffn_output});
return ffn_output;
}
} // namespace patterns
template <typename T>
inline void QKVWeightsProcess(framework::LoDTensor* wq_tensor,
framework::LoDTensor* wk_tensor,
framework::LoDTensor* wv_tensor,
framework::LoDTensor* bq_tensor,
framework::LoDTensor* bk_tensor,
framework::LoDTensor* bv_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* wq_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<T>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<T>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<T>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<T>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<T>(platform::CPUPlace());
auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head});
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<T>(platform::CPUPlace());
std::vector<T*> w_vec = {wq_data, wk_data, wv_data};
// Combine the three fc weights together.
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx = l * num_head * dim_head + j * dim_head + k;
tmp_combined_w_data[out_idx] = w_vec[i][in_idx];
}
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = wq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(
new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel());
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<T>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(T) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, sizeof(T) * bias_size);
memcpy(
tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(T) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(T) * bq_tensor->numel());
}
template <typename T>
inline void QKVWeightsProcessFuseQKV(framework::LoDTensor* qkv_w_tensor,
framework::LoDTensor* qkv_b_tensor,
const int num_head,
const int dim_head,
const int dim_embed) {
auto* qkv_w_data = qkv_w_tensor->mutable_data<T>(platform::CPUPlace());
auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed});
framework::LoDTensor tmp_transpose_w_tensor;
tmp_transpose_w_tensor.Resize(transpose_w_dims);
auto* tmp_transpose_w_data =
tmp_transpose_w_tensor.mutable_data<T>(platform::CPUPlace());
// transpose qkv matmul Y to QKVWeights
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
for (int l = 0; l < dim_embed; l++) {
int out_idx = i * num_head * dim_head * dim_embed +
j * dim_head * dim_embed + k * dim_embed + l;
int in_idx =
l * num_head * 3 * dim_head + j * 3 * dim_head + i * dim_head + k;
tmp_transpose_w_data[out_idx] = qkv_w_data[in_idx];
}
}
}
}
qkv_w_tensor->Resize(transpose_w_dims);
auto* new_transpose_w_data =
qkv_w_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_transpose_w_data,
tmp_transpose_w_data,
sizeof(T) * qkv_w_tensor->numel());
auto* qkv_b_data = qkv_b_tensor->mutable_data<T>(platform::CPUPlace());
auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head});
framework::LoDTensor tmp_transpose_b_tensor;
tmp_transpose_b_tensor.Resize(transpose_b_dims);
auto* tmp_transpose_b_data =
tmp_transpose_b_tensor.mutable_data<T>(platform::CPUPlace());
// transpose qkv elemenwise_add Y to QKVBias
for (int i = 0; i < 3; i++) {
for (int j = 0; j < num_head; j++) {
for (int k = 0; k < dim_head; k++) {
int out_idx = i * num_head * dim_head + j * dim_head + k;
int in_idx = j * 3 * dim_head + i * dim_head + k;
tmp_transpose_b_data[out_idx] = qkv_b_data[in_idx];
}
}
}
qkv_b_tensor->Resize({3, num_head, dim_head});
auto* new_transpose_b_data =
qkv_b_tensor->mutable_data<T>(platform::CPUPlace());
memcpy(new_transpose_b_data,
tmp_transpose_b_data,
sizeof(T) * qkv_b_tensor->numel());
}
int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern(
pattern, name_scope);
fused_multi_transformer_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* matmul1_w,
Node* matmul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* transpose2_1_out,
Node* transpose2_2_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3);
int dim_embed = num_head * dim_head;
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* wq_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<LoDTensor>();
auto* wk_tensor =
scope->FindVar(matmul1_w->Name())->GetMutable<LoDTensor>();
auto* wv_tensor =
scope->FindVar(matmul2_w->Name())->GetMutable<LoDTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<LoDTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
if (wq_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcess<float>(wq_tensor,
wk_tensor,
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
} else if (wq_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcess<platform::float16>(wq_tensor,
wk_tensor,
wv_tensor,
bq_tensor,
bk_tensor,
bv_tensor,
num_head,
dim_head,
dim_embed);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = matmul0_w->Var();
combined_w_desc->SetShape({3, num_head, dim_head, dim_embed});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, num_head, dim_head});
combined_bias_desc->SetPersistable(true);
scope->EraseVars({matmul1_w->Name(), matmul2_w->Name()});
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(wq_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), transpose2_1_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), transpose2_2_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), matmul1_w->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), matmul2_w->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), eltadd1_b->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), eltadd2_b->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(std::remove(std::begin(while_Outs),
std::end(while_Outs),
transpose2_1_out->Name()),
std::end(while_Outs));
while_Outs.erase(std::remove(std::begin(while_Outs),
std::end(while_Outs),
transpose2_2_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(transpose2_1_out, while0);
IR_NODE_UNLINK(transpose2_2_out, while0);
IR_NODE_UNLINK(while0, transpose2_1_out);
IR_NODE_UNLINK(while0, transpose2_2_out);
// unlink KV weight/bias to while after merged into Q weight/bias
IR_NODE_UNLINK(matmul1_w, while0);
IR_NODE_UNLINK(matmul2_w, while0);
IR_NODE_UNLINK(eltadd1_b, while0);
IR_NODE_UNLINK(eltadd2_b, while0);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder pass in "
"op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder fuse";
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1, matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_out, matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul1_w, matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1, reshape2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1, transpose2_1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2, matmul2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_out, matmul2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul2_w, matmul2_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2, reshape2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2, transpose2_2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attention_output, attention_output, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(while0, while0, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_layer_norm, ffn_layer_norm, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_out, ffn_matmul0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_out, ffn_matmul1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1, eltadd1, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_b, eltadd1_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd1_out, eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2, eltadd2, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_b, eltadd2_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd2_out, eltadd2_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_w, matmul_linear_w, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear_out, matmul_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear, dropout_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
matmul1_w,
matmul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
transpose2_1_out,
transpose2_2_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
while0,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul1,
matmul2,
matmul0_out,
matmul1_out,
matmul2_out,
eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multi_transformer pass, The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerEncoderPass::FusedMultiTransformerEncoderPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
}
int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FusedMultiTransformerEncoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* matmul0_w,
Node* eltadd0_b,
Node* split0_k_out,
Node* split0_v_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
int dim_embed = num_head * dim_head;
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* qkv_w_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<LoDTensor>();
auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcessFuseQKV<platform::float16>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
}
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()),
std::end(while_Outs));
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(split0_k_out, while0);
IR_NODE_UNLINK(split0_v_out, while0);
IR_NODE_UNLINK(while0, split0_k_out);
IR_NODE_UNLINK(while0, split0_v_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
while0, while0, fused_multi_transformer_fuse_qkv_pattern)
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
matmul0_w,
eltadd0_b,
split0_k_out,
split0_v_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
while0,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
}
AddStatis(fusion_count);
}
FusedMultiTransformerEncoderFuseQKVPass::
FusedMultiTransformerEncoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
}
int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Graph* graph, const std::string& name_scope, Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope);
fused_multi_transformer_fuse_qkv_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0,
Node* layer_norm,
Node* layer_norm_scale,
Node* layer_norm_bias,
Node* layer_norm_mean,
Node* layer_norm_variance,
Node* c_identity,
Node* matmul0_w,
Node* eltadd0_b,
Node* split0_k_out,
Node* split0_v_out,
Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0,
Node* matmul_linear_w,
Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0,
Node* ffn_layer_norm,
Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias,
Node* ffn_layer_norm_mean,
Node* ffn_layer_norm_variance,
Node* ffn_matmul0_w,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto ln_scale_name = layer_norm_scale->Name();
auto ln_name = ln_scale_name.substr(0, ln_scale_name.find('.'));
auto ln_idx_str = ln_name.substr(ln_name.rfind('_') + 1);
int layer_idx = atoi(ln_idx_str.c_str()) / 2;
auto* qkv_w_tensor =
scope->FindVar(matmul0_w->Name())->GetMutable<LoDTensor>();
auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
int dim_embed = qkv_w_tensor->dims()[0];
if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) {
QKVWeightsProcessFuseQKV<float>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) {
QKVWeightsProcessFuseQKV<platform::float16>(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."));
}
// create fused_multi_transformer
OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block());
fused_multi_transformer_op_desc.SetType("fused_multi_transformer");
// 1. Input setting
fused_multi_transformer_op_desc.SetInput("X", {input0->Name()});
// pre-LayerNorm input
fused_multi_transformer_op_desc.SetInput("LnScale",
{layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("LnBias",
{layer_norm_bias->Name()});
// QKV computation input
fused_multi_transformer_op_desc.SetInput("QKVW", {matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("QKVBias", {eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("SrcMask", {eltadd_qk_b->Name()});
// CacheKV input
VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_tensor->dtype()));
cache_kv_desc.SetPersistable(false);
auto* cache_kv = graph->CreateVarNode(&cache_kv_desc);
OpDesc fill_const_op_desc(layer_norm->Op()->Block());
fill_const_op_desc.SetType("fill_constant_batch_size_like");
fill_const_op_desc.SetInput("Input", {input0->Name()});
fill_const_op_desc.SetOutput("Out", {cache_kv->Name()});
std::vector<int> shape = {2, -1, num_head, 1024, dim_head};
fill_const_op_desc.SetAttr("shape", shape);
fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
{matmul_linear_w->Name()});
fused_multi_transformer_op_desc.SetInput("OutLinearBias",
{eltadd_linear_b->Name()});
// Feed Forward input
fused_multi_transformer_op_desc.SetInput("FFNLnScale",
{ffn_layer_norm_scale->Name()});
fused_multi_transformer_op_desc.SetInput("FFNLnBias",
{ffn_layer_norm_bias->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Weight",
{ffn_matmul0_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN1Bias",
{ffn_eltadd0_b->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Weight",
{ffn_matmul1_w->Name()});
fused_multi_transformer_op_desc.SetInput("FFN2Bias",
{ffn_eltadd1_b->Name()});
// 2. Output setting
fused_multi_transformer_op_desc.SetOutput("Out", {ffn_output->Name()});
fused_multi_transformer_op_desc.SetOutput("CacheKVOut", {cache_kv->Name()});
// Attribute setting
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute
auto* dropout_op = dropout_linear->Op();
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// parallel ring id
auto* c_identity_op = c_identity->Op();
fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id"));
auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);
IR_NODE_LINK_TO(input0, fill_const_op);
IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto while_Xs = while0->Op()->Input("X");
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_k_out->Name()),
std::end(while_Xs));
while_Xs.erase(
std::remove(
std::begin(while_Xs), std::end(while_Xs), split0_v_out->Name()),
std::end(while_Xs));
while_Xs.emplace_back(cache_kv->Name());
while0->Op()->SetInput("X", while_Xs);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto while_Outs = while0->Op()->Output("Out");
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_k_out->Name()),
std::end(while_Outs));
while_Outs.erase(
std::remove(
std::begin(while_Outs), std::end(while_Outs), split0_v_out->Name()),
std::end(while_Outs));
while_Outs.emplace_back(cache_kv->Name());
while0->Op()->SetOutput("Out", while_Outs);
// link CacheKV to while
IR_NODE_LINK_TO(cache_kv, while0)
// unlink origin KV output to while
IR_NODE_UNLINK(split0_k_out, while0);
IR_NODE_UNLINK(split0_v_out, while0);
IR_NODE_UNLINK(while0, split0_k_out);
IR_NODE_UNLINK(while0, split0_v_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "fused_multi_transformer_encoder_fuse_qkv "
"pass in op compat failed.";
return;
}
VLOG(4) << "handle MultiTransformer encoder(Fuse-QKV) fuse";
GET_IR_NODE_FROM_SUBGRAPH(
input0, input0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm, layer_norm, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale,
layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias,
layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean,
layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance,
layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out,
layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
c_identity, c_identity, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_identity_out,
c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0, matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_out, matmul0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul0_w, matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0, reshape2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out,
reshape2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0, transpose2_0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out,
transpose2_0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0, split0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_q_out, split0_q_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_k_out, split0_k_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
split0_v_out, split0_v_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm,
ffn_layer_norm,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_scale,
ffn_layer_norm_scale,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_bias,
ffn_layer_norm_bias,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_mean,
ffn_layer_norm_mean,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_variance,
ffn_layer_norm_variance,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_layer_norm_out,
ffn_layer_norm_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity,
ffn_c_identity,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_identity_out,
ffn_c_identity_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0, ffn_matmul0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul0_out,
ffn_matmul0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul0_w, ffn_matmul0_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0, ffn_eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd0_b, ffn_eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd0_out,
ffn_eltadd0_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_matmul1_out,
ffn_matmul1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1_w, ffn_matmul1_w, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum,
ffn_c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_c_allreduce_sum_out,
ffn_c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1, ffn_eltadd1, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_b, ffn_eltadd1_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd1_out,
ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_output, ffn_output, fused_multi_transformer_fuse_qkv_pattern)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0, eltadd0, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_b, eltadd0_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd0_out, eltadd0_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk, matmul_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_b, eltadd_qk_b, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk_out, eltadd_qk_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk, softmax_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out,
matmul_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv, reshape2_qkv, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out,
reshape2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv,
transpose2_qkv,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out,
transpose2_qkv_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_linear, matmul_linear, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_w,
matmul_linear_w,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_linear_out,
matmul_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum,
c_allreduce_sum,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(c_allreduce_sum_out,
c_allreduce_sum_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear, eltadd_linear, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_b,
eltadd_linear_b,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
while0, while0, fused_multi_transformer_fuse_qkv_pattern);
fuse_creater(input0,
layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
c_identity,
matmul0_w,
eltadd0_b,
split0_k_out,
split0_v_out,
eltadd_qk_b,
dropout_qk,
reshape2_0,
matmul_linear_w,
eltadd_linear_b,
dropout_linear,
while0,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_matmul0_w,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_dropout,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean,
layer_norm_variance,
layer_norm_out,
c_identity,
c_identity_out,
matmul0,
matmul0_out,
eltadd0,
eltadd0_out,
reshape2_0,
reshape2_0_out,
transpose2_0,
transpose2_0_out,
split0,
split0_q_out,
split0_k_out,
split0_v_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
reshape2_qkv,
transpose2_qkv,
transpose2_qkv_out,
matmul_linear,
matmul_linear_w,
matmul_linear_out,
c_allreduce_sum,
c_allreduce_sum_out,
eltadd_linear,
eltadd_linear_b,
eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out,
ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean,
ffn_layer_norm_variance,
ffn_layer_norm_out,
ffn_c_identity,
ffn_c_identity_out,
ffn_matmul0,
ffn_matmul1,
ffn_matmul0_out,
ffn_matmul1_out,
ffn_c_allreduce_sum,
ffn_c_allreduce_sum_out,
ffn_eltadd0,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(
Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
new bool(true));
}
AddStatis(fusion_count);
}
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("while"))
.AddInput("X") // A set of variables, unconstrained
.End()
.AddInput("Condition") // An scalar
.IsTensor()
.End()
.AddOutput("Out") // A set of variables, unconstrained
.End()
.AddOutput("StepScopes") // A vector of local scope, unconstrained
.End()
.AddAttr("sub_block")
.IsType<framework::BlockDesc*>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_multi_transformer_encoder_pass,
paddle::framework::ir::FusedMultiTransformerEncoderPass);
REGISTER_PASS(fused_multi_transformer_encoder_fuse_qkv_pass,
paddle::framework::ir::FusedMultiTransformerEncoderFuseQKVPass);
REGISTER_PASS(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass,
paddle::framework::ir::MultiDevicesFusedMultiTransformerEncoderFuseQKVPass);
REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(fused_multi_transformer_encoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerEncoderPattern : public PatternBase {
FusedMultiTransformerEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_multi_transformer_encoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
FusedMultiTransformerEncoderFuseQKVPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(
pattern, name_scope, "fused_multi_transformer_encoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// while loop
PATTERN_DECL_NODE(while0);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(c_identity);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(split0)
PATTERN_DECL_NODE(split0_q_out)
PATTERN_DECL_NODE(split0_k_out)
PATTERN_DECL_NODE(split0_v_out)
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// while loop
PATTERN_DECL_NODE(while0);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
} // namespace patterns
class FusedMultiTransformerEncoderPass : public FusePassBase {
public:
FusedMultiTransformerEncoderPass();
virtual ~FusedMultiTransformerEncoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_encoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
public:
FusedMultiTransformerEncoderFuseQKVPass();
virtual ~FusedMultiTransformerEncoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_multi_transformer_encoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
: public FusePassBase {
public:
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass();
virtual ~MultiDevicesFusedMultiTransformerEncoderFuseQKVPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_encoder_fuse_qkv"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
// MHA: QKV fc
AddVarToScope(param_scope, "weights0", {1024, 1024});
AddVarToScope(param_scope, "weights1", {1024, 1024});
AddVarToScope(param_scope, "weights2", {1024, 1024});
AddVarToScope(param_scope, "bias_0", {1024});
AddVarToScope(param_scope, "bias_1", {1024});
AddVarToScope(param_scope, "bias_2", {1024});
// MHA: QK bias
AddVarToScope(param_scope, "biasqk", {1024});
// MHA: out Linear
AddVarToScope(param_scope, "weights_l", {1024, 1024});
AddVarToScope(param_scope, "bias_l", {1024});
// MHA: pre Layer Norm
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope(param_scope, "ffn_weights0", {1024, 4096});
AddVarToScope(param_scope, "ffn_weights1", {4096, 1024});
AddVarToScope(param_scope, "ffn_bias_0", {4096});
AddVarToScope(param_scope, "ffn_bias_1", {1024});
return param_scope;
}
TEST(FusedMultiTransformerEncoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* matmul_out_1 =
layers.matmul_v2(ln_out, weights_1, nullptr, false, true);
auto* matmul_out_2 =
layers.matmul_v2(ln_out, weights_2, nullptr, false, true);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
auto* b2 = layers.data("bias_2", {1024}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 =
layers.elementwise_add(matmul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 =
layers.elementwise_add(matmul_out_2, b2, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
// Link to decoder while block
layers.while_loop({transpose_1, transpose_2});
// MHA: QK matmul
auto* matmul_qk =
layers.matmul(transpose_0, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass =
PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_encoder_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 68,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 68,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_encoder_pass"));
}
TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
layers.assign(split_k);
layers.assign(split_v);
// Link to decoder while block
layers.while_loop({split_k, split_v});
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_encoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fused_multi_transformer_encoder_fuse_qkv_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 56,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 56,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv "
"pass, there should be one fused_multi_transformer "
"op, but the result is %d",
num_fused_nodes_after));
}
TEST(FusedMultiTransformerEncoderFuseQKVPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("fused_multi_transformer_encoder_fuse_qkv_pass"));
}
TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> c_identity_out
// (c_identity_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (ffn_layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
auto* c_identity_out = layers.c_identity(ln_out);
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 3072}, true);
auto* matmul_out_0 =
layers.matmul_v2(c_identity_out, weights_0, nullptr, false, true);
auto* b0 = layers.data("bias_0", {3072}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto split_outs = layers.split(transpose_0, 3, 3);
auto* split_q = split_outs[0];
auto* split_k = split_outs[1];
auto* split_v = split_outs[2];
layers.assign(split_k);
layers.assign(split_v);
// Link to decoder while block
layers.while_loop({split_k, split_v});
// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out);
auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv =
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
auto* ffn_c_identity_out = layers.c_identity(ffn_ln_out);
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, true);
auto* ffn_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass");
if (pass.get() == nullptr)
LOG(INFO)
<< "get multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass "
"failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 64,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 64,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv "
"multi-devices pass, there should be one "
"fused_multi_transformer op, but the result is %d",
num_fused_nodes_after));
}
TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible(
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_encoder_pass);
USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass);
USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass);
...@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph, ...@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph,
// avoid kRootBlockIndex not 0 // avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue; if (idx == kRootBlockIndex) continue;
block = program_pb.add_blocks(); if (static_cast<int>(idx) < program_pb.blocks_size()) {
block->set_idx(idx); block = program_pb.mutable_blocks(idx);
block->set_parent_idx(kRootBlockIndex); } else {
block = program_pb.add_blocks();
block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex);
}
GraphToBlock(*graph.GetSubGraph(idx), GraphToBlock(*graph.GetSubGraph(idx),
block, block,
sort_kind, sort_kind,
......
...@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { ...@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
for (auto &node : GraphTraits::DFS(graph)) { for (auto &node : GraphTraits::DFS(graph)) {
if (node.Name().rfind("__control_var") == 0) continue;
for (const auto &pdnode : pattern_.nodes()) { for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) { if (pdnode->Tell(&node)) {
VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name(); VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
...@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const { ...@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const {
// Create Edges // Create Edges
for (const auto &edge : edges()) { for (const auto &edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue; continue;
} }
auto &src = node2dot.at(edge.first); auto &src = node2dot.at(edge.first);
...@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() { ...@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() {
PDNode *PDNode::assert_is_persistable_var() { PDNode *PDNode::assert_is_persistable_var() {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node *x) { return x->Var()->Persistable(); }); asserts_.emplace_back(
[=](Node *x) { return x->Var() && x->Var()->Persistable(); });
return this; return this;
} }
......
...@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase { ...@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase {
a->outputs.push_back(b); \ a->outputs.push_back(b); \
b->inputs.push_back(a); b->inputs.push_back(a);
// UnLink 2 ir::Nodes from each other.
#define IR_NODE_UNLINK(a, b) \
a->outputs.erase( \
std::remove(std::begin(a->outputs), std::end(a->outputs), b), \
std::end(a->outputs)); \
b->inputs.erase(std::remove(std::begin(b->inputs), std::end(b->inputs), a), \
std::end(b->inputs));
// Set the out_var as the output of the op // Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \ #define IR_OP_VAR_LINK(op, out_var) \
op->outputs.push_back(out_var); \ op->outputs.push_back(out_var); \
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Scope;
namespace ir { namespace ir {
class Graph; class Graph;
} // namespace ir } // namespace ir
...@@ -35,6 +36,17 @@ namespace paddle { ...@@ -35,6 +36,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kParamScopeAttr[] = "__param_scope__";
static const std::vector<std::string> support_subgraph_passes = {
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass",
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph"; VLOG(10) << "start to apply pass " << Type() << " to graph";
CheckPrevPass(); CheckPrevPass();
...@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const { ...@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency.")); "The VarDescs of persistable variable are not consistency."));
applied_ = true;
if (!graph->Has(kPassRecorder)) { if (!graph->Has(kPassRecorder)) {
graph->Set<PassRecorder>(kPassRecorder, new PassRecorder); graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
} }
graph->Get<PassRecorder>(kPassRecorder).insert(Type()); graph->Get<PassRecorder>(kPassRecorder).insert(Type());
if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(),
support_subgraph_passes.end(),
Type())) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
sub_graph->SetNotOwned<Scope>(
framework::ir::kParamScopeAttr,
&graph->Get<Scope>(framework::ir::kParamScopeAttr));
}
ApplyImpl(sub_graph);
PADDLE_ENFORCE_EQ(
HasCircle(*sub_graph),
false,
platform::errors::InvalidArgument(
"Illegal pass %s. Generated graph shouldn't contain cycle.",
Type()));
PADDLE_ENFORCE_EQ(
VarDescIsConsistency(*sub_graph),
true,
platform::errors::InvalidArgument(
"The VarDescs of persistable variable are not consistency."));
if (!sub_graph->Has(kPassRecorder)) {
sub_graph->Set<PassRecorder>(kPassRecorder, new PassRecorder);
}
sub_graph->Get<PassRecorder>(kPassRecorder).insert(Type());
}
}
applied_ = true;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded // Passes can change params, tensors, so caching need to be discarded
......
...@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder"; ...@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] = constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag"; "embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
constexpr char kFusedMultiTransformerEncoderPass[] =
"fused_multi_transformer_encoder_pass_flag";
constexpr char kFusedMultiTransformerDecoderPass[] =
"fused_multi_transformer_decoder_pass_flag";
constexpr char kFusedMultiTransformerEncoderFuseQKVPass[] =
"fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kFusedMultiTransformerDecoderFuseQKVPass[] =
"fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kPrelnEmbEltwiseLayernormPass[] = constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag"; "preln_embedding_eltwise_layernorm_fuse_pass_flag";
......
...@@ -146,6 +146,12 @@ struct Layers { ...@@ -146,6 +146,12 @@ struct Layers {
return unary_op("relu", x, out); return unary_op("relu", x, out);
} }
VarDesc* gelu(VarDesc* x, VarDesc* out = nullptr, bool approximate = true) {
AttributeMap attrs;
attrs["approximate"] = approximate;
return unary_op("gelu", x, out, &attrs);
}
VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) { VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) {
return unary_op("sigmoid", x, out); return unary_op("sigmoid", x, out);
} }
...@@ -154,6 +160,20 @@ struct Layers { ...@@ -154,6 +160,20 @@ struct Layers {
return unary_op("tanh", x, out); return unary_op("tanh", x, out);
} }
VarDesc* c_identity(VarDesc* x, VarDesc* out = nullptr, int ring_id = -1) {
AttributeMap attrs;
attrs["ring_id"] = ring_id;
return unary_op("c_identity", x, out, &attrs);
}
VarDesc* c_allreduce_sum(VarDesc* x,
VarDesc* out = nullptr,
int ring_id = -1) {
AttributeMap attrs;
attrs["ring_id"] = ring_id;
return unary_op("c_allreduce_sum", x, out, &attrs);
}
VarDesc* fc(VarDesc* input, VarDesc* fc(VarDesc* input,
VarDesc* w, VarDesc* w,
VarDesc* bias, VarDesc* bias,
...@@ -332,6 +352,37 @@ struct Layers { ...@@ -332,6 +352,37 @@ struct Layers {
return outs; return outs;
} }
std::vector<VarDesc*> split(VarDesc* x, int num_or_section, int axis = 0) {
std::vector<VarDesc*> outs(num_or_section);
for (int i = 0; i < num_or_section; i++) {
outs[i] = lod_tensor(unique_name());
}
std::vector<std::string> out_names(num_or_section);
for (int i = 0; i < num_or_section; i++) {
out_names[i] = outs[i]->Name();
}
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", out_names);
op->SetAttr("num_or_section", num_or_section);
op->SetAttr("axis", axis);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return outs;
}
VarDesc* assign(VarDesc* x) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* matmul(VarDesc* x, VarDesc* matmul(VarDesc* x,
VarDesc* y, VarDesc* y,
VarDesc* alpha = nullptr, VarDesc* alpha = nullptr,
...@@ -459,6 +510,24 @@ struct Layers { ...@@ -459,6 +510,24 @@ struct Layers {
return out; return out;
} }
VarDesc* while_loop(std::vector<VarDesc*> xs, VarDesc* cond = nullptr) {
VarDesc* out = lod_tensor(unique_name());
VarDesc* step_scopes = lod_tensor(unique_name());
if (cond == nullptr) cond = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("while");
std::vector<std::string> xs_names;
for (auto& x : xs) xs_names.emplace_back(x->Name());
op->SetInput("X", xs_names);
op->SetInput("Condition", {cond->Name()});
op->SetOutput("Out", {out->Name()});
op->SetOutput("StepScopes", {step_scopes->Name()});
op->SetAttr("sub_block", {program_.MutableBlock(0)});
op->SetAttr("is_test", true);
return out;
}
void backward(std::vector<VarDesc*> targets) { void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program, // This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program. // but is constructed differently as the actual program.
...@@ -523,7 +592,10 @@ struct Layers { ...@@ -523,7 +592,10 @@ struct Layers {
return var; return var;
} }
VarDesc* unary_op(std::string type, VarDesc* x, VarDesc* out = nullptr) { VarDesc* unary_op(std::string type,
VarDesc* x,
VarDesc* out = nullptr,
const AttributeMap* attrs = nullptr) {
if (!out) { if (!out) {
out = lod_tensor(unique_name()); out = lod_tensor(unique_name());
} }
...@@ -531,6 +603,11 @@ struct Layers { ...@@ -531,6 +603,11 @@ struct Layers {
op->SetType(type); op->SetType(type);
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
if (attrs) {
for (auto& iter : *attrs) {
op->SetAttr(iter.first, iter.second);
}
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
return out; return out;
......
...@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle( ...@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle(
} else { } else {
// Normal operators. // Normal operators.
for (const Node* node : requires) { for (const Node* node : requires) {
if (!node->Var()) continue;
if (node->Var()->Persistable()) continue; if (node->Var()->Persistable()) continue;
std::string var = node->Name(); std::string var = node->Name();
if (!lifecycles->count(var)) { if (!lifecycles->count(var)) {
...@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize( ...@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// between performance and underlying principle. // between performance and underlying principle.
std::unordered_set<std::string> black_list; std::unordered_set<std::string> black_list;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsVar() && if (node->IsVar() && node->Var() &&
node->Var()->GetType() == node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) { framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
if (!valid_var(node)) { if (!valid_var(node)) {
...@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize( ...@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// Collect tensors from graph. // Collect tensors from graph.
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsVar() && if (node->IsVar() && node->Var() &&
node->Var()->GetType() == node->Var()->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR && framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
!black_list.count(node->Var()->Name())) { !black_list.count(node->Var()->Name())) {
......
...@@ -193,22 +193,28 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{ ...@@ -193,22 +193,28 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
// "identity_scale_op_clean_pass", // // "identity_scale_op_clean_pass", //
"is_test_pass", // "is_test_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "multihead_matmul_fuse_pass_v2", //
"gpu_cpu_squeeze2_matmul_fuse_pass", // "fused_multi_transformer_encoder_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", // "fused_multi_transformer_decoder_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", // "fused_multi_transformer_encoder_fuse_qkv_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", // "fused_multi_transformer_decoder_fuse_qkv_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", // "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"matmul_scale_fuse_pass", // "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"multihead_matmul_fuse_pass_v3", // "gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_to_mul_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", //
"fc_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", //
"multihead_matmul_fuse_pass_v3", //
"gpu_cpu_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7 // guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we // cudnn8.0 has memory leak problem in conv + eltwise + act, so we
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册