Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5a2e5179
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5a2e5179
编写于
10月 20, 2022
作者:
K
Kaipeng Deng
提交者:
GitHub
10月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add FusedMultiTransformer fuse pass for GPT3 (#45907)
* add fused_multi_transformer_encoder/decoder pass, run GPT-3 success
上级
4dc4d5fc
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
8802 addition
and
25 deletion
+8802
-25
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+10
-0
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
...luid/framework/ir/fused_multi_transformer_decoder_pass.cc
+3214
-0
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
...fluid/framework/ir/fused_multi_transformer_decoder_pass.h
+416
-0
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc
...amework/ir/fused_multi_transformer_decoder_pass_tester.cc
+576
-0
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
...luid/framework/ir/fused_multi_transformer_encoder_pass.cc
+3448
-0
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
...fluid/framework/ir/fused_multi_transformer_encoder_pass.h
+398
-0
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
...amework/ir/fused_multi_transformer_encoder_pass_tester.cc
+563
-0
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+8
-3
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+3
-2
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+8
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+43
-1
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+12
-0
paddle/fluid/framework/ir/pass_tester_helper.h
paddle/fluid/framework/ir/pass_tester_helper.h
+78
-1
paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
...e/fluid/inference/analysis/passes/memory_optimize_pass.cc
+3
-2
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+22
-16
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
5a2e5179
...
@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base)
...
@@ -105,6 +105,8 @@ pass_library(simplify_with_basic_ops_pass base)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
skip_layernorm_fuse_pass base
)
pass_library
(
skip_layernorm_fuse_pass base
)
pass_library
(
multihead_matmul_fuse_pass inference
)
pass_library
(
multihead_matmul_fuse_pass inference
)
pass_library
(
fused_multi_transformer_encoder_pass inference
)
pass_library
(
fused_multi_transformer_decoder_pass inference
)
pass_library
(
adaptive_pool2d_convert_global_pass inference
)
pass_library
(
adaptive_pool2d_convert_global_pass inference
)
pass_library
(
unsqueeze2_eltwise_fuse_pass inference
)
pass_library
(
unsqueeze2_eltwise_fuse_pass inference
)
pass_library
(
yolo_box_fuse_pass inference
)
pass_library
(
yolo_box_fuse_pass inference
)
...
@@ -311,6 +313,14 @@ cc_test(
...
@@ -311,6 +313,14 @@ cc_test(
test_multihead_matmul_fuse_pass
test_multihead_matmul_fuse_pass
SRCS multihead_matmul_fuse_pass_tester.cc
SRCS multihead_matmul_fuse_pass_tester.cc
DEPS multihead_matmul_fuse_pass
)
DEPS multihead_matmul_fuse_pass
)
cc_test
(
test_fused_multi_transformer_encoder_pass
SRCS fused_multi_transformer_encoder_pass_tester.cc
DEPS fused_multi_transformer_encoder_pass
)
cc_test
(
test_fused_multi_transformer_decoder_pass
SRCS fused_multi_transformer_decoder_pass_tester.cc
DEPS fused_multi_transformer_decoder_pass
)
cc_test
(
cc_test
(
test_conv_bn_fuse_pass_cc
test_conv_bn_fuse_pass_cc
SRCS conv_bn_fuse_pass_tester.cc
SRCS conv_bn_fuse_pass_tester.cc
...
...
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
0 → 100644
浏览文件 @
5a2e5179
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h"
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
PDNode
*
FusedMultiTransformerDecoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
3
)
{
return
true
;
}
else
{
return
false
;
}
});
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// Q path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
// Q path Links
matmul0
->
LinksFrom
({
layer_norm_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
// K path Nodes
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul1_w_var
=
pattern
->
NewNode
(
matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul1_out_var
=
pattern
->
NewNode
(
matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd1
=
pattern
->
NewNode
(
eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd1_b_var
=
pattern
->
NewNode
(
eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_1
=
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
();
auto
*
concat_0_in_var
=
pattern
->
NewNode
(
concat_0_in_repr
())
->
AsInput
();
auto
*
concat_0
=
pattern
->
NewNode
(
concat_0_repr
())
->
assert_is_op
(
"concat"
);
auto
*
concat_0_out_var
=
pattern
->
NewNode
(
concat_0_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
)
->
assert_is_op_input
(
"assign"
);
auto
assign_0
=
pattern
->
NewNode
(
assign_0_repr
())
->
assert_is_op
(
"assign"
);
// K path Links
matmul1
->
LinksFrom
({
layer_norm_out_var
,
matmul1_w_var
})
.
LinksTo
({
matmul1_out_var
});
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
.
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
concat_0
->
LinksFrom
({
transpose2_1_out_var
,
concat_0_in_var
})
.
LinksTo
({
concat_0_out_var
});
assign_0
->
LinksFrom
({
concat_0_out_var
});
// V path Nodes
auto
*
matmul2
=
pattern
->
NewNode
(
matmul2_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul2_w_var
=
pattern
->
NewNode
(
matmul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul2_out_var
=
pattern
->
NewNode
(
matmul2_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
auto
*
concat_1_in_var
=
pattern
->
NewNode
(
concat_1_in_repr
())
->
AsInput
()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_1
=
pattern
->
NewNode
(
concat_1_repr
())
->
assert_is_op
(
"concat"
);
auto
*
concat_1_out_var
=
pattern
->
NewNode
(
concat_1_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
assert_is_op_input
(
"matmul_v2"
)
->
assert_is_op_input
(
"assign"
);
auto
assign_1
=
pattern
->
NewNode
(
assign_1_repr
())
->
assert_is_op
(
"assign"
);
// V path Links
matmul2
->
LinksFrom
({
layer_norm_out_var
,
matmul2_w_var
})
.
LinksTo
({
matmul2_out_var
});
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
.
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
concat_1
->
LinksFrom
({
transpose2_2_out_var
,
concat_1_in_var
})
.
LinksTo
({
concat_1_out_var
});
assign_1
->
LinksFrom
({
concat_1_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
concat_0_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout_qk_out_var
,
concat_1_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"gelu"
);
auto
*
ffn_gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_layer_norm_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_gelu_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_dropout_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
PDNode
*
FusedMultiTransformerDecoderFuseQKVPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// QKV fused path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"concat"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_k_in_var
=
pattern
->
NewNode
(
concat_k_in_repr
())
// ->AsInput()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_k
=
pattern
->
NewNode
(
concat_k_repr
())
->
assert_is_op
(
"concat"
);
auto
*
concat_k_out_var
=
pattern
->
NewNode
(
concat_k_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
)
->
assert_is_op_input
(
"assign"
);
auto
*
concat_v_in_var
=
pattern
->
NewNode
(
concat_v_in_repr
())
// ->AsInput()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_v
=
pattern
->
NewNode
(
concat_v_repr
())
->
assert_is_op
(
"concat"
);
auto
*
concat_v_out_var
=
pattern
->
NewNode
(
concat_v_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
)
->
assert_is_op_input
(
"assign"
);
auto
*
assign_k
=
pattern
->
NewNode
(
assign_k_repr
())
->
assert_is_op
(
"assign"
);
auto
*
assign_v
=
pattern
->
NewNode
(
assign_v_repr
())
->
assert_is_op
(
"assign"
);
// QKV fused path Links
matmul0
->
LinksFrom
({
layer_norm_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
concat_k
->
LinksFrom
({
concat_k_in_var
,
split0_k_out_var
})
.
LinksTo
({
concat_k_out_var
});
concat_v
->
LinksFrom
({
concat_v_in_var
,
split0_v_out_var
})
.
LinksTo
({
concat_v_out_var
});
assign_k
->
LinksFrom
({
concat_k_out_var
});
assign_v
->
LinksFrom
({
concat_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout_qk_out_var
,
concat_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"gelu"
);
auto
*
ffn_gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_layer_norm_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_gelu_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_dropout_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
PDNode
*
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// communication c_identity
auto
*
c_identity
=
pattern
->
NewNode
(
c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity_out_var
=
pattern
->
NewNode
(
c_identity_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
c_identity
->
LinksFrom
({
layer_norm_out_var
}).
LinksTo
({
c_identity_out_var
});
// QKV fused path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"concat"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_k_in_var
=
pattern
->
NewNode
(
concat_k_in_repr
())
// ->AsInput()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_k
=
pattern
->
NewNode
(
concat_k_repr
())
->
assert_is_op
(
"concat"
);
auto
*
concat_k_out_var
=
pattern
->
NewNode
(
concat_k_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
)
->
assert_is_op_input
(
"assign"
);
auto
*
concat_v_in_var
=
pattern
->
NewNode
(
concat_v_in_repr
())
// ->AsInput()
->
assert_is_op_input
(
"concat"
);
auto
*
concat_v
=
pattern
->
NewNode
(
concat_v_repr
())
->
assert_is_op
(
"concat"
);
auto
*
concat_v_out_var
=
pattern
->
NewNode
(
concat_v_out_repr
())
->
assert_is_op_output
(
"concat"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
)
->
assert_is_op_input
(
"assign"
);
auto
*
assign_k
=
pattern
->
NewNode
(
assign_k_repr
())
->
assert_is_op
(
"assign"
);
auto
*
assign_v
=
pattern
->
NewNode
(
assign_v_repr
())
->
assert_is_op
(
"assign"
);
// QKV fused path Links
matmul0
->
LinksFrom
({
c_identity_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
concat_k
->
LinksFrom
({
concat_k_in_var
,
split0_k_out_var
})
.
LinksTo
({
concat_k_out_var
});
concat_v
->
LinksFrom
({
concat_v_in_var
,
split0_v_out_var
})
.
LinksTo
({
concat_v_out_var
});
assign_k
->
LinksFrom
({
concat_k_out_var
});
assign_v
->
LinksFrom
({
concat_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
concat_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
c_allreduce_sum
=
pattern
->
NewNode
(
c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
c_allreduce_sum_out_var
=
pattern
->
NewNode
(
c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout_qk_out_var
,
concat_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
c_allreduce_sum
->
LinksFrom
({
matmul_linear_out_var
})
.
LinksTo
({
c_allreduce_sum_out_var
});
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// communication c_identity
auto
*
ffn_c_identity
=
pattern
->
NewNode
(
ffn_c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
ffn_c_identity_out_var
=
pattern
->
NewNode
(
ffn_c_identity_out_repr
())
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"gelu"
);
auto
*
ffn_gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
ffn_c_allreduce_sum
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
ffn_c_allreduce_sum_out_var
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_gelu_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_dropout_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
}
// namespace patterns
int
FusedMultiTransformerDecoderPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
FusedMultiTransformerDecoderPattern
fused_multi_transformer_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// Cache KV use cache_kv in encoder
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
VarDesc
shape_out_desc
(
"shape_out."
+
std
::
to_string
(
layer_idx
));
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
shape_out_desc
.
SetPersistable
(
false
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out."
+
std
::
to_string
(
layer_idx
));
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
slice_out
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv_name
});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
// TimeStep link
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_decoder "
"pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer decoder fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1
,
matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_w
,
matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_0
,
concat_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_0_out
,
concat_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
assign_0
,
assign_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_1
,
concat_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_1_out
,
concat_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
assign_1
,
assign_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
matmul_linear_w
,
eltadd_linear_b
,
dropout_linear
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
concat_0
,
concat_1
,
concat_0_out
,
concat_1_out
,
assign_0
,
assign_1
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
FusedMultiTransformerDecoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multi_transformer pass, "
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerDecoderPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
FusedMultiTransformerDecoderPass
::
FusedMultiTransformerDecoderPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"concat"
))
.
AddInput
(
"X"
)
// Input("X"): vector<tensors>
.
End
()
.
AddInput
(
"AxisTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
2
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
}
int
FusedMultiTransformerDecoderFuseQKVPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
FusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
matmul0_w
,
Node
*
eltadd0_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// Cache KV use cache_kv in encoder
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
VarDesc
shape_out_desc
(
"shape_out."
+
std
::
to_string
(
layer_idx
));
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
shape_out_desc
.
SetPersistable
(
false
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out."
+
std
::
to_string
(
layer_idx
));
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
slice_out
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv_name
});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
// TimeStep link
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_decoder_fuse_qkv "
"pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer decoder(Fuse-QKV) fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_k_in
,
concat_k_in
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_k
,
concat_k
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_k_out
,
concat_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_v_in
,
concat_v_in
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_v
,
concat_v
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_v_out
,
concat_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
assign_k
,
assign_k
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
assign_v
,
assign_v
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
matmul0_w
,
eltadd0_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
matmul_linear_w
,
eltadd_linear_b
,
dropout_linear
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
matmul0
,
matmul0_out
,
eltadd0
,
eltadd0_out
,
reshape2_0
,
reshape2_0_out
,
transpose2_0
,
transpose2_0_out
,
split0
,
split0_q_out
,
split0_k_out
,
split0_v_out
,
concat_k_in
,
concat_k
,
concat_k_out
,
concat_v_in
,
concat_v
,
concat_v_out
,
assign_k
,
assign_v
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
FusedMultiTransformerDecoderFuseQKVPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the fused_multi_transformer_decoder "
"pass, The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerDecoderFuseQKVPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
FusedMultiTransformerDecoderFuseQKVPass
::
FusedMultiTransformerDecoderFuseQKVPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"concat"
))
.
AddInput
(
"X"
)
// Input("X"): vector<tensors>
.
End
()
.
AddInput
(
"AxisTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
2
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
}
int
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
c_identity
,
Node
*
matmul0_w
,
Node
*
eltadd0_b
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// Cache KV use cache_kv in encoder
auto
cache_kv_name
=
"cache_kv"
+
std
::
to_string
(
layer_idx
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv_name
});
VarDesc
shape_out_desc
(
"shape_out."
+
std
::
to_string
(
layer_idx
));
shape_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
shape_out_desc
.
SetPersistable
(
false
);
auto
*
shape_out
=
graph
->
CreateVarNode
(
&
shape_out_desc
);
OpDesc
shape_op_desc
(
layer_norm
->
Op
()
->
Block
());
shape_op_desc
.
SetType
(
"shape"
);
shape_op_desc
.
SetInput
(
"Input"
,
{
eltadd_qk_b
->
Name
()});
shape_op_desc
.
SetOutput
(
"Out"
,
{
shape_out
->
Name
()});
auto
*
shape_op
=
graph
->
CreateOpNode
(
&
shape_op_desc
);
VarDesc
slice_out_desc
(
"slice_out."
+
std
::
to_string
(
layer_idx
));
slice_out_desc
.
SetDataType
(
proto
::
VarType
::
INT32
);
slice_out_desc
.
SetPersistable
(
false
);
auto
*
slice_out
=
graph
->
CreateVarNode
(
&
slice_out_desc
);
OpDesc
slice_op_desc
(
layer_norm
->
Op
()
->
Block
());
slice_op_desc
.
SetType
(
"slice"
);
slice_op_desc
.
SetInput
(
"Input"
,
{
shape_out
->
Name
()});
slice_op_desc
.
SetOutput
(
"Out"
,
{
slice_out
->
Name
()});
std
::
vector
<
int
>
axes
=
{
0
};
std
::
vector
<
int
>
starts
=
{
3
};
std
::
vector
<
int
>
ends
=
{
4
};
slice_op_desc
.
SetAttr
(
"axes"
,
axes
);
slice_op_desc
.
SetAttr
(
"starts"
,
starts
);
slice_op_desc
.
SetAttr
(
"ends"
,
ends
);
auto
*
slice_op
=
graph
->
CreateOpNode
(
&
slice_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"TimeStep"
,
{
slice_out
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv_name
});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
// parallel ring id
auto
*
c_identity_op
=
c_identity
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"ring_id"
,
c_identity_op
->
GetAttr
(
"ring_id"
));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
// TimeStep link
IR_NODE_LINK_TO
(
eltadd_qk_b
,
shape_op
);
IR_NODE_LINK_TO
(
shape_op
,
shape_out
);
IR_NODE_LINK_TO
(
shape_out
,
slice_op
);
IR_NODE_LINK_TO
(
slice_op
,
slice_out
);
IR_NODE_LINK_TO
(
slice_out
,
fused_multi_transformer
)
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_decoder_fuse_qkv "
"pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer decoder(Fuse-QKV) fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity
,
c_identity
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity_out
,
c_identity_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_k_in
,
concat_k_in
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_k
,
concat_k
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_k_out
,
concat_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_v_in
,
concat_v_in
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_v
,
concat_v
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
concat_v_out
,
concat_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
assign_k
,
assign_k
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
assign_v
,
assign_v
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity
,
ffn_c_identity
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity_out
,
ffn_c_identity_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum
,
ffn_c_allreduce_sum
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum_out
,
ffn_c_allreduce_sum_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum
,
c_allreduce_sum
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum_out
,
c_allreduce_sum_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
c_identity
,
matmul0_w
,
eltadd0_b
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
matmul_linear_w
,
eltadd_linear_b
,
dropout_linear
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
c_identity
,
c_identity_out
,
matmul0
,
matmul0_out
,
eltadd0
,
eltadd0_out
,
reshape2_0
,
reshape2_0_out
,
transpose2_0
,
transpose2_0_out
,
split0
,
split0_q_out
,
split0_k_out
,
split0_v_out
,
concat_k_in
,
concat_k
,
concat_k_out
,
concat_v_in
,
concat_v
,
concat_v_out
,
assign_k
,
assign_v
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
c_allreduce_sum
,
c_allreduce_sum_out
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_c_identity
,
ffn_c_identity_out
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_c_allreduce_sum
,
ffn_c_allreduce_sum_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the fused_multi_transformer_decoder "
"pass, The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerDecoderFuseQKVPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
::
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"concat"
))
.
AddInput
(
"X"
)
// Input("X"): vector<tensors>
.
End
()
.
AddInput
(
"AxisTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsNumEQ
(
2
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fused_multi_transformer_decoder_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerDecoderPass
);
REGISTER_PASS
(
fused_multi_transformer_decoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerDecoderFuseQKVPass
);
REGISTER_PASS
(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
);
REGISTER_PASS_CAPABILITY
(
fused_multi_transformer_decoder_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
fused_multi_transformer_decoder_fuse_qkv_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h
0 → 100644
浏览文件 @
5a2e5179
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
FusedMultiTransformerDecoderPattern
:
public
PatternBase
{
FusedMultiTransformerDecoderPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_multi_transformer_decoder"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul1_w
);
PATTERN_DECL_NODE
(
matmul2_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
matmul1_out
);
PATTERN_DECL_NODE
(
matmul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
concat_0_in
);
PATTERN_DECL_NODE
(
concat_0
);
PATTERN_DECL_NODE
(
concat_0_out
);
PATTERN_DECL_NODE
(
assign_0
);
PATTERN_DECL_NODE
(
concat_1_in
);
PATTERN_DECL_NODE
(
concat_1
);
PATTERN_DECL_NODE
(
concat_1_out
);
PATTERN_DECL_NODE
(
assign_1
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// while loop
PATTERN_DECL_NODE
(
while0
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
struct
FusedMultiTransformerDecoderFuseQKVPattern
:
public
PatternBase
{
FusedMultiTransformerDecoderFuseQKVPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_multi_transformer_decoder_fuse_qkv"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
split0
)
PATTERN_DECL_NODE
(
split0_q_out
)
PATTERN_DECL_NODE
(
split0_k_out
)
PATTERN_DECL_NODE
(
split0_v_out
)
PATTERN_DECL_NODE
(
concat_k_in
)
PATTERN_DECL_NODE
(
concat_v_in
)
PATTERN_DECL_NODE
(
concat_k
)
PATTERN_DECL_NODE
(
concat_v
)
PATTERN_DECL_NODE
(
concat_k_out
)
PATTERN_DECL_NODE
(
concat_v_out
)
PATTERN_DECL_NODE
(
assign_k
)
PATTERN_DECL_NODE
(
assign_v
)
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
struct
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
c_identity
);
PATTERN_DECL_NODE
(
c_identity_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
split0
)
PATTERN_DECL_NODE
(
split0_q_out
)
PATTERN_DECL_NODE
(
split0_k_out
)
PATTERN_DECL_NODE
(
split0_v_out
)
PATTERN_DECL_NODE
(
concat_k_in
)
PATTERN_DECL_NODE
(
concat_v_in
)
PATTERN_DECL_NODE
(
concat_k
)
PATTERN_DECL_NODE
(
concat_v
)
PATTERN_DECL_NODE
(
concat_k_out
)
PATTERN_DECL_NODE
(
concat_v_out
)
PATTERN_DECL_NODE
(
assign_k
)
PATTERN_DECL_NODE
(
assign_v
)
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
c_allreduce_sum
);
PATTERN_DECL_NODE
(
c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_c_identity
);
PATTERN_DECL_NODE
(
ffn_c_identity_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
}
// namespace patterns
class
FusedMultiTransformerDecoderPass
:
public
FusePassBase
{
public:
FusedMultiTransformerDecoderPass
();
virtual
~
FusedMultiTransformerDecoderPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fused_multi_transformer_decoder"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
FusedMultiTransformerDecoderFuseQKVPass
:
public
FusePassBase
{
public:
FusedMultiTransformerDecoderFuseQKVPass
();
virtual
~
FusedMultiTransformerDecoderFuseQKVPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fused_multi_transformer_decoder_fuse_qkv"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
:
public
FusePassBase
{
public:
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
();
virtual
~
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multi_devices_fused_multi_transformer_decoder_fuse_qkv"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc
0 → 100644
浏览文件 @
5a2e5179
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
// MHA: pre Layer Norm
AddVarToScope
(
param_scope
,
"ln_scale"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ln_bias"
,
{
1024
});
// MHA: QKV fc
AddVarToScope
(
param_scope
,
"weights0"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"weights1"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"weights2"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"bias_0"
,
{
1024
});
AddVarToScope
(
param_scope
,
"bias_1"
,
{
1024
});
AddVarToScope
(
param_scope
,
"bias_2"
,
{
1024
});
// MHA: QK bias
AddVarToScope
(
param_scope
,
"biasqk"
,
{
1024
});
// MHA: out Linear
AddVarToScope
(
param_scope
,
"weights_l"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"bias_l"
,
{
1024
});
// MHA: pre Layer Norm
AddVarToScope
(
param_scope
,
"ffn_ln_scale"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_ln_bias"
,
{
1024
});
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias_0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias_1"
,
{
1024
});
return
param_scope
;
}
TEST
(
FusedMultiTransformerDecoderPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_1) concat -> concat_0
// (transpose_2) concat -> concat_2
// (concat_0) assign -> assign_0
// (concat_1) assign -> assign_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
ln_out
,
weights_1
,
nullptr
,
false
,
true
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
ln_out
,
weights_2
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
1024
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
matmul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
matmul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
auto
*
cache_k
=
layers
.
data
(
"cache_k"
,
{
1
,
16
,
128
,
64
});
auto
*
cache_v
=
layers
.
data
(
"cache_v"
,
{
1
,
16
,
128
,
64
});
auto
*
concat_k
=
layers
.
concat
({
cache_k
,
transpose_1
},
2
);
auto
*
concat_v
=
layers
.
concat
({
cache_v
,
transpose_2
},
2
);
layers
.
assign
(
concat_k
);
layers
.
assign
(
concat_v
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
transpose_0
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout_qk
,
concat_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"weightsl"
,
{
1024
,
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
true
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_decoder_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get fused_multi_transformer_decoder_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
72
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
72
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
FusedMultiTransformerDecoderPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"fused_multi_transformer_decoder_pass"
));
}
TEST
(
FusedMultiTransformerDecoderFuseQKVPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
3072
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
3072
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
split_outs
=
layers
.
split
(
transpose_0
,
3
,
3
);
auto
*
split_q
=
split_outs
[
0
];
auto
*
split_k
=
split_outs
[
1
];
auto
*
split_v
=
split_outs
[
2
];
auto
*
cache_k
=
layers
.
data
(
"cache_k"
,
{
1
,
16
,
128
,
64
});
auto
*
cache_v
=
layers
.
data
(
"cache_v"
,
{
1
,
16
,
128
,
64
});
auto
*
concat_k
=
layers
.
concat
({
cache_k
,
split_k
},
2
);
auto
*
concat_v
=
layers
.
concat
({
cache_v
,
split_v
},
2
);
layers
.
assign
(
concat_k
);
layers
.
assign
(
concat_v
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout_qk
,
concat_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"weightsl"
,
{
1024
,
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
true
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_decoder_fuse_qkv_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get fused_multi_transformer_decoder_fuse_qkv_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
62
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
62
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv "
"pass, there should be one fused_multi_transformer "
"op, but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
FusedMultiTransformerDecoderFuseQKVPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"fused_multi_transformer_decoder_fuse_qkv_pass"
));
}
TEST
(
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> c_identity_out
// (c_identity_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (ffn_layer_norm_out) c_identity -> ffn_c_identity_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
auto
*
c_identity_out
=
layers
.
c_identity
(
ln_out
);
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
3072
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
c_identity_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
3072
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
split_outs
=
layers
.
split
(
transpose_0
,
3
,
3
);
auto
*
split_q
=
split_outs
[
0
];
auto
*
split_k
=
split_outs
[
1
];
auto
*
split_v
=
split_outs
[
2
];
auto
*
cache_k
=
layers
.
data
(
"cache_k"
,
{
1
,
16
,
128
,
64
});
auto
*
cache_v
=
layers
.
data
(
"cache_v"
,
{
1
,
16
,
128
,
64
});
auto
*
concat_k
=
layers
.
concat
({
cache_k
,
split_k
},
2
);
auto
*
concat_v
=
layers
.
concat
({
cache_v
,
split_v
},
2
);
layers
.
assign
(
concat_k
);
layers
.
assign
(
concat_v
);
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
concat_k
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout_qk
,
concat_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"weightsl"
,
{
1024
,
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
true
);
auto
*
c_allreduce_out
=
layers
.
c_allreduce_sum
(
linear_matmut_out
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
auto
*
ffn_c_identity_out
=
layers
.
c_identity
(
ffn_ln_out
);
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_c_identity_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
true
);
auto
*
ffn_c_allreduce_out
=
layers
.
c_allreduce_sum
(
ffn_matmul1_out
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_c_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass "
"failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
70
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
70
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_decoder_fuse_qkv "
"multi-devices pass, there should be one "
"fused_multi_transformer op, but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
MultiDevicesFusedMultiTransformerDecoderFuseQKVPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fused_multi_transformer_decoder_pass
);
USE_PASS
(
fused_multi_transformer_decoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass
);
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc
0 → 100644
浏览文件 @
5a2e5179
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
PDNode
*
FusedMultiTransformerEncoderPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
3
)
{
return
true
;
}
else
{
return
false
;
}
});
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// Q path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
// Q path Links
matmul0
->
LinksFrom
({
layer_norm_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
// K path Nodes
auto
*
matmul1
=
pattern
->
NewNode
(
matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul1_w_var
=
pattern
->
NewNode
(
matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul1_out_var
=
pattern
->
NewNode
(
matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd1
=
pattern
->
NewNode
(
eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd1_b_var
=
pattern
->
NewNode
(
eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_1
=
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"while"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
// K path Links
matmul1
->
LinksFrom
({
layer_norm_out_var
,
matmul1_w_var
})
.
LinksTo
({
matmul1_out_var
});
eltadd1
->
LinksFrom
({
matmul1_out_var
,
eltadd1_b_var
})
.
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// V path Nodes
auto
*
matmul2
=
pattern
->
NewNode
(
matmul2_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul2_w_var
=
pattern
->
NewNode
(
matmul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul2_out_var
=
pattern
->
NewNode
(
matmul2_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
)
->
assert_more
([](
Node
*
x
)
{
if
(
x
->
outputs
.
size
()
==
2
)
{
return
true
;
}
else
{
return
false
;
}
});
// V path Links
matmul2
->
LinksFrom
({
layer_norm_out_var
,
matmul2_w_var
})
.
LinksTo
({
matmul2_out_var
});
eltadd2
->
LinksFrom
({
matmul2_out_var
,
eltadd2_b_var
})
.
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout_qk_out_var
,
transpose2_2_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
transpose2_1_out_var
,
transpose2_2_out_var
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"gelu"
);
auto
*
ffn_gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_layer_norm_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_gelu_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_dropout_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
PDNode
*
FusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// QKV fused path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
// QKV fused path Links
matmul0
->
LinksFrom
({
layer_norm_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout_qk_out_var
,
split0_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
eltadd_linear
->
LinksFrom
({
matmul_linear_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"gelu"
);
auto
*
ffn_gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_layer_norm_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_gelu_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_matmul1_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_dropout_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
PDNode
*
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
::
operator
()()
{
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
// pre-LayerNorm
auto
*
layer_norm
=
pattern
->
NewNode
(
layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
layer_norm_scale_var
=
pattern
->
NewNode
(
layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
layer_norm_bias_var
=
pattern
->
NewNode
(
layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
layer_norm_mean_var
=
pattern
->
NewNode
(
layer_norm_mean_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
layer_norm_variance_var
=
pattern
->
NewNode
(
layer_norm_variance_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
layer_norm_out_var
=
pattern
->
NewNode
(
layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
layer_norm
->
LinksFrom
({
input0
,
layer_norm_bias_var
,
layer_norm_scale_var
})
.
LinksTo
(
{
layer_norm_out_var
,
layer_norm_mean_var
,
layer_norm_variance_var
});
// communication c_identity
auto
*
c_identity
=
pattern
->
NewNode
(
c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
c_identity_out_var
=
pattern
->
NewNode
(
c_identity_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
c_identity
->
LinksFrom
({
layer_norm_out_var
}).
LinksTo
({
c_identity_out_var
});
// QKV fused path Nodes
auto
*
matmul0
=
pattern
->
NewNode
(
matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul0_w_var
=
pattern
->
NewNode
(
matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul0_out_var
=
pattern
->
NewNode
(
matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"split"
,
"X"
);
auto
*
split0
=
pattern
->
NewNode
(
split0_repr
())
->
assert_is_op
(
"split"
);
auto
*
split0_q_out_var
=
pattern
->
NewNode
(
split0_q_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
*
split0_k_out_var
=
pattern
->
NewNode
(
split0_k_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
auto
*
split0_v_out_var
=
pattern
->
NewNode
(
split0_v_out_repr
())
->
assert_is_op_output
(
"split"
)
->
AsOutput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
)
->
assert_is_op_input
(
"while"
);
// QKV fused path Links
matmul0
->
LinksFrom
({
c_identity_out_var
,
matmul0_w_var
})
.
LinksTo
({
matmul0_out_var
});
eltadd0
->
LinksFrom
({
matmul0_out_var
,
eltadd0_b_var
})
.
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
split0
->
LinksFrom
({
transpose2_0_out_var
})
.
LinksTo
({
split0_q_out_var
,
split0_k_out_var
,
split0_v_out_var
});
// while loop
auto
*
while0
=
pattern
->
NewNode
(
while0_repr
())
->
assert_is_op
(
"while"
);
while0
->
LinksFrom
({
split0_k_out_var
,
split0_v_out_var
});
// QK path Nodes
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_qk
=
pattern
->
NewNode
(
dropout_qk_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_qk_out_var
=
pattern
->
NewNode
(
dropout_qk_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
// -> matmul_qkv
// QK path Linsk
matmul_qk
->
LinksFrom
({
split0_q_out_var
,
split0_k_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
dropout_qk
->
LinksFrom
({
softmax_qk_out_var
}).
LinksTo
({
dropout_qk_out_var
});
// QKV path Nodes
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
// -> out_linear
auto
*
matmul_linear
=
pattern
->
NewNode
(
matmul_linear_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
matmul_linear_w_var
=
pattern
->
NewNode
(
matmul_linear_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
matmul_linear_out_var
=
pattern
->
NewNode
(
matmul_linear_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
c_allreduce_sum
=
pattern
->
NewNode
(
c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
c_allreduce_sum_out_var
=
pattern
->
NewNode
(
c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_linear
=
pattern
->
NewNode
(
eltadd_linear_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_linear_b_var
=
pattern
->
NewNode
(
eltadd_linear_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_linear_out_var
=
pattern
->
NewNode
(
eltadd_linear_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
dropout_linear
=
pattern
->
NewNode
(
dropout_linear_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
dropout_linear_out_var
=
pattern
->
NewNode
(
dropout_linear_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
attention_output
=
pattern
->
NewNode
(
attention_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
// QKV path Links
matmul_qkv
->
LinksFrom
({
dropout_qk_out_var
,
split0_v_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
matmul_linear
->
LinksFrom
({
reshape2_qkv_out_var
,
matmul_linear_w_var
})
.
LinksTo
({
matmul_linear_out_var
});
c_allreduce_sum
->
LinksFrom
({
matmul_linear_out_var
})
.
LinksTo
({
c_allreduce_sum_out_var
});
eltadd_linear
->
LinksFrom
({
c_allreduce_sum_out_var
,
eltadd_linear_b_var
})
.
LinksTo
({
eltadd_linear_out_var
});
dropout_linear
->
LinksFrom
({
eltadd_linear_out_var
})
.
LinksTo
({
dropout_linear_out_var
});
eltadd_out
->
LinksFrom
({
input0
,
dropout_linear_out_var
})
.
LinksTo
({
attention_output
});
// Feed Forward LayerNorm Nodes
auto
*
ffn_layer_norm
=
pattern
->
NewNode
(
ffn_layer_norm_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
ffn_layer_norm_scale_var
=
pattern
->
NewNode
(
ffn_layer_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
ffn_layer_norm_bias_var
=
pattern
->
NewNode
(
ffn_layer_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
ffn_layer_norm_mean_var
=
pattern
->
NewNode
(
ffn_layer_norm_mean_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
ffn_layer_norm_variance_var
=
pattern
->
NewNode
(
ffn_layer_norm_variance_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
auto
*
ffn_layer_norm_out_var
=
pattern
->
NewNode
(
ffn_layer_norm_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"layer_norm"
,
"Y"
)
->
assert_is_op_input
(
"c_identity"
,
"X"
);
ffn_layer_norm
->
LinksFrom
(
{
attention_output
,
ffn_layer_norm_bias_var
,
ffn_layer_norm_scale_var
})
.
LinksTo
({
ffn_layer_norm_out_var
,
ffn_layer_norm_mean_var
,
ffn_layer_norm_variance_var
});
// communication c_identity
auto
*
ffn_c_identity
=
pattern
->
NewNode
(
ffn_c_identity_repr
())
->
assert_is_op
(
"c_identity"
);
auto
*
ffn_c_identity_out_var
=
pattern
->
NewNode
(
ffn_c_identity_out_repr
())
->
assert_is_op_output
(
"c_identity"
,
"Out"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
ffn_c_identity
->
LinksFrom
({
ffn_layer_norm_out_var
})
.
LinksTo
({
ffn_c_identity_out_var
});
// Feed Forward fc1 -> gelu -> fc2 -> dropout
auto
*
ffn_matmul0
=
pattern
->
NewNode
(
ffn_matmul0_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul0_w_var
=
pattern
->
NewNode
(
ffn_matmul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul0_out_var
=
pattern
->
NewNode
(
ffn_matmul0_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd0
=
pattern
->
NewNode
(
ffn_eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd0_b_var
=
pattern
->
NewNode
(
ffn_eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd0_out_var
=
pattern
->
NewNode
(
ffn_eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"gelu"
);
auto
*
ffn_gelu
=
pattern
->
NewNode
(
ffn_gelu_repr
())
->
assert_is_op
(
"gelu"
);
auto
*
ffn_gelu_out_var
=
pattern
->
NewNode
(
ffn_gelu_out_repr
())
->
assert_is_op_output
(
"gelu"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul_v2"
);
auto
*
ffn_matmul1
=
pattern
->
NewNode
(
ffn_matmul1_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
ffn_matmul1_w_var
=
pattern
->
NewNode
(
ffn_matmul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
ffn_matmul1_out_var
=
pattern
->
NewNode
(
ffn_matmul1_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"c_allreduce_sum"
);
// communication c_allreduce_sum
auto
*
ffn_c_allreduce_sum
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_repr
())
->
assert_is_op
(
"c_allreduce_sum"
);
auto
*
ffn_c_allreduce_sum_out_var
=
pattern
->
NewNode
(
ffn_c_allreduce_sum_out_repr
())
->
assert_is_op_output
(
"c_allreduce_sum"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd1
=
pattern
->
NewNode
(
ffn_eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_eltadd1_b_var
=
pattern
->
NewNode
(
ffn_eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ffn_eltadd1_out_var
=
pattern
->
NewNode
(
ffn_eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"dropout"
);
auto
*
ffn_dropout
=
pattern
->
NewNode
(
ffn_dropout_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
ffn_dropout_out_var
=
pattern
->
NewNode
(
ffn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
)
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
ffn_eltadd_out
=
pattern
->
NewNode
(
ffn_eltadd_out_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ffn_output
=
pattern
->
NewNode
(
ffn_output_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsOutput
();
ffn_matmul0
->
LinksFrom
({
ffn_c_identity_out_var
,
ffn_matmul0_w_var
})
.
LinksTo
({
ffn_matmul0_out_var
});
ffn_eltadd0
->
LinksFrom
({
ffn_matmul0_out_var
,
ffn_eltadd0_b_var
})
.
LinksTo
({
ffn_eltadd0_out_var
});
ffn_gelu
->
LinksFrom
({
ffn_eltadd0_out_var
}).
LinksTo
({
ffn_gelu_out_var
});
ffn_matmul1
->
LinksFrom
({
ffn_gelu_out_var
,
ffn_matmul1_w_var
})
.
LinksTo
({
ffn_matmul1_out_var
});
ffn_c_allreduce_sum
->
LinksFrom
({
ffn_matmul1_out_var
})
.
LinksTo
({
ffn_c_allreduce_sum_out_var
});
ffn_eltadd1
->
LinksFrom
({
ffn_c_allreduce_sum_out_var
,
ffn_eltadd1_b_var
})
.
LinksTo
({
ffn_eltadd1_out_var
});
ffn_dropout
->
LinksFrom
({
ffn_eltadd1_out_var
}).
LinksTo
({
ffn_dropout_out_var
});
ffn_eltadd_out
->
LinksFrom
({
attention_output
,
ffn_dropout_out_var
})
.
LinksTo
({
ffn_output
});
return
ffn_output
;
}
}
// namespace patterns
template
<
typename
T
>
inline
void
QKVWeightsProcess
(
framework
::
LoDTensor
*
wq_tensor
,
framework
::
LoDTensor
*
wk_tensor
,
framework
::
LoDTensor
*
wv_tensor
,
framework
::
LoDTensor
*
bq_tensor
,
framework
::
LoDTensor
*
bk_tensor
,
framework
::
LoDTensor
*
bv_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
wq_data
=
wq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
wk_data
=
wk_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
wv_data
=
wv_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
bq_data
=
bq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
bk_data
=
bk_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
bv_data
=
bv_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
combined_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
auto
combined_bias_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
framework
::
LoDTensor
tmp_combined_w_tensor
;
tmp_combined_w_tensor
.
Resize
(
combined_w_dims
);
auto
*
tmp_combined_w_data
=
tmp_combined_w_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
std
::
vector
<
T
*>
w_vec
=
{
wq_data
,
wk_data
,
wv_data
};
// Combine the three fc weights together.
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
tmp_combined_w_data
[
out_idx
]
=
w_vec
[
i
][
in_idx
];
}
}
}
}
wq_tensor
->
Resize
(
combined_w_dims
);
auto
*
new_combined_w_data
=
wq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_w_data
,
tmp_combined_w_data
,
sizeof
(
T
)
*
wq_tensor
->
numel
());
framework
::
LoDTensor
tmp_combined_bias_tensor
;
tmp_combined_bias_tensor
.
Resize
(
combined_bias_dims
);
auto
*
tmp_combined_bias_data
=
tmp_combined_bias_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
size_t
bias_size
=
bq_tensor
->
numel
();
memcpy
(
tmp_combined_bias_data
,
bq_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
bias_size
,
bk_data
,
sizeof
(
T
)
*
bias_size
);
memcpy
(
tmp_combined_bias_data
+
2
*
bias_size
,
bv_data
,
sizeof
(
T
)
*
bias_size
);
bq_tensor
->
Resize
(
combined_bias_dims
);
auto
*
new_combined_bias_data
=
bq_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
memcpy
(
new_combined_bias_data
,
tmp_combined_bias_data
,
sizeof
(
T
)
*
bq_tensor
->
numel
());
}
template
<
typename
T
>
inline
void
QKVWeightsProcessFuseQKV
(
framework
::
LoDTensor
*
qkv_w_tensor
,
framework
::
LoDTensor
*
qkv_b_tensor
,
const
int
num_head
,
const
int
dim_head
,
const
int
dim_embed
)
{
auto
*
qkv_w_data
=
qkv_w_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
transpose_w_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
,
dim_embed
});
framework
::
LoDTensor
tmp_transpose_w_tensor
;
tmp_transpose_w_tensor
.
Resize
(
transpose_w_dims
);
auto
*
tmp_transpose_w_data
=
tmp_transpose_w_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
// transpose qkv matmul Y to QKVWeights
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
for
(
int
l
=
0
;
l
<
dim_embed
;
l
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
*
dim_embed
+
j
*
dim_head
*
dim_embed
+
k
*
dim_embed
+
l
;
int
in_idx
=
l
*
num_head
*
3
*
dim_head
+
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_w_data
[
out_idx
]
=
qkv_w_data
[
in_idx
];
}
}
}
}
qkv_w_tensor
->
Resize
(
transpose_w_dims
);
auto
*
new_transpose_w_data
=
qkv_w_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
memcpy
(
new_transpose_w_data
,
tmp_transpose_w_data
,
sizeof
(
T
)
*
qkv_w_tensor
->
numel
());
auto
*
qkv_b_data
=
qkv_b_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
transpose_b_dims
=
phi
::
make_ddim
({
3
,
num_head
,
dim_head
});
framework
::
LoDTensor
tmp_transpose_b_tensor
;
tmp_transpose_b_tensor
.
Resize
(
transpose_b_dims
);
auto
*
tmp_transpose_b_data
=
tmp_transpose_b_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
// transpose qkv elemenwise_add Y to QKVBias
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
for
(
int
j
=
0
;
j
<
num_head
;
j
++
)
{
for
(
int
k
=
0
;
k
<
dim_head
;
k
++
)
{
int
out_idx
=
i
*
num_head
*
dim_head
+
j
*
dim_head
+
k
;
int
in_idx
=
j
*
3
*
dim_head
+
i
*
dim_head
+
k
;
tmp_transpose_b_data
[
out_idx
]
=
qkv_b_data
[
in_idx
];
}
}
}
qkv_b_tensor
->
Resize
({
3
,
num_head
,
dim_head
});
auto
*
new_transpose_b_data
=
qkv_b_tensor
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
memcpy
(
new_transpose_b_data
,
tmp_transpose_b_data
,
sizeof
(
T
)
*
qkv_b_tensor
->
numel
());
}
int
FusedMultiTransformerEncoderPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
FusedMultiTransformerEncoderPattern
fused_multi_transformer_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
matmul0_w
,
Node
*
matmul1_w
,
Node
*
matmul2_w
,
Node
*
eltadd0_b
,
Node
*
eltadd1_b
,
Node
*
eltadd2_b
,
Node
*
transpose2_1_out
,
Node
*
transpose2_2_out
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
);
int
dim_embed
=
num_head
*
dim_head
;
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
wq_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wk_tensor
=
scope
->
FindVar
(
matmul1_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
wv_tensor
=
scope
->
FindVar
(
matmul2_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bq_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bk_tensor
=
scope
->
FindVar
(
eltadd1_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
bv_tensor
=
scope
->
FindVar
(
eltadd2_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
if
(
wq_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
QKVWeightsProcess
<
float
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
}
else
if
(
wq_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
QKVWeightsProcess
<
platform
::
float16
>
(
wq_tensor
,
wk_tensor
,
wv_tensor
,
bq_tensor
,
bk_tensor
,
bv_tensor
,
num_head
,
dim_head
,
dim_embed
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."
));
}
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto
*
combined_w_desc
=
matmul0_w
->
Var
();
combined_w_desc
->
SetShape
({
3
,
num_head
,
dim_head
,
dim_embed
});
combined_w_desc
->
SetPersistable
(
true
);
auto
*
combined_bias_desc
=
eltadd0_b
->
Var
();
combined_bias_desc
->
SetShape
({
3
,
num_head
,
dim_head
});
combined_bias_desc
->
SetPersistable
(
true
);
scope
->
EraseVars
({
matmul1_w
->
Name
(),
matmul2_w
->
Name
()});
scope
->
EraseVars
({
eltadd1_b
->
Name
(),
eltadd2_b
->
Name
()});
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// CacheKV input
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
wq_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
OpDesc
fill_const_op_desc
(
layer_norm
->
Op
()
->
Block
());
fill_const_op_desc
.
SetType
(
"fill_constant_batch_size_like"
);
fill_const_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
fill_const_op_desc
.
SetOutput
(
"Out"
,
{
cache_kv
->
Name
()});
std
::
vector
<
int
>
shape
=
{
2
,
-
1
,
num_head
,
1024
,
dim_head
};
fill_const_op_desc
.
SetAttr
(
"shape"
,
shape
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fill_const_op
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
transpose2_1_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
transpose2_2_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
matmul1_w
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
matmul2_w
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
eltadd1_b
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
eltadd2_b
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
transpose2_1_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
transpose2_2_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
IR_NODE_UNLINK
(
transpose2_1_out
,
while0
);
IR_NODE_UNLINK
(
transpose2_2_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
transpose2_1_out
);
IR_NODE_UNLINK
(
while0
,
transpose2_2_out
);
// unlink KV weight/bias to while after merged into Q weight/bias
IR_NODE_UNLINK
(
matmul1_w
,
while0
);
IR_NODE_UNLINK
(
matmul2_w
,
while0
);
IR_NODE_UNLINK
(
eltadd1_b
,
while0
);
IR_NODE_UNLINK
(
eltadd2_b
,
while0
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder pass in "
"op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1
,
matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_out
,
matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul1_w
,
matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1
,
reshape2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_1_out
,
reshape2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1
,
transpose2_1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_1_out
,
transpose2_1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2
,
matmul2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_out
,
matmul2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul2_w
,
matmul2_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2
,
reshape2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_2_out
,
reshape2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2
,
transpose2_2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_2_out
,
transpose2_2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attention_output
,
attention_output
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1
,
eltadd1
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_b
,
eltadd1_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd1_out
,
eltadd1_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2
,
eltadd2
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_b
,
eltadd2_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd2_out
,
eltadd2_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
matmul0_w
,
matmul1_w
,
matmul2_w
,
eltadd0_b
,
eltadd1_b
,
eltadd2_b
,
transpose2_1_out
,
transpose2_2_out
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
matmul_linear_w
,
eltadd_linear_b
,
dropout_linear
,
while0
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
matmul0
,
matmul1
,
matmul2
,
matmul0_out
,
matmul1_out
,
matmul2_out
,
eltadd0
,
eltadd1
,
eltadd2
,
eltadd0_out
,
eltadd1_out
,
eltadd2_out
,
reshape2_0
,
reshape2_1
,
reshape2_2
,
reshape2_0_out
,
reshape2_1_out
,
reshape2_2_out
,
transpose2_0
,
transpose2_1
,
transpose2_2
,
transpose2_0_out
,
transpose2_1_out
,
transpose2_2_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
FusedMultiTransformerEncoderPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multi_transformer pass, The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
FusedMultiTransformerEncoderPass
::
FusedMultiTransformerEncoderPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"while"
))
.
AddInput
(
"X"
)
// A set of variables, unconstrained
.
End
()
.
AddInput
(
"Condition"
)
// An scalar
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// A set of variables, unconstrained
.
End
()
.
AddOutput
(
"StepScopes"
)
// A vector of local scope, unconstrained
.
End
()
.
AddAttr
(
"sub_block"
)
.
IsType
<
framework
::
BlockDesc
*>
()
.
End
();
}
int
FusedMultiTransformerEncoderFuseQKVPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
FusedMultiTransformerEncoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
matmul0_w
,
Node
*
eltadd0_b
,
Node
*
split0_k_out
,
Node
*
split0_v_out
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
int
dim_embed
=
num_head
*
dim_head
;
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
qkv_w_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
qkv_b_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
if
(
qkv_w_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
QKVWeightsProcessFuseQKV
<
float
>
(
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
}
else
if
(
qkv_w_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
QKVWeightsProcessFuseQKV
<
platform
::
float16
>
(
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."
));
}
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// CacheKV input
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_w_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
OpDesc
fill_const_op_desc
(
layer_norm
->
Op
()
->
Block
());
fill_const_op_desc
.
SetType
(
"fill_constant_batch_size_like"
);
fill_const_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
fill_const_op_desc
.
SetOutput
(
"Out"
,
{
cache_kv
->
Name
()});
std
::
vector
<
int
>
shape
=
{
2
,
-
1
,
num_head
,
1024
,
dim_head
};
fill_const_op_desc
.
SetAttr
(
"shape"
,
shape
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fill_const_op
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
IR_NODE_UNLINK
(
split0_k_out
,
while0
);
IR_NODE_UNLINK
(
split0_v_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
split0_k_out
);
IR_NODE_UNLINK
(
while0
,
split0_v_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder_fuse_qkv "
"pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder(Fuse-QKV) fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_fuse_qkv_pattern
)
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
matmul0_w
,
eltadd0_b
,
split0_k_out
,
split0_v_out
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
matmul_linear_w
,
eltadd_linear_b
,
dropout_linear
,
while0
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
matmul0
,
matmul0_out
,
eltadd0
,
eltadd0_out
,
reshape2_0
,
reshape2_0_out
,
transpose2_0
,
transpose2_0_out
,
split0
,
split0_q_out
,
split0_k_out
,
split0_v_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
FusedMultiTransformerEncoderFuseQKVPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kFusedMultiTransformerEncoderFuseQKVPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
FusedMultiTransformerEncoderFuseQKVPass
::
FusedMultiTransformerEncoderFuseQKVPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"while"
))
.
AddInput
(
"X"
)
// A set of variables, unconstrained
.
End
()
.
AddInput
(
"Condition"
)
// An scalar
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// A set of variables, unconstrained
.
End
()
.
AddOutput
(
"StepScopes"
)
// A vector of local scope, unconstrained
.
End
()
.
AddAttr
(
"sub_block"
)
.
IsType
<
framework
::
BlockDesc
*>
()
.
End
();
}
int
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
::
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
patterns
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
fused_multi_transformer_fuse_qkv_pattern
(
pattern
,
name_scope
);
fused_multi_transformer_fuse_qkv_pattern
();
// Create New OpDesc
auto
fuse_creater
=
[
&
](
Node
*
input0
,
Node
*
layer_norm
,
Node
*
layer_norm_scale
,
Node
*
layer_norm_bias
,
Node
*
layer_norm_mean
,
Node
*
layer_norm_variance
,
Node
*
c_identity
,
Node
*
matmul0_w
,
Node
*
eltadd0_b
,
Node
*
split0_k_out
,
Node
*
split0_v_out
,
Node
*
eltadd_qk_b
,
Node
*
dropout_qk
,
Node
*
reshape2_0
,
Node
*
matmul_linear_w
,
Node
*
eltadd_linear_b
,
Node
*
dropout_linear
,
Node
*
while0
,
Node
*
ffn_layer_norm
,
Node
*
ffn_layer_norm_scale
,
Node
*
ffn_layer_norm_bias
,
Node
*
ffn_layer_norm_mean
,
Node
*
ffn_layer_norm_variance
,
Node
*
ffn_matmul0_w
,
Node
*
ffn_matmul1_w
,
Node
*
ffn_eltadd0_b
,
Node
*
ffn_eltadd1_b
,
Node
*
ffn_dropout
,
Node
*
ffn_output
)
{
auto
reshape_desc
=
reshape2_0
->
Op
();
int
num_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
2
);
int
dim_head
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
reshape_desc
->
GetAttr
(
"shape"
))
.
at
(
3
)
/
3
;
// 3 for qkv
// Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes:
// 1. no LayerNorm before all transformer layer
// 2. each transformer layer contains 2 LayerNorm layer
auto
ln_scale_name
=
layer_norm_scale
->
Name
();
auto
ln_name
=
ln_scale_name
.
substr
(
0
,
ln_scale_name
.
find
(
'.'
));
auto
ln_idx_str
=
ln_name
.
substr
(
ln_name
.
rfind
(
'_'
)
+
1
);
int
layer_idx
=
atoi
(
ln_idx_str
.
c_str
())
/
2
;
auto
*
qkv_w_tensor
=
scope
->
FindVar
(
matmul0_w
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
*
qkv_b_tensor
=
scope
->
FindVar
(
eltadd0_b
->
Name
())
->
GetMutable
<
LoDTensor
>
();
int
dim_embed
=
qkv_w_tensor
->
dims
()[
0
];
if
(
qkv_w_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
QKVWeightsProcessFuseQKV
<
float
>
(
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
}
else
if
(
qkv_w_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
QKVWeightsProcessFuseQKV
<
platform
::
float16
>
(
qkv_w_tensor
,
qkv_b_tensor
,
num_head
,
dim_head
,
dim_embed
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"fused_multi_transformer not supported weight dtype. "
"we now only support fp32 and fp16."
));
}
// create fused_multi_transformer
OpDesc
fused_multi_transformer_op_desc
(
layer_norm
->
Op
()
->
Block
());
fused_multi_transformer_op_desc
.
SetType
(
"fused_multi_transformer"
);
// 1. Input setting
fused_multi_transformer_op_desc
.
SetInput
(
"X"
,
{
input0
->
Name
()});
// pre-LayerNorm input
fused_multi_transformer_op_desc
.
SetInput
(
"LnScale"
,
{
layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"LnBias"
,
{
layer_norm_bias
->
Name
()});
// QKV computation input
fused_multi_transformer_op_desc
.
SetInput
(
"QKVW"
,
{
matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"QKVBias"
,
{
eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"SrcMask"
,
{
eltadd_qk_b
->
Name
()});
// CacheKV input
VarDesc
cache_kv_desc
(
"cache_kv"
+
std
::
to_string
(
layer_idx
));
// FIXME: only support max_seq_len <= 1024
cache_kv_desc
.
SetDataType
(
framework
::
TransToProtoVarType
(
qkv_w_tensor
->
dtype
()));
cache_kv_desc
.
SetPersistable
(
false
);
auto
*
cache_kv
=
graph
->
CreateVarNode
(
&
cache_kv_desc
);
OpDesc
fill_const_op_desc
(
layer_norm
->
Op
()
->
Block
());
fill_const_op_desc
.
SetType
(
"fill_constant_batch_size_like"
);
fill_const_op_desc
.
SetInput
(
"Input"
,
{
input0
->
Name
()});
fill_const_op_desc
.
SetOutput
(
"Out"
,
{
cache_kv
->
Name
()});
std
::
vector
<
int
>
shape
=
{
2
,
-
1
,
num_head
,
1024
,
dim_head
};
fill_const_op_desc
.
SetAttr
(
"shape"
,
shape
);
fill_const_op_desc
.
SetAttr
(
"input_dim_idx"
,
0
);
fill_const_op_desc
.
SetAttr
(
"output_dim_idx"
,
1
);
fill_const_op_desc
.
SetAttr
(
"value"
,
0
);
fill_const_op_desc
.
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
proto
::
VarType
::
FP32
));
auto
*
fill_const_op
=
graph
->
CreateOpNode
(
&
fill_const_op_desc
);
fused_multi_transformer_op_desc
.
SetInput
(
"CacheKV"
,
{
cache_kv
->
Name
()});
// Out Linear input
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearW"
,
{
matmul_linear_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"OutLinearBias"
,
{
eltadd_linear_b
->
Name
()});
// Feed Forward input
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnScale"
,
{
ffn_layer_norm_scale
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFNLnBias"
,
{
ffn_layer_norm_bias
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Weight"
,
{
ffn_matmul0_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN1Bias"
,
{
ffn_eltadd0_b
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Weight"
,
{
ffn_matmul1_w
->
Name
()});
fused_multi_transformer_op_desc
.
SetInput
(
"FFN2Bias"
,
{
ffn_eltadd1_b
->
Name
()});
// 2. Output setting
fused_multi_transformer_op_desc
.
SetOutput
(
"Out"
,
{
ffn_output
->
Name
()});
fused_multi_transformer_op_desc
.
SetOutput
(
"CacheKVOut"
,
{
cache_kv
->
Name
()});
// Attribute setting
fused_multi_transformer_op_desc
.
SetAttr
(
"pre_layer_norm"
,
true
);
fused_multi_transformer_op_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
// output dropout attribute
auto
*
dropout_op
=
dropout_linear
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_rate"
,
dropout_op
->
GetAttr
(
"dropout_prob"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"is_test"
,
dropout_op
->
GetAttr
(
"is_test"
));
fused_multi_transformer_op_desc
.
SetAttr
(
"dropout_implementation"
,
dropout_op
->
GetAttr
(
"dropout_implementation"
));
// parallel ring id
auto
*
c_identity_op
=
c_identity
->
Op
();
fused_multi_transformer_op_desc
.
SetAttr
(
"ring_id"
,
c_identity_op
->
GetAttr
(
"ring_id"
));
auto
*
fused_multi_transformer
=
graph
->
CreateOpNode
(
&
fused_multi_transformer_op_desc
);
IR_NODE_LINK_TO
(
input0
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_scale
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
layer_norm_bias
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
matmul0_w
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd0_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
eltadd_qk_b
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
input0
,
fill_const_op
);
IR_NODE_LINK_TO
(
fill_const_op
,
cache_kv
);
IR_NODE_LINK_TO
(
cache_kv
,
fused_multi_transformer
);
IR_NODE_LINK_TO
(
fused_multi_transformer
,
ffn_output
);
// rewrite while OP input
// 1. delete k, v
// 2. delete matmul1/2_w eltadd1/2_w
// 3. add cache_kv
auto
while_Xs
=
while0
->
Op
()
->
Input
(
"X"
);
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Xs
));
while_Xs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
// rewrite while OP output
// 1. delete k, v
// 2. add cache_kv
auto
while_Outs
=
while0
->
Op
()
->
Output
(
"Out"
);
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_k_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Outs
),
std
::
end
(
while_Outs
),
split0_v_out
->
Name
()),
std
::
end
(
while_Outs
));
while_Outs
.
emplace_back
(
cache_kv
->
Name
());
while0
->
Op
()
->
SetOutput
(
"Out"
,
while_Outs
);
// link CacheKV to while
IR_NODE_LINK_TO
(
cache_kv
,
while0
)
// unlink origin KV output to while
IR_NODE_UNLINK
(
split0_k_out
,
while0
);
IR_NODE_UNLINK
(
split0_v_out
,
while0
);
IR_NODE_UNLINK
(
while0
,
split0_k_out
);
IR_NODE_UNLINK
(
while0
,
split0_v_out
);
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"fused_multi_transformer_encoder_fuse_qkv "
"pass in op compat failed."
;
return
;
}
VLOG
(
4
)
<<
"handle MultiTransformer encoder(Fuse-QKV) fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm
,
layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_scale
,
layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_bias
,
layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_mean
,
layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_variance
,
layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
layer_norm_out
,
layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity
,
c_identity
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_identity_out
,
c_identity_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0
,
matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_out
,
matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul0_w
,
matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0
,
reshape2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_0_out
,
reshape2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0
,
transpose2_0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_0_out
,
transpose2_0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0
,
split0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_q_out
,
split0_q_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_k_out
,
split0_k_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
split0_v_out
,
split0_v_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm
,
ffn_layer_norm
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_scale
,
ffn_layer_norm_scale
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_bias
,
ffn_layer_norm_bias
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_mean
,
ffn_layer_norm_mean
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_variance
,
ffn_layer_norm_variance
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_layer_norm_out
,
ffn_layer_norm_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity
,
ffn_c_identity
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_identity_out
,
ffn_c_identity_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0
,
ffn_matmul0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_out
,
ffn_matmul0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul0_w
,
ffn_matmul0_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0
,
ffn_eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_b
,
ffn_eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd0_out
,
ffn_eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu
,
ffn_gelu
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_gelu_out
,
ffn_gelu_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1
,
ffn_matmul1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_out
,
ffn_matmul1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_matmul1_w
,
ffn_matmul1_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum
,
ffn_c_allreduce_sum
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_c_allreduce_sum_out
,
ffn_c_allreduce_sum_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1
,
ffn_eltadd1
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_b
,
ffn_eltadd1_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd1_out
,
ffn_eltadd1_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout
,
ffn_dropout
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_dropout_out
,
ffn_dropout_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_eltadd_out
,
ffn_eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
ffn_output
,
ffn_output
,
fused_multi_transformer_fuse_qkv_pattern
)
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0
,
eltadd0
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_b
,
eltadd0_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd0_out
,
eltadd0_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk
,
matmul_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qk_out
,
matmul_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk
,
eltadd_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_b
,
eltadd_qk_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_qk_out
,
eltadd_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk
,
softmax_qk
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softmax_qk_out
,
softmax_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk
,
dropout_qk
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_qk_out
,
dropout_qk_out
,
fused_multi_transformer_fuse_qkv_pattern
)
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv
,
matmul_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_qkv_out
,
matmul_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv
,
reshape2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape2_qkv_out
,
reshape2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv
,
transpose2_qkv
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose2_qkv_out
,
transpose2_qkv_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear
,
matmul_linear
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_w
,
matmul_linear_w
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_linear_out
,
matmul_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum
,
c_allreduce_sum
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
c_allreduce_sum_out
,
c_allreduce_sum_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear
,
eltadd_linear
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_b
,
eltadd_linear_b
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_linear_out
,
eltadd_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear
,
dropout_linear
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
dropout_linear_out
,
dropout_linear_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fused_multi_transformer_fuse_qkv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
while0
,
while0
,
fused_multi_transformer_fuse_qkv_pattern
);
fuse_creater
(
input0
,
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
c_identity
,
matmul0_w
,
eltadd0_b
,
split0_k_out
,
split0_v_out
,
eltadd_qk_b
,
dropout_qk
,
reshape2_0
,
matmul_linear_w
,
eltadd_linear_b
,
dropout_linear
,
while0
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_matmul0_w
,
ffn_matmul1_w
,
ffn_eltadd0_b
,
ffn_eltadd1_b
,
ffn_dropout
,
ffn_output
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
layer_norm
,
layer_norm_scale
,
layer_norm_bias
,
layer_norm_mean
,
layer_norm_variance
,
layer_norm_out
,
c_identity
,
c_identity_out
,
matmul0
,
matmul0_out
,
eltadd0
,
eltadd0_out
,
reshape2_0
,
reshape2_0_out
,
transpose2_0
,
transpose2_0_out
,
split0
,
split0_q_out
,
split0_k_out
,
split0_v_out
,
matmul_qk
,
matmul_qk_out
,
eltadd_qk
,
eltadd_qk_out
,
softmax_qk
,
softmax_qk_out
,
dropout_qk
,
dropout_qk_out
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_qkv
,
matmul_qkv_out
,
reshape2_qkv
,
transpose2_qkv
,
transpose2_qkv_out
,
matmul_linear
,
matmul_linear_w
,
matmul_linear_out
,
c_allreduce_sum
,
c_allreduce_sum_out
,
eltadd_linear
,
eltadd_linear_b
,
eltadd_linear_out
,
dropout_linear
,
dropout_linear_out
,
eltadd_out
,
ffn_layer_norm
,
ffn_layer_norm_scale
,
ffn_layer_norm_bias
,
ffn_layer_norm_mean
,
ffn_layer_norm_variance
,
ffn_layer_norm_out
,
ffn_c_identity
,
ffn_c_identity_out
,
ffn_matmul0
,
ffn_matmul1
,
ffn_matmul0_out
,
ffn_matmul1_out
,
ffn_c_allreduce_sum
,
ffn_c_allreduce_sum_out
,
ffn_eltadd0
,
ffn_eltadd1
,
ffn_eltadd0_out
,
ffn_eltadd1_out
,
ffn_gelu
,
ffn_gelu_out
,
ffn_dropout
,
ffn_dropout_out
,
ffn_eltadd_out
});
// Remove unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
void
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the fused_multi_transformer_encoder pass, "
"The scope should not be null."
));
int
fusion_count
=
BuildFusion
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
()
{
AddOpCompat
(
OpCompat
(
"layer_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"begin_norm_axis"
)
.
IsNumGT
(
0
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsType
<
bool
>
()
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsOptional
()
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"while"
))
.
AddInput
(
"X"
)
// A set of variables, unconstrained
.
End
()
.
AddInput
(
"Condition"
)
// An scalar
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// A set of variables, unconstrained
.
End
()
.
AddOutput
(
"StepScopes"
)
// A vector of local scope, unconstrained
.
End
()
.
AddAttr
(
"sub_block"
)
.
IsType
<
framework
::
BlockDesc
*>
()
.
End
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fused_multi_transformer_encoder_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderPass
);
REGISTER_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
FusedMultiTransformerEncoderFuseQKVPass
);
REGISTER_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
,
paddle
::
framework
::
ir
::
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
);
REGISTER_PASS_CAPABILITY
(
fused_multi_transformer_encoder_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
fused_multi_transformer_encoder_fuse_qkv_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
REGISTER_PASS_CAPABILITY
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"matmul_v2"
,
0
)
.
EQ
(
"softmax"
,
0
));
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h
0 → 100644
浏览文件 @
5a2e5179
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
FusedMultiTransformerEncoderPattern
:
public
PatternBase
{
FusedMultiTransformerEncoderPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_multi_transformer_encoder"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul1
);
PATTERN_DECL_NODE
(
matmul2
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul1_w
);
PATTERN_DECL_NODE
(
matmul2_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
matmul1_out
);
PATTERN_DECL_NODE
(
matmul2_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd2_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
eltadd1_out
);
PATTERN_DECL_NODE
(
eltadd2_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_1
);
PATTERN_DECL_NODE
(
reshape2_2
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
reshape2_1_out
);
PATTERN_DECL_NODE
(
reshape2_2_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_1
);
PATTERN_DECL_NODE
(
transpose2_2
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// while loop
PATTERN_DECL_NODE
(
while0
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
struct
FusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
FusedMultiTransformerEncoderFuseQKVPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_multi_transformer_encoder_fuse_qkv"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
split0
)
PATTERN_DECL_NODE
(
split0_q_out
)
PATTERN_DECL_NODE
(
split0_k_out
)
PATTERN_DECL_NODE
(
split0_v_out
)
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// while loop
PATTERN_DECL_NODE
(
while0
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
struct
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
:
public
PatternBase
{
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv"
)
{}
PDNode
*
operator
()();
// Q, K, V path
PATTERN_DECL_NODE
(
input0
);
PATTERN_DECL_NODE
(
layer_norm
);
PATTERN_DECL_NODE
(
layer_norm_scale
);
PATTERN_DECL_NODE
(
layer_norm_bias
);
PATTERN_DECL_NODE
(
layer_norm_mean
);
PATTERN_DECL_NODE
(
layer_norm_variance
);
PATTERN_DECL_NODE
(
layer_norm_out
);
PATTERN_DECL_NODE
(
c_identity
);
PATTERN_DECL_NODE
(
c_identity_out
);
PATTERN_DECL_NODE
(
matmul0
);
PATTERN_DECL_NODE
(
matmul0_w
);
PATTERN_DECL_NODE
(
matmul0_out
);
PATTERN_DECL_NODE
(
eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
eltadd0_out
);
PATTERN_DECL_NODE
(
reshape2_0
);
PATTERN_DECL_NODE
(
reshape2_0_out
);
PATTERN_DECL_NODE
(
transpose2_0
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
split0
)
PATTERN_DECL_NODE
(
split0_q_out
)
PATTERN_DECL_NODE
(
split0_k_out
)
PATTERN_DECL_NODE
(
split0_v_out
)
// Q, K matmul
PATTERN_DECL_NODE
(
matmul_qk
);
PATTERN_DECL_NODE
(
matmul_qk_out
);
PATTERN_DECL_NODE
(
eltadd_qk
);
PATTERN_DECL_NODE
(
eltadd_qk_b
);
PATTERN_DECL_NODE
(
eltadd_qk_out
);
PATTERN_DECL_NODE
(
softmax_qk
);
PATTERN_DECL_NODE
(
softmax_qk_out
);
PATTERN_DECL_NODE
(
dropout_qk
);
PATTERN_DECL_NODE
(
dropout_qk_out
);
// QK, V matmul
PATTERN_DECL_NODE
(
matmul_qkv
);
PATTERN_DECL_NODE
(
matmul_qkv_out
);
PATTERN_DECL_NODE
(
reshape2_qkv
);
PATTERN_DECL_NODE
(
reshape2_qkv_out
);
PATTERN_DECL_NODE
(
transpose2_qkv
);
PATTERN_DECL_NODE
(
transpose2_qkv_out
);
// while loop
PATTERN_DECL_NODE
(
while0
);
// out linear
PATTERN_DECL_NODE
(
matmul_linear
);
PATTERN_DECL_NODE
(
matmul_linear_w
);
PATTERN_DECL_NODE
(
matmul_linear_out
);
PATTERN_DECL_NODE
(
c_allreduce_sum
);
PATTERN_DECL_NODE
(
c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
eltadd_linear
);
PATTERN_DECL_NODE
(
eltadd_linear_b
);
PATTERN_DECL_NODE
(
eltadd_linear_out
);
PATTERN_DECL_NODE
(
dropout_linear
);
PATTERN_DECL_NODE
(
dropout_linear_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
eltadd_out
)
PATTERN_DECL_NODE
(
attention_output
);
// Feed Forward nodes
PATTERN_DECL_NODE
(
ffn_layer_norm
);
PATTERN_DECL_NODE
(
ffn_layer_norm_scale
);
PATTERN_DECL_NODE
(
ffn_layer_norm_bias
);
PATTERN_DECL_NODE
(
ffn_layer_norm_mean
);
PATTERN_DECL_NODE
(
ffn_layer_norm_variance
);
PATTERN_DECL_NODE
(
ffn_layer_norm_out
);
PATTERN_DECL_NODE
(
ffn_c_identity
);
PATTERN_DECL_NODE
(
ffn_c_identity_out
);
PATTERN_DECL_NODE
(
ffn_matmul0
);
PATTERN_DECL_NODE
(
ffn_matmul0_w
);
PATTERN_DECL_NODE
(
ffn_matmul0_out
);
PATTERN_DECL_NODE
(
ffn_eltadd0
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd0_out
);
PATTERN_DECL_NODE
(
ffn_gelu
);
PATTERN_DECL_NODE
(
ffn_gelu_out
);
PATTERN_DECL_NODE
(
ffn_matmul1
);
PATTERN_DECL_NODE
(
ffn_matmul1_w
);
PATTERN_DECL_NODE
(
ffn_matmul1_out
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum
);
PATTERN_DECL_NODE
(
ffn_c_allreduce_sum_out
);
PATTERN_DECL_NODE
(
ffn_eltadd1
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_b
);
// ELEMENTWISE_ADD
PATTERN_DECL_NODE
(
ffn_eltadd1_out
);
PATTERN_DECL_NODE
(
ffn_dropout
);
PATTERN_DECL_NODE
(
ffn_dropout_out
);
// output elementwise_add
PATTERN_DECL_NODE
(
ffn_eltadd_out
)
PATTERN_DECL_NODE
(
ffn_output
);
};
}
// namespace patterns
class
FusedMultiTransformerEncoderPass
:
public
FusePassBase
{
public:
FusedMultiTransformerEncoderPass
();
virtual
~
FusedMultiTransformerEncoderPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fused_multi_transformer_encoder"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
FusedMultiTransformerEncoderFuseQKVPass
:
public
FusePassBase
{
public:
FusedMultiTransformerEncoderFuseQKVPass
();
virtual
~
FusedMultiTransformerEncoderFuseQKVPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fused_multi_transformer_encoder_fuse_qkv"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
class
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
:
public
FusePassBase
{
public:
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
();
virtual
~
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multi_devices_fused_multi_transformer_encoder_fuse_qkv"
};
private:
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc
0 → 100644
浏览文件 @
5a2e5179
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
dims
);
tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
}
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
// MHA: pre Layer Norm
AddVarToScope
(
param_scope
,
"ln_scale"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ln_bias"
,
{
1024
});
// MHA: QKV fc
AddVarToScope
(
param_scope
,
"weights0"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"weights1"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"weights2"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"bias_0"
,
{
1024
});
AddVarToScope
(
param_scope
,
"bias_1"
,
{
1024
});
AddVarToScope
(
param_scope
,
"bias_2"
,
{
1024
});
// MHA: QK bias
AddVarToScope
(
param_scope
,
"biasqk"
,
{
1024
});
// MHA: out Linear
AddVarToScope
(
param_scope
,
"weights_l"
,
{
1024
,
1024
});
AddVarToScope
(
param_scope
,
"bias_l"
,
{
1024
});
// MHA: pre Layer Norm
AddVarToScope
(
param_scope
,
"ffn_ln_scale"
,
{
1024
});
AddVarToScope
(
param_scope
,
"ffn_ln_bias"
,
{
1024
});
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope
(
param_scope
,
"ffn_weights0"
,
{
1024
,
4096
});
AddVarToScope
(
param_scope
,
"ffn_weights1"
,
{
4096
,
1024
});
AddVarToScope
(
param_scope
,
"ffn_bias_0"
,
{
4096
});
AddVarToScope
(
param_scope
,
"ffn_bias_1"
,
{
1024
});
return
param_scope
;
}
TEST
(
FusedMultiTransformerEncoderPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
1024
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
1024
,
1024
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
1024
,
1024
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
matmul_out_1
=
layers
.
matmul_v2
(
ln_out
,
weights_1
,
nullptr
,
false
,
true
);
auto
*
matmul_out_2
=
layers
.
matmul_v2
(
ln_out
,
weights_2
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
1024
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
1024
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
1024
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
matmul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
matmul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
// Link to decoder while block
layers
.
while_loop
({
transpose_1
,
transpose_2
});
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
transpose_0
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
,
nullptr
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"weightsl"
,
{
1024
,
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
true
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_encoder_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get fused_multi_transformer_encoder_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
68
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d"
,
num_nodes_before
-
68
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
FusedMultiTransformerEncoderPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"fused_multi_transformer_encoder_pass"
));
}
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
3072
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
ln_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
3072
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
split_outs
=
layers
.
split
(
transpose_0
,
3
,
3
);
auto
*
split_q
=
split_outs
[
0
];
auto
*
split_k
=
split_outs
[
1
];
auto
*
split_v
=
split_outs
[
2
];
layers
.
assign
(
split_k
);
layers
.
assign
(
split_v
);
// Link to decoder while block
layers
.
while_loop
({
split_k
,
split_v
});
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout_qk
,
split_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"weightsl"
,
{
1024
,
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
true
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
linear_matmut_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_ln_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_matmul1_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fused_multi_transformer_encoder_fuse_qkv_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get fused_multi_transformer_encoder_fuse_qkv_pass failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
56
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
56
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv "
"pass, there should be one fused_multi_transformer "
"op, but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
FusedMultiTransformerEncoderFuseQKVPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"fused_multi_transformer_encoder_fuse_qkv_pass"
));
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> c_identity_out
// (c_identity_out, weights_0) matmul_v2 -> matmul_out0
// (matmul_out0) elementwise_add -> eltadd_0
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) assign -> assign_k
// (split_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (ffn_layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
Layers
layers
;
// MHA: pre LayerNorm
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
1024
});
auto
*
ln_scale
=
layers
.
data
(
"ln_scale"
,
{
1024
},
true
);
auto
*
ln_bias
=
layers
.
data
(
"ln_bias"
,
{
1024
},
true
);
auto
*
ln_out
=
layers
.
layer_norm
(
x
,
ln_scale
,
ln_bias
)[
0
];
auto
*
c_identity_out
=
layers
.
c_identity
(
ln_out
);
// MHA: QKV fc
auto
*
weights_0
=
layers
.
data
(
"weights0"
,
{
1024
,
3072
},
true
);
auto
*
matmul_out_0
=
layers
.
matmul_v2
(
c_identity_out
,
weights_0
,
nullptr
,
false
,
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
3072
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
matmul_out_0
,
b0
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
1
,
128
,
16
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
split_outs
=
layers
.
split
(
transpose_0
,
3
,
3
);
auto
*
split_q
=
split_outs
[
0
];
auto
*
split_k
=
split_outs
[
1
];
auto
*
split_v
=
split_outs
[
2
];
layers
.
assign
(
split_k
);
layers
.
assign
(
split_v
);
// Link to decoder while block
layers
.
while_loop
({
split_k
,
split_v
});
// MHA: QK matmul
auto
*
matmul_qk
=
layers
.
matmul
(
split_q
,
split_k
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
128
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
dropout_qk
=
layers
.
dropout
(
softmax_qk
,
0.1
,
"upscale_in_train"
);
// MHA: QKV matmul
auto
*
matmul_qkv
=
layers
.
matmul_v2
(
dropout_qk
,
split_v
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
},
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
1024
},
true
);
// MHA: out Linear
auto
*
weights_l
=
layers
.
data
(
"weights_l"
,
{
1024
,
1024
},
true
);
auto
*
bias_l
=
layers
.
data
(
"weightsl"
,
{
1024
,
1024
},
true
);
auto
*
linear_matmut_out
=
layers
.
matmul_v2
(
reshape_qkv_out
,
weights_l
,
nullptr
,
false
,
true
);
auto
*
c_allreduce_out
=
layers
.
c_allreduce_sum
(
linear_matmut_out
);
auto
*
linear_eltadd_out
=
layers
.
elementwise_add
(
c_allreduce_out
,
bias_l
,
nullptr
,
2
);
auto
*
dropout_qkv
=
layers
.
dropout
(
linear_eltadd_out
,
0.1
,
"upscale_in_train"
);
auto
*
attention_out
=
layers
.
elementwise_add
(
x
,
dropout_qkv
);
// FFN: pre LayerNorm
auto
*
ffn_ln_scale
=
layers
.
data
(
"ffn_ln_scale"
,
{
1024
},
true
);
auto
*
ffn_ln_bias
=
layers
.
data
(
"ffn_ln_bias"
,
{
1024
},
true
);
auto
*
ffn_ln_out
=
layers
.
layer_norm
(
attention_out
,
ffn_ln_scale
,
ffn_ln_bias
)[
0
];
auto
*
ffn_c_identity_out
=
layers
.
c_identity
(
ffn_ln_out
);
// FFN: fc1 -> gelu -> fc2
auto
*
ffn_weights0
=
layers
.
data
(
"ffn_weights0"
,
{
1024
,
4096
},
true
);
auto
*
ffn_weights1
=
layers
.
data
(
"ffn_weights1"
,
{
4096
,
1024
},
true
);
auto
*
ffn_bias0
=
layers
.
data
(
"ffn_bias0"
,
{
4096
},
true
);
auto
*
ffn_bias1
=
layers
.
data
(
"ffn_bias1"
,
{
1024
},
true
);
auto
*
ffn_matmul0_out
=
layers
.
matmul_v2
(
ffn_c_identity_out
,
ffn_weights0
,
nullptr
,
false
,
true
);
auto
*
ffn_eltadd0_out
=
layers
.
elementwise_add
(
ffn_matmul0_out
,
ffn_bias0
,
nullptr
,
2
);
auto
*
ffn_gelu_out
=
layers
.
gelu
(
ffn_eltadd0_out
);
auto
*
ffn_matmul1_out
=
layers
.
matmul_v2
(
ffn_gelu_out
,
ffn_weights1
,
nullptr
,
false
,
true
);
auto
*
ffn_allreduce_out
=
layers
.
c_allreduce_sum
(
ffn_matmul1_out
);
auto
*
ffn_eltadd1_out
=
layers
.
elementwise_add
(
ffn_allreduce_out
,
ffn_bias1
,
nullptr
,
2
);
// FFN: dropout -> elementwise_add
auto
*
ffn_dropout
=
layers
.
dropout
(
ffn_eltadd1_out
,
0.1
,
"upscale_in_train"
);
layers
.
elementwise_add
(
attention_out
,
ffn_dropout
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
);
if
(
pass
.
get
()
==
nullptr
)
LOG
(
INFO
)
<<
"get multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass "
"failed"
;
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
int
num_fused_nodes_after
=
GetNumOpNodes
(
graph
,
"fused_multi_transformer"
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
64
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d"
,
num_nodes_before
-
64
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_fused_nodes_after
,
1
,
platform
::
errors
::
InvalidArgument
(
"After the fused_multi_transformer_encoder_fuse_qkv "
"multi-devices pass, there should be one "
"fused_multi_transformer op, but the result is %d"
,
num_fused_nodes_after
));
}
TEST
(
MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fused_multi_transformer_encoder_pass
);
USE_PASS
(
fused_multi_transformer_encoder_fuse_qkv_pass
);
USE_PASS
(
multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass
);
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
5a2e5179
...
@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph,
...
@@ -815,9 +815,14 @@ void GraphToProgram(const Graph &graph,
// avoid kRootBlockIndex not 0
// avoid kRootBlockIndex not 0
if
(
idx
==
kRootBlockIndex
)
continue
;
if
(
idx
==
kRootBlockIndex
)
continue
;
block
=
program_pb
.
add_blocks
();
if
(
static_cast
<
int
>
(
idx
)
<
program_pb
.
blocks_size
())
{
block
->
set_idx
(
idx
);
block
=
program_pb
.
mutable_blocks
(
idx
);
block
->
set_parent_idx
(
kRootBlockIndex
);
}
else
{
block
=
program_pb
.
add_blocks
();
block
->
set_idx
(
idx
);
block
->
set_parent_idx
(
kRootBlockIndex
);
}
GraphToBlock
(
*
graph
.
GetSubGraph
(
idx
),
GraphToBlock
(
*
graph
.
GetSubGraph
(
idx
),
block
,
block
,
sort_kind
,
sort_kind
,
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
5a2e5179
...
@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
...
@@ -112,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
if
(
graph
.
Nodes
().
empty
())
return
false
;
if
(
graph
.
Nodes
().
empty
())
return
false
;
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
if
(
node
.
Name
().
rfind
(
"__control_var"
)
==
0
)
continue
;
for
(
const
auto
&
pdnode
:
pattern_
.
nodes
())
{
for
(
const
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
pdnode
->
Tell
(
&
node
))
{
if
(
pdnode
->
Tell
(
&
node
))
{
VLOG
(
4
)
<<
"Node "
<<
node
.
Name
()
<<
" marked as "
<<
pdnode
->
name
();
VLOG
(
4
)
<<
"Node "
<<
node
.
Name
()
<<
" marked as "
<<
pdnode
->
name
();
...
@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const {
...
@@ -383,7 +384,6 @@ std::string PDPattern::DotString() const {
// Create Edges
// Create Edges
for
(
const
auto
&
edge
:
edges
())
{
for
(
const
auto
&
edge
:
edges
())
{
if
(
!
node2dot
.
count
(
edge
.
first
)
||
!
node2dot
.
count
(
edge
.
second
))
{
if
(
!
node2dot
.
count
(
edge
.
first
)
||
!
node2dot
.
count
(
edge
.
second
))
{
LOG
(
ERROR
)
<<
"no node "
<<
edge
.
first
<<
" "
<<
edge
.
second
;
continue
;
continue
;
}
}
auto
&
src
=
node2dot
.
at
(
edge
.
first
);
auto
&
src
=
node2dot
.
at
(
edge
.
first
);
...
@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() {
...
@@ -453,7 +453,8 @@ PDNode *PDNode::assert_var_not_persistable() {
PDNode
*
PDNode
::
assert_is_persistable_var
()
{
PDNode
*
PDNode
::
assert_is_persistable_var
()
{
assert_is_var
();
assert_is_var
();
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
->
Var
()
->
Persistable
();
});
asserts_
.
emplace_back
(
[
=
](
Node
*
x
)
{
return
x
->
Var
()
&&
x
->
Var
()
->
Persistable
();
});
return
this
;
return
this
;
}
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
5a2e5179
...
@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase {
...
@@ -1990,6 +1990,14 @@ struct AddSupportInt8 : public PatternBase {
a->outputs.push_back(b); \
a->outputs.push_back(b); \
b->inputs.push_back(a);
b->inputs.push_back(a);
// UnLink 2 ir::Nodes from each other.
#define IR_NODE_UNLINK(a, b) \
a->outputs.erase( \
std::remove(std::begin(a->outputs), std::end(a->outputs), b), \
std::end(a->outputs)); \
b->inputs.erase(std::remove(std::begin(b->inputs), std::end(b->inputs), a), \
std::end(b->inputs));
// Set the out_var as the output of the op
// Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \
#define IR_OP_VAR_LINK(op, out_var) \
op->outputs.push_back(out_var); \
op->outputs.push_back(out_var); \
...
...
paddle/fluid/framework/ir/pass.cc
浏览文件 @
5a2e5179
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Scope
;
namespace
ir
{
namespace
ir
{
class
Graph
;
class
Graph
;
}
// namespace ir
}
// namespace ir
...
@@ -35,6 +36,17 @@ namespace paddle {
...
@@ -35,6 +36,17 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
static
const
char
kParamScopeAttr
[]
=
"__param_scope__"
;
static
const
std
::
vector
<
std
::
string
>
support_subgraph_passes
=
{
"fused_multi_transformer_encoder_pass"
,
"fused_multi_transformer_decoder_pass"
,
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
};
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"start to apply pass "
<<
Type
()
<<
" to graph"
;
VLOG
(
10
)
<<
"start to apply pass "
<<
Type
()
<<
" to graph"
;
CheckPrevPass
();
CheckPrevPass
();
...
@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const {
...
@@ -65,11 +77,41 @@ Graph *Pass::Apply(Graph *graph) const {
true
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The VarDescs of persistable variable are not consistency."
));
"The VarDescs of persistable variable are not consistency."
));
applied_
=
true
;
if
(
!
graph
->
Has
(
kPassRecorder
))
{
if
(
!
graph
->
Has
(
kPassRecorder
))
{
graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
}
}
graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
if
(
graph
->
IsMainGraph
()
&&
std
::
count
(
support_subgraph_passes
.
begin
(),
support_subgraph_passes
.
end
(),
Type
()))
{
for
(
size_t
i
=
1
;
i
<
graph
->
SubGraphsSize
();
i
++
)
{
auto
*
sub_graph
=
graph
->
GetSubGraph
(
i
);
if
(
!
sub_graph
->
Has
(
framework
::
ir
::
kParamScopeAttr
))
{
sub_graph
->
SetNotOwned
<
Scope
>
(
framework
::
ir
::
kParamScopeAttr
,
&
graph
->
Get
<
Scope
>
(
framework
::
ir
::
kParamScopeAttr
));
}
ApplyImpl
(
sub_graph
);
PADDLE_ENFORCE_EQ
(
HasCircle
(
*
sub_graph
),
false
,
platform
::
errors
::
InvalidArgument
(
"Illegal pass %s. Generated graph shouldn't contain cycle."
,
Type
()));
PADDLE_ENFORCE_EQ
(
VarDescIsConsistency
(
*
sub_graph
),
true
,
platform
::
errors
::
InvalidArgument
(
"The VarDescs of persistable variable are not consistency."
));
if
(
!
sub_graph
->
Has
(
kPassRecorder
))
{
sub_graph
->
Set
<
PassRecorder
>
(
kPassRecorder
,
new
PassRecorder
);
}
sub_graph
->
Get
<
PassRecorder
>
(
kPassRecorder
).
insert
(
Type
());
}
}
applied_
=
true
;
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded
// Passes can change params, tensors, so caching need to be discarded
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
5a2e5179
...
@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder";
...
@@ -47,6 +47,18 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr
char
kEmbEltwiseLayernormPass
[]
=
constexpr
char
kEmbEltwiseLayernormPass
[]
=
"embedding_eltwise_layernorm_fuse_pass_flag"
;
"embedding_eltwise_layernorm_fuse_pass_flag"
;
constexpr
char
kMultiheadMatmulPass
[]
=
"multihead_matmul_fuse_pass_flag"
;
constexpr
char
kMultiheadMatmulPass
[]
=
"multihead_matmul_fuse_pass_flag"
;
constexpr
char
kFusedMultiTransformerEncoderPass
[]
=
"fused_multi_transformer_encoder_pass_flag"
;
constexpr
char
kFusedMultiTransformerDecoderPass
[]
=
"fused_multi_transformer_decoder_pass_flag"
;
constexpr
char
kFusedMultiTransformerEncoderFuseQKVPass
[]
=
"fused_multi_transformer_encoder_fuse_qkv_pass_flag"
;
constexpr
char
kFusedMultiTransformerDecoderFuseQKVPass
[]
=
"fused_multi_transformer_decoder_fuse_qkv_pass_flag"
;
constexpr
char
kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass
[]
=
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag"
;
constexpr
char
kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass
[]
=
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag"
;
constexpr
char
kPrelnEmbEltwiseLayernormPass
[]
=
constexpr
char
kPrelnEmbEltwiseLayernormPass
[]
=
"preln_embedding_eltwise_layernorm_fuse_pass_flag"
;
"preln_embedding_eltwise_layernorm_fuse_pass_flag"
;
...
...
paddle/fluid/framework/ir/pass_tester_helper.h
浏览文件 @
5a2e5179
...
@@ -146,6 +146,12 @@ struct Layers {
...
@@ -146,6 +146,12 @@ struct Layers {
return
unary_op
(
"relu"
,
x
,
out
);
return
unary_op
(
"relu"
,
x
,
out
);
}
}
VarDesc
*
gelu
(
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
,
bool
approximate
=
true
)
{
AttributeMap
attrs
;
attrs
[
"approximate"
]
=
approximate
;
return
unary_op
(
"gelu"
,
x
,
out
,
&
attrs
);
}
VarDesc
*
sigmoid
(
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
)
{
VarDesc
*
sigmoid
(
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
)
{
return
unary_op
(
"sigmoid"
,
x
,
out
);
return
unary_op
(
"sigmoid"
,
x
,
out
);
}
}
...
@@ -154,6 +160,20 @@ struct Layers {
...
@@ -154,6 +160,20 @@ struct Layers {
return
unary_op
(
"tanh"
,
x
,
out
);
return
unary_op
(
"tanh"
,
x
,
out
);
}
}
VarDesc
*
c_identity
(
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
,
int
ring_id
=
-
1
)
{
AttributeMap
attrs
;
attrs
[
"ring_id"
]
=
ring_id
;
return
unary_op
(
"c_identity"
,
x
,
out
,
&
attrs
);
}
VarDesc
*
c_allreduce_sum
(
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
,
int
ring_id
=
-
1
)
{
AttributeMap
attrs
;
attrs
[
"ring_id"
]
=
ring_id
;
return
unary_op
(
"c_allreduce_sum"
,
x
,
out
,
&
attrs
);
}
VarDesc
*
fc
(
VarDesc
*
input
,
VarDesc
*
fc
(
VarDesc
*
input
,
VarDesc
*
w
,
VarDesc
*
w
,
VarDesc
*
bias
,
VarDesc
*
bias
,
...
@@ -332,6 +352,37 @@ struct Layers {
...
@@ -332,6 +352,37 @@ struct Layers {
return
outs
;
return
outs
;
}
}
std
::
vector
<
VarDesc
*>
split
(
VarDesc
*
x
,
int
num_or_section
,
int
axis
=
0
)
{
std
::
vector
<
VarDesc
*>
outs
(
num_or_section
);
for
(
int
i
=
0
;
i
<
num_or_section
;
i
++
)
{
outs
[
i
]
=
lod_tensor
(
unique_name
());
}
std
::
vector
<
std
::
string
>
out_names
(
num_or_section
);
for
(
int
i
=
0
;
i
<
num_or_section
;
i
++
)
{
out_names
[
i
]
=
outs
[
i
]
->
Name
();
}
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"split"
);
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetOutput
(
"Out"
,
out_names
);
op
->
SetAttr
(
"num_or_section"
,
num_or_section
);
op
->
SetAttr
(
"axis"
,
axis
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
outs
;
}
VarDesc
*
assign
(
VarDesc
*
x
)
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
out
;
}
VarDesc
*
matmul
(
VarDesc
*
x
,
VarDesc
*
matmul
(
VarDesc
*
x
,
VarDesc
*
y
,
VarDesc
*
y
,
VarDesc
*
alpha
=
nullptr
,
VarDesc
*
alpha
=
nullptr
,
...
@@ -459,6 +510,24 @@ struct Layers {
...
@@ -459,6 +510,24 @@ struct Layers {
return
out
;
return
out
;
}
}
VarDesc
*
while_loop
(
std
::
vector
<
VarDesc
*>
xs
,
VarDesc
*
cond
=
nullptr
)
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
VarDesc
*
step_scopes
=
lod_tensor
(
unique_name
());
if
(
cond
==
nullptr
)
cond
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"while"
);
std
::
vector
<
std
::
string
>
xs_names
;
for
(
auto
&
x
:
xs
)
xs_names
.
emplace_back
(
x
->
Name
());
op
->
SetInput
(
"X"
,
xs_names
);
op
->
SetInput
(
"Condition"
,
{
cond
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetOutput
(
"StepScopes"
,
{
step_scopes
->
Name
()});
op
->
SetAttr
(
"sub_block"
,
{
program_
.
MutableBlock
(
0
)});
op
->
SetAttr
(
"is_test"
,
true
);
return
out
;
}
void
backward
(
std
::
vector
<
VarDesc
*>
targets
)
{
void
backward
(
std
::
vector
<
VarDesc
*>
targets
)
{
// This function is designed to simulate the structure of training program,
// This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program.
// but is constructed differently as the actual program.
...
@@ -523,7 +592,10 @@ struct Layers {
...
@@ -523,7 +592,10 @@ struct Layers {
return
var
;
return
var
;
}
}
VarDesc
*
unary_op
(
std
::
string
type
,
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
)
{
VarDesc
*
unary_op
(
std
::
string
type
,
VarDesc
*
x
,
VarDesc
*
out
=
nullptr
,
const
AttributeMap
*
attrs
=
nullptr
)
{
if
(
!
out
)
{
if
(
!
out
)
{
out
=
lod_tensor
(
unique_name
());
out
=
lod_tensor
(
unique_name
());
}
}
...
@@ -531,6 +603,11 @@ struct Layers {
...
@@ -531,6 +603,11 @@ struct Layers {
op
->
SetType
(
type
);
op
->
SetType
(
type
);
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
if
(
attrs
)
{
for
(
auto
&
iter
:
*
attrs
)
{
op
->
SetAttr
(
iter
.
first
,
iter
.
second
);
}
}
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
static_cast
<
int
>
(
OpRole
::
kForward
));
return
out
;
return
out
;
...
...
paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
浏览文件 @
5a2e5179
...
@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle(
...
@@ -76,6 +76,7 @@ void MemoryOptimizePass::CollectLifeCycle(
}
else
{
}
else
{
// Normal operators.
// Normal operators.
for
(
const
Node
*
node
:
requires
)
{
for
(
const
Node
*
node
:
requires
)
{
if
(
!
node
->
Var
())
continue
;
if
(
node
->
Var
()
->
Persistable
())
continue
;
if
(
node
->
Var
()
->
Persistable
())
continue
;
std
::
string
var
=
node
->
Name
();
std
::
string
var
=
node
->
Name
();
if
(
!
lifecycles
->
count
(
var
))
{
if
(
!
lifecycles
->
count
(
var
))
{
...
@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
...
@@ -133,7 +134,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// between performance and underlying principle.
// between performance and underlying principle.
std
::
unordered_set
<
std
::
string
>
black_list
;
std
::
unordered_set
<
std
::
string
>
black_list
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
if
(
node
->
IsVar
()
&&
node
->
Var
()
&&
node
->
Var
()
->
GetType
()
==
node
->
Var
()
->
GetType
()
==
framework
::
proto
::
VarType
::
Type
::
VarType_Type_LOD_TENSOR
)
{
framework
::
proto
::
VarType
::
Type
::
VarType_Type_LOD_TENSOR
)
{
if
(
!
valid_var
(
node
))
{
if
(
!
valid_var
(
node
))
{
...
@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
...
@@ -144,7 +145,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
// Collect tensors from graph.
// Collect tensors from graph.
for
(
auto
*
node
:
graph
->
Nodes
())
{
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
if
(
node
->
IsVar
()
&&
node
->
Var
()
&&
node
->
Var
()
->
GetType
()
==
node
->
Var
()
->
GetType
()
==
framework
::
proto
::
VarType
::
Type
::
VarType_Type_LOD_TENSOR
&&
framework
::
proto
::
VarType
::
Type
::
VarType_Type_LOD_TENSOR
&&
!
black_list
.
count
(
node
->
Var
()
->
Name
()))
{
!
black_list
.
count
(
node
->
Var
()
->
Name
()))
{
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
5a2e5179
...
@@ -193,22 +193,28 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
...
@@ -193,22 +193,28 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
GpuPassStrategy
::
GpuPassStrategy
()
:
PassStrategy
({})
{
GpuPassStrategy
::
GpuPassStrategy
()
:
PassStrategy
({})
{
passes_
.
assign
({
passes_
.
assign
({
// "identity_scale_op_clean_pass", //
// "identity_scale_op_clean_pass", //
"is_test_pass"
,
//
"is_test_pass"
,
//
"simplify_with_basic_ops_pass"
,
//
"simplify_with_basic_ops_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"multihead_matmul_fuse_pass_v2"
,
//
"gpu_cpu_squeeze2_matmul_fuse_pass"
,
//
"fused_multi_transformer_encoder_pass"
,
//
"gpu_cpu_reshape2_matmul_fuse_pass"
,
//
"fused_multi_transformer_decoder_pass"
,
//
"gpu_cpu_flatten2_matmul_fuse_pass"
,
//
"fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"gpu_cpu_map_matmul_v2_to_mul_pass"
,
//
"fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
//
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass"
,
//
"matmul_scale_fuse_pass"
,
//
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
//
"multihead_matmul_fuse_pass_v3"
,
//
"gpu_cpu_squeeze2_matmul_fuse_pass"
,
//
"gpu_cpu_map_matmul_to_mul_pass"
,
//
"gpu_cpu_reshape2_matmul_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"gpu_cpu_flatten2_matmul_fuse_pass"
,
//
"fc_elementwise_layernorm_fuse_pass"
,
//
"gpu_cpu_map_matmul_v2_to_mul_pass"
,
//
"gpu_cpu_map_matmul_v2_to_matmul_pass"
,
//
"matmul_scale_fuse_pass"
,
//
"multihead_matmul_fuse_pass_v3"
,
//
"gpu_cpu_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_elementwise_layernorm_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录